commit 88398a1491b54a95ed0874a5791c3357a09d07e3 Author: Don Aldrich Date: Thu Jul 28 11:32:57 2022 -0500 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f0fed7e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.DS_Store +*/.DS_Store diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..c44d7ec --- /dev/null +++ b/Dockerfile @@ -0,0 +1,46 @@ +FROM apache/airflow:2.0.2-python3.8 +USER root +RUN apt-get update \ + && apt-get install -y --no-install-recommends \ + git \ + pylint \ + libpq-dev \ + python-dev \ + # gcc \ + && apt-get remove python-cffi \ + && apt-get autoremove -yqq --purge \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +RUN /usr/local/bin/python -m pip install --upgrade pip + +# COPY --from=docker /usr/local/bin/docker /usr/local/bin/docker +RUN curl -sSL https://get.docker.com/ | sh +# COPY requirements.txt ./ +# RUN pip install --no-cache-dir -r requirements.txt + +# ENV PYTHONPATH "${PYTHONPATH}:/your/custom/path" +ENV AIRFLOW_UID=1000 +ENV AIRFLOW_GID=1000 + +RUN echo airflow ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/airflow +RUN chmod 0440 /etc/sudoers.d/airflow +RUN echo 'airflow:airflow' | chpasswd +RUN groupmod --gid ${AIRFLOW_GID} airflow +RUN groupmod --gid 998 docker +# adduser --quiet "airflow" --uid "${AIRFLOW_UID}" \ + # --gid "${AIRFLOW_GID}" \ + # --home "${AIRFLOW_USER_HOME_DIR}" +RUN usermod --gid "${AIRFLOW_GID}" --uid "${AIRFLOW_UID}" airflow +RUN usermod -aG docker airflow +RUN chown -R ${AIRFLOW_UID}:${AIRFLOW_GID} /home/airflow +RUN chown -R ${AIRFLOW_UID}:${AIRFLOW_GID} /opt/airflow + +WORKDIR /opt/airflow + +# RUN usermod -g ${AIRFLOW_GID} airflow -G airflow + +USER ${AIRFLOW_UID} + +COPY requirements.txt ./ +RUN pip install --no-cache-dir -r requirements.txt diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/__main__ .py b/__main__ .py new file mode 100755 index 0000000..e69de29 diff --git a/config/airflow.cfg b/config/airflow.cfg new file mode 100644 index 0000000..a177958 --- /dev/null +++ b/config/airflow.cfg @@ -0,0 +1,1032 @@ +[core] +# The folder where your airflow pipelines live, most likely a +# subfolder in a code repository. This path must be absolute. +dags_folder = /opt/airflow/dags + +# Hostname by providing a path to a callable, which will resolve the hostname. +# The format is "package.function". +# +# For example, default value "socket.getfqdn" means that result from getfqdn() of "socket" +# package will be used as hostname. +# +# No argument should be required in the function specified. +# If using IP address as hostname is preferred, use value ``airflow.utils.net.get_host_ip_address`` +hostname_callable = socket.getfqdn + +# Default timezone in case supplied date times are naive +# can be utc (default), system, or any IANA timezone string (e.g. Europe/Amsterdam) +default_timezone = America/Chicago + +# The executor class that airflow should use. Choices include +# ``SequentialExecutor``, ``LocalExecutor``, ``CeleryExecutor``, ``DaskExecutor``, +# ``KubernetesExecutor``, ``CeleryKubernetesExecutor`` or the +# full import path to the class when using a custom executor. +executor = SequentialExecutor + +# The SqlAlchemy connection string to the metadata database. +# SqlAlchemy supports many different database engine, more information +# their website +sql_alchemy_conn = sqlite:////opt/airflow/airflow.db + +# The encoding for the databases +sql_engine_encoding = utf-8 + +# Collation for ``dag_id``, ``task_id``, ``key`` columns in case they have different encoding. +# This is particularly useful in case of mysql with utf8mb4 encoding because +# primary keys for XCom table has too big size and ``sql_engine_collation_for_ids`` should +# be set to ``utf8mb3_general_ci``. +# sql_engine_collation_for_ids = + +# If SqlAlchemy should pool database connections. +sql_alchemy_pool_enabled = True + +# The SqlAlchemy pool size is the maximum number of database connections +# in the pool. 0 indicates no limit. +sql_alchemy_pool_size = 5 + +# The maximum overflow size of the pool. +# When the number of checked-out connections reaches the size set in pool_size, +# additional connections will be returned up to this limit. +# When those additional connections are returned to the pool, they are disconnected and discarded. +# It follows then that the total number of simultaneous connections the pool will allow +# is pool_size + max_overflow, +# and the total number of "sleeping" connections the pool will allow is pool_size. +# max_overflow can be set to ``-1`` to indicate no overflow limit; +# no limit will be placed on the total number of concurrent connections. Defaults to ``10``. +sql_alchemy_max_overflow = 10 + +# The SqlAlchemy pool recycle is the number of seconds a connection +# can be idle in the pool before it is invalidated. This config does +# not apply to sqlite. If the number of DB connections is ever exceeded, +# a lower config value will allow the system to recover faster. +sql_alchemy_pool_recycle = 1800 + +# Check connection at the start of each connection pool checkout. +# Typically, this is a simple statement like "SELECT 1". +# More information here: +# https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic +sql_alchemy_pool_pre_ping = True + +# The schema to use for the metadata database. +# SqlAlchemy supports databases with the concept of multiple schemas. +sql_alchemy_schema = + +# Import path for connect args in SqlAlchemy. Defaults to an empty dict. +# This is useful when you want to configure db engine args that SqlAlchemy won't parse +# in connection string. +# See https://docs.sqlalchemy.org/en/13/core/engines.html#sqlalchemy.create_engine.params.connect_args +# sql_alchemy_connect_args = + +# The amount of parallelism as a setting to the executor. This defines +# the max number of task instances that should run simultaneously +# on this airflow installation +parallelism = 32 + +# The number of task instances allowed to run concurrently by the scheduler +# in one DAG. Can be overridden by ``concurrency`` on DAG level. +dag_concurrency = 16 + +# Are DAGs paused by default at creation +dags_are_paused_at_creation = True + +# The maximum number of active DAG runs per DAG +max_active_runs_per_dag = 16 + +# Whether to load the DAG examples that ship with Airflow. It's good to +# get started, but you probably want to set this to ``False`` in a production +# environment +load_examples = false + +# Whether to load the default connections that ship with Airflow. It's good to +# get started, but you probably want to set this to ``False`` in a production +# environment +load_default_connections = false + +# Path to the folder containing Airflow plugins +plugins_folder = /opt/airflow/plugins + +# Should tasks be executed via forking of the parent process ("False", +# the speedier option) or by spawning a new python process ("True" slow, +# but means plugin changes picked up by tasks straight away) +execute_tasks_new_python_interpreter = False + +# Secret key to save connection passwords in the db +fernet_key = 7N-Rq5q9_xQkJXAQxoOaCHakHwDJdnG4KfBcxoNhhfU= + +# Whether to disable pickling dags +donot_pickle = True + +# How long before timing out a python file import +dagbag_import_timeout = 30.0 + +# Should a traceback be shown in the UI for dagbag import errors, +# instead of just the exception message +dagbag_import_error_tracebacks = True + +# If tracebacks are shown, how many entries from the traceback should be shown +dagbag_import_error_traceback_depth = 2 + +# How long before timing out a DagFileProcessor, which processes a dag file +dag_file_processor_timeout = 50 + +# The class to use for running task instances in a subprocess. +# Choices include StandardTaskRunner, CgroupTaskRunner or the full import path to the class +# when using a custom task runner. +task_runner = StandardTaskRunner + +# If set, tasks without a ``run_as_user`` argument will be run with this user +# Can be used to de-elevate a sudo user running Airflow when executing tasks +default_impersonation = + +# What security module to use (for example kerberos) +security = + +# Turn unit test mode on (overwrites many configuration options with test +# values at runtime) +unit_test_mode = False + +# Whether to enable pickling for xcom (note that this is insecure and allows for +# RCE exploits). +enable_xcom_pickling = False + +# When a task is killed forcefully, this is the amount of time in seconds that +# it has to cleanup after it is sent a SIGTERM, before it is SIGKILLED +killed_task_cleanup_time = 60 + +# Whether to override params with dag_run.conf. If you pass some key-value pairs +# through ``airflow dags backfill -c`` or +# ``airflow dags trigger -c``, the key-value pairs will override the existing ones in params. +dag_run_conf_overrides_params = True + +# When discovering DAGs, ignore any files that don't contain the strings ``DAG`` and ``airflow``. +dag_discovery_safe_mode = True + +# The number of retries each task is going to have by default. Can be overridden at dag or task level. +default_task_retries = 0 + +# Updating serialized DAG can not be faster than a minimum interval to reduce database write rate. +min_serialized_dag_update_interval = 30 + +# Fetching serialized DAG can not be faster than a minimum interval to reduce database +# read rate. This config controls when your DAGs are updated in the Webserver +min_serialized_dag_fetch_interval = 10 + +# Whether to persist DAG files code in DB. +# If set to True, Webserver reads file contents from DB instead of +# trying to access files in a DAG folder. +# Example: store_dag_code = False +# store_dag_code = + +# Maximum number of Rendered Task Instance Fields (Template Fields) per task to store +# in the Database. +# All the template_fields for each of Task Instance are stored in the Database. +# Keeping this number small may cause an error when you try to view ``Rendered`` tab in +# TaskInstance view for older tasks. +max_num_rendered_ti_fields_per_task = 30 + +# On each dagrun check against defined SLAs +check_slas = True + +# Path to custom XCom class that will be used to store and resolve operators results +# Example: xcom_backend = path.to.CustomXCom +xcom_backend = airflow.models.xcom.BaseXCom + +# By default Airflow plugins are lazily-loaded (only loaded when required). Set it to ``False``, +# if you want to load plugins whenever 'airflow' is invoked via cli or loaded from module. +lazy_load_plugins = True + +# By default Airflow providers are lazily-discovered (discovery and imports happen only when required). +# Set it to False, if you want to discover providers whenever 'airflow' is invoked via cli or +# loaded from module. +lazy_discover_providers = True + +# Number of times the code should be retried in case of DB Operational Errors. +# Not all transactions will be retried as it can cause undesired state. +# Currently it is only used in ``DagFileProcessor.process_file`` to retry ``dagbag.sync_to_db``. +max_db_retries = 3 + +[logging] +# The folder where airflow should store its log files +# This path must be absolute +base_log_folder = /opt/airflow/logs + +# Airflow can store logs remotely in AWS S3, Google Cloud Storage or Elastic Search. +# Set this to True if you want to enable remote logging. +remote_logging = False + +# Users must supply an Airflow connection id that provides access to the storage +# location. +remote_log_conn_id = + +# Path to Google Credential JSON file. If omitted, authorization based on `the Application Default +# Credentials +# `__ will +# be used. +google_key_path = + +# Storage bucket URL for remote logging +# S3 buckets should start with "s3://" +# Cloudwatch log groups should start with "cloudwatch://" +# GCS buckets should start with "gs://" +# WASB buckets should start with "wasb" just to help Airflow select correct handler +# Stackdriver logs should start with "stackdriver://" +remote_base_log_folder = + +# Use server-side encryption for logs stored in S3 +encrypt_s3_logs = False + +# Logging level +logging_level = INFO + +# Logging level for Flask-appbuilder UI +fab_logging_level = WARN + +# Logging class +# Specify the class that will specify the logging configuration +# This class has to be on the python classpath +# Example: logging_config_class = my.path.default_local_settings.LOGGING_CONFIG +logging_config_class = +# logging_config_class = log_config.LOGGING_CONFIG + +# Flag to enable/disable Colored logs in Console +# Colour the logs when the controlling terminal is a TTY. +colored_console_log = True + +# Log format for when Colored logs is enabled +colored_log_format = [%%(blue)s%%(asctime)s%%(reset)s] {%%(blue)s%%(filename)s:%%(reset)s%%(lineno)d} %%(log_color)s%%(levelname)s%%(reset)s - %%(log_color)s%%(message)s%%(reset)s +colored_formatter_class = airflow.utils.log.colored_log.CustomTTYColoredFormatter + +# Format of Log line +log_format = [%%(asctime)s] {%%(filename)s:%%(lineno)d} %%(levelname)s - %%(message)s +simple_log_format = %%(asctime)s %%(levelname)s - %%(message)s + +# Specify prefix pattern like mentioned below with stream handler TaskHandlerWithCustomFormatter +# Example: task_log_prefix_template = {ti.dag_id}-{ti.task_id}-{execution_date}-{try_number} +task_log_prefix_template = + +# Formatting for how airflow generates file names/paths for each task run. +log_filename_template = {{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log + +# Formatting for how airflow generates file names for log +log_processor_filename_template = {{ filename }}.log + +# full path of dag_processor_manager logfile +dag_processor_manager_log_location = /opt/airflow/logs/dag_processor_manager/dag_processor_manager.log + +# Name of handler to read task instance logs. +# Defaults to use ``task`` handler. +task_log_reader = task + +# A comma\-separated list of third-party logger names that will be configured to print messages to +# consoles\. +# Example: extra_loggers = connexion,sqlalchemy +extra_loggers = + +[metrics] + +# StatsD (https://github.com/etsy/statsd) integration settings. +# Enables sending metrics to StatsD. +statsd_on = False +statsd_host = localhost +statsd_port = 8125 +statsd_prefix = airflow + +# If you want to avoid sending all the available metrics to StatsD, +# you can configure an allow list of prefixes (comma separated) to send only the metrics that +# start with the elements of the list (e.g: "scheduler,executor,dagrun") +statsd_allow_list = + +# A function that validate the statsd stat name, apply changes to the stat name if necessary and return +# the transformed stat name. +# +# The function should have the following signature: +# def func_name(stat_name: str) -> str: +stat_name_handler = + +# To enable datadog integration to send airflow metrics. +statsd_datadog_enabled = False + +# List of datadog tags attached to all metrics(e.g: key1:value1,key2:value2) +statsd_datadog_tags = + +# If you want to utilise your own custom Statsd client set the relevant +# module path below. +# Note: The module path must exist on your PYTHONPATH for Airflow to pick it up +# statsd_custom_client_path = + +[secrets] +# Full class name of secrets backend to enable (will precede env vars and metastore in search path) +# Example: backend = airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend + +# backend = +backend = airflow.providers.hashicorp.secrets.vault.VaultBackend + +# The backend_kwargs param is loaded into a dictionary and passed to __init__ of secrets backend class. +# See documentation for the secrets backend you are using. JSON is expected. +# Example for AWS Systems Manager ParameterStore: +# ``{"connections_prefix": "/airflow/connections", "profile_name": "default"}`` + +# backend_kwargs = +# backend_kwargs = {"connections_path": "connections", "variables_path": "variables", "mount_point": "airflow", "url": "http://192.168.1.101:8200", "token": "token} +# https://github.com/astronomer/webinar-secrets-management/blob/4d87161c871840be70f8e62c49c51324ddc53abb/Dockerfile +backend_kwargs = {"connections_path": "airflow/connections", + "variables_path": null, + "config_path": null, + "url": "http://192.168.1.101:8200", + "token": os.env(vault_token) + } +# s.v3UKEnsjWDKtoDaurV7y9BHQ + +[cli] +# In what way should the cli access the API. The LocalClient will use the +# database directly, while the json_client will use the api running on the +# webserver +api_client = airflow.api.client.local_client + +# If you set web_server_url_prefix, do NOT forget to append it here, ex: +# ``endpoint_url = http://localhost:8080/myroot`` +# So api will look like: ``http://localhost:8080/myroot/api/experimental/...`` +endpoint_url = http://localhost:8080 + +[debug] +# Used only with ``DebugExecutor``. If set to ``True`` DAG will fail with first +# failed task. Helpful for debugging purposes. +fail_fast = False + +[api] +# Enables the deprecated experimental API. Please note that these APIs do not have access control. +# The authenticated user has full access. +# +# .. warning:: +# +# This `Experimental REST API `__ is +# deprecated since version 2.0. Please consider using +# `the Stable REST API `__. +# For more information on migration, see +# `UPDATING.md `_ +enable_experimental_api = False + +# How to authenticate users of the API. See +# https://airflow.apache.org/docs/stable/security.html for possible values. +# ("airflow.api.auth.backend.default" allows all requests for historic reasons) +auth_backend = airflow.api.auth.backend.deny_all + +# Used to set the maximum page limit for API requests +maximum_page_limit = 100 + +# Used to set the default page limit when limit is zero. A default limit +# of 100 is set on OpenApi spec. However, this particular default limit +# only work when limit is set equal to zero(0) from API requests. +# If no limit is supplied, the OpenApi spec default is used. +fallback_page_limit = 100 + +# The intended audience for JWT token credentials used for authorization. This value must match on the client and server sides. If empty, audience will not be tested. +# Example: google_oauth2_audience = project-id-random-value.apps.googleusercontent.com +google_oauth2_audience = + +# Path to Google Cloud Service Account key file (JSON). If omitted, authorization based on +# `the Application Default Credentials +# `__ will +# be used. +# Example: google_key_path = /files/service-account-json +google_key_path = + +[lineage] +# what lineage backend to use +backend = + +[atlas] +sasl_enabled = False +host = +port = 21000 +username = +password = + +[operators] +# The default owner assigned to each new operator, unless +# provided explicitly or passed via ``default_args`` +default_owner = airflow +default_cpus = 1 +default_ram = 512 +default_disk = 512 +default_gpus = 0 + +# Is allowed to pass additional/unused arguments (args, kwargs) to the BaseOperator operator. +# If set to False, an exception will be thrown, otherwise only the console message will be displayed. +allow_illegal_arguments = False + +[hive] +# Default mapreduce queue for HiveOperator tasks +default_hive_mapred_queue = + +# Template for mapred_job_name in HiveOperator, supports the following named parameters +# hostname, dag_id, task_id, execution_date +# mapred_job_name_template = + +[webserver] +# The base url of your website as airflow cannot guess what domain or +# cname you are using. This is used in automated emails that +# airflow sends to point links to the right web server +# base_url = https://airflow.donavanaldrich.com +base_url = http://localhost:8080 + +# Default timezone to display all dates in the UI, can be UTC, system, or +# any IANA timezone string (e.g. Europe/Amsterdam). If left empty the +# default value of core/default_timezone will be used +# Example: default_ui_timezone = America/New_York +default_ui_timezone = America/Chicago + +# The ip specified when starting the web server +web_server_host = 0.0.0.0 + +# The port on which to run the web server +web_server_port = 8080 + +# Paths to the SSL certificate and key for the web server. When both are +# provided SSL will be enabled. This does not change the web server port. +web_server_ssl_cert = + +# Paths to the SSL certificate and key for the web server. When both are +# provided SSL will be enabled. This does not change the web server port. +web_server_ssl_key = + +# Number of seconds the webserver waits before killing gunicorn master that doesn't respond +web_server_master_timeout = 120 + +# Number of seconds the gunicorn webserver waits before timing out on a worker +web_server_worker_timeout = 120 + +# Number of workers to refresh at a time. When set to 0, worker refresh is +# disabled. When nonzero, airflow periodically refreshes webserver workers by +# bringing up new ones and killing old ones. +worker_refresh_batch_size = 1 + +# Number of seconds to wait before refreshing a batch of workers. +worker_refresh_interval = 30 + +# If set to True, Airflow will track files in plugins_folder directory. When it detects changes, +# then reload the gunicorn. +reload_on_plugin_change = True + +# Secret key used to run your flask app +# It should be as random as possible +secret_key = /9xL+wMDDZAp0Arm0wxMeg== + +# Number of workers to run the Gunicorn web server +workers = 4 + +# The worker class gunicorn should use. Choices include +# sync (default), eventlet, gevent +worker_class = sync + +# Log files for the gunicorn webserver. '-' means log to stderr. +access_logfile = - + +# Log files for the gunicorn webserver. '-' means log to stderr. +error_logfile = - + +# Access log format for gunicorn webserver. +# default format is %%(h)s %%(l)s %%(u)s %%(t)s "%%(r)s" %%(s)s %%(b)s "%%(f)s" "%%(a)s" +# documentation - https://docs.gunicorn.org/en/stable/settings.html#access-log-format +access_logformat = + +# Expose the configuration file in the web server +expose_config = True + +# Expose hostname in the web server +expose_hostname = True + +# Expose stacktrace in the web server +expose_stacktrace = True + +# Default DAG view. Valid values are: ``tree``, ``graph``, ``duration``, ``gantt``, ``landing_times`` +dag_default_view = tree + +# Default DAG orientation. Valid values are: +# ``LR`` (Left->Right), ``TB`` (Top->Bottom), ``RL`` (Right->Left), ``BT`` (Bottom->Top) +dag_orientation = TB + +# Puts the webserver in demonstration mode; blurs the names of Operators for +# privacy. +demo_mode = False + +# The amount of time (in secs) webserver will wait for initial handshake +# while fetching logs from other worker machine +log_fetch_timeout_sec = 5 + +# Time interval (in secs) to wait before next log fetching. +log_fetch_delay_sec = 2 + +# Distance away from page bottom to enable auto tailing. +log_auto_tailing_offset = 30 + +# Animation speed for auto tailing log display. +log_animation_speed = 1000 + +# By default, the webserver shows paused DAGs. Flip this to hide paused +# DAGs by default +hide_paused_dags_by_default = False + +# Consistent page size across all listing views in the UI +page_size = 100 + +# Define the color of navigation bar +navbar_color = #00008b + +# Default dagrun to show in UI +default_dag_run_display_number = 25 + +# Enable werkzeug ``ProxyFix`` middleware for reverse proxy +enable_proxy_fix = True + +# Number of values to trust for ``X-Forwarded-For``. +# More info: https://werkzeug.palletsprojects.com/en/0.16.x/middleware/proxy_fix/ +proxy_fix_x_for = 1 + +# Number of values to trust for ``X-Forwarded-Proto`` +proxy_fix_x_proto = 1 + +# Number of values to trust for ``X-Forwarded-Host`` +proxy_fix_x_host = 1 + +# Number of values to trust for ``X-Forwarded-Port`` +proxy_fix_x_port = 1 + +# Number of values to trust for ``X-Forwarded-Prefix`` +proxy_fix_x_prefix = 1 + +# Set secure flag on session cookie +cookie_secure = False + +# Set samesite policy on session cookie +cookie_samesite = Lax + +# Default setting for wrap toggle on DAG code and TI log views. +default_wrap = False + +# Allow the UI to be rendered in a frame +x_frame_enabled = True + +# Send anonymous user activity to your analytics tool +# choose from google_analytics, segment, or metarouter +# analytics_tool = + +# Unique ID of your account in the analytics tool +# analytics_id = + +# 'Recent Tasks' stats will show for old DagRuns if set +show_recent_stats_for_completed_runs = True + +# Update FAB permissions and sync security manager roles +# on webserver startup +update_fab_perms = True + +# The UI cookie lifetime in minutes. User will be logged out from UI after +# ``session_lifetime_minutes`` of non-activity +session_lifetime_minutes = 43200 + +# Use FAB-based webserver with RBAC feature +rbac = True +authenticate = True +auth_backend = airflow.contrib.auth.backends.password_auth + +[email] + +# Configuration email backend and whether to +# send email alerts on retry or failure +# Email backend to use +email_backend = airflow.utils.email.send_email_smtp + +# Whether email alerts should be sent when a task is retried +default_email_on_retry = False + +# Whether email alerts should be sent when a task failed +default_email_on_failure = False + +# File that will be used as the template for Email subject (which will be rendered using Jinja2). +# If not set, Airflow uses a base template. +# Example: subject_template = /path/to/my_subject_template_file +# subject_template = + +# File that will be used as the template for Email content (which will be rendered using Jinja2). +# If not set, Airflow uses a base template. +# Example: html_content_template = /path/to/my_html_content_template_file +# html_content_template = + +[smtp] + +# If you want airflow to send emails on retries, failure, and you want to use +# the airflow.utils.email.send_email_smtp function, you have to configure an +# smtp server here +smtp_host = localhost +smtp_starttls = True +smtp_ssl = False +# Example: smtp_user = airflow +# smtp_user = +# Example: smtp_password = airflow +# smtp_password = +smtp_port = 25 +smtp_mail_from = airflow@example.com +smtp_timeout = 30 +smtp_retry_limit = 5 + +[sentry] + +# Sentry (https://docs.sentry.io) integration. Here you can supply +# additional configuration options based on the Python platform. See: +# https://docs.sentry.io/error-reporting/configuration/?platform=python. +# Unsupported options: ``integrations``, ``in_app_include``, ``in_app_exclude``, +# ``ignore_errors``, ``before_breadcrumb``, ``before_send``, ``transport``. +# Enable error reporting to Sentry +sentry_on = false +sentry_dsn = + +[celery_kubernetes_executor] + +# This section only applies if you are using the ``CeleryKubernetesExecutor`` in +# ``[core]`` section above +# Define when to send a task to ``KubernetesExecutor`` when using ``CeleryKubernetesExecutor``. +# When the queue of a task is the value of ``kubernetes_queue`` (default ``kubernetes``), +# the task is executed via ``KubernetesExecutor``, +# otherwise via ``CeleryExecutor`` +kubernetes_queue = kubernetes + +[celery] + +# This section only applies if you are using the CeleryExecutor in +# ``[core]`` section above +# The app name that will be used by celery +celery_app_name = airflow.executors.celery_executor + +# The concurrency that will be used when starting workers with the +# ``airflow celery worker`` command. This defines the number of task instances that +# a worker will take, so size up your workers based on the resources on +# your worker box and the nature of your tasks +worker_concurrency = 16 + +# The maximum and minimum concurrency that will be used when starting workers with the +# ``airflow celery worker`` command (always keep minimum processes, but grow +# to maximum if necessary). Note the value should be max_concurrency,min_concurrency +# Pick these numbers based on resources on worker box and the nature of the task. +# If autoscale option is available, worker_concurrency will be ignored. +# http://docs.celeryproject.org/en/latest/reference/celery.bin.worker.html#cmdoption-celery-worker-autoscale +# Example: worker_autoscale = 16,12 +# worker_autoscale = + +# Used to increase the number of tasks that a worker prefetches which can improve performance. +# The number of processes multiplied by worker_prefetch_multiplier is the number of tasks +# that are prefetched by a worker. A value greater than 1 can result in tasks being unnecessarily +# blocked if there are multiple workers and one worker prefetches tasks that sit behind long +# running tasks while another worker has unutilized processes that are unable to process the already +# claimed blocked tasks. +# https://docs.celeryproject.org/en/stable/userguide/optimizing.html#prefetch-limits +# Example: worker_prefetch_multiplier = 1 +# worker_prefetch_multiplier = + +# When you start an airflow worker, airflow starts a tiny web server +# subprocess to serve the workers local log files to the airflow main +# web server, who then builds pages and sends them to users. This defines +# the port on which the logs are served. It needs to be unused, and open +# visible from the main web server to connect into the workers. +worker_log_server_port = 8793 + +# Umask that will be used when starting workers with the ``airflow celery worker`` +# in daemon mode. This control the file-creation mode mask which determines the initial +# value of file permission bits for newly created files. +worker_umask = 0o077 + +# The Celery broker URL. Celery supports RabbitMQ, Redis and experimentally +# a sqlalchemy database. Refer to the Celery documentation for more information. +broker_url = redis://redis:6379/0 + +# The Celery result_backend. When a job finishes, it needs to update the +# metadata of the job. Therefore it will post a message on a message bus, +# or insert it into a database (depending of the backend) +# This status is used by the scheduler to update the state of the task +# The use of a database is highly recommended +# http://docs.celeryproject.org/en/latest/userguide/configuration.html#task-result-backend-settings +result_backend = db+postgresql://postgres:airflow@postgres/airflow + +# Celery Flower is a sweet UI for Celery. Airflow has a shortcut to start +# it ``airflow celery flower``. This defines the IP that Celery Flower runs on +flower_host = 0.0.0.0 + +# The root URL for Flower +# Example: flower_url_prefix = /flower +flower_url_prefix = + +# This defines the port that Celery Flower runs on +flower_port = 5555 + +# Securing Flower with Basic Authentication +# Accepts user:password pairs separated by a comma +# Example: flower_basic_auth = user1:password1,user2:password2 +flower_basic_auth = + +# Default queue that tasks get assigned to and that worker listen on. +default_queue = default + +# How many processes CeleryExecutor uses to sync task state. +# 0 means to use max(1, number of cores - 1) processes. +sync_parallelism = 0 + +# Import path for celery configuration options +celery_config_options = airflow.config_templates.default_celery.DEFAULT_CELERY_CONFIG +ssl_active = False +ssl_key = +ssl_cert = +ssl_cacert = + +# Celery Pool implementation. +# Choices include: ``prefork`` (default), ``eventlet``, ``gevent`` or ``solo``. +# See: +# https://docs.celeryproject.org/en/latest/userguide/workers.html#concurrency +# https://docs.celeryproject.org/en/latest/userguide/concurrency/eventlet.html +pool = prefork + +# The number of seconds to wait before timing out ``send_task_to_executor`` or +# ``fetch_celery_task_state`` operations. +operation_timeout = 1.0 + +# Celery task will report its status as 'started' when the task is executed by a worker. +# This is used in Airflow to keep track of the running tasks and if a Scheduler is restarted +# or run in HA mode, it can adopt the orphan tasks launched by previous SchedulerJob. +task_track_started = True + +# Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear +# stalled tasks. +task_adoption_timeout = 600 + +# The Maximum number of retries for publishing task messages to the broker when failing +# due to ``AirflowTaskTimeout`` error before giving up and marking Task as failed. +task_publish_max_retries = 3 + +# Worker initialisation check to validate Metadata Database connection +worker_precheck = False + +[celery_broker_transport_options] + +# This section is for specifying options which can be passed to the +# underlying celery broker transport. See: +# http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-broker_transport_options +# The visibility timeout defines the number of seconds to wait for the worker +# to acknowledge the task before the message is redelivered to another worker. +# Make sure to increase the visibility timeout to match the time of the longest +# ETA you're planning to use. +# visibility_timeout is only supported for Redis and SQS celery brokers. +# See: +# http://docs.celeryproject.org/en/master/userguide/configuration.html#std:setting-broker_transport_options +# Example: visibility_timeout = 21600 +# visibility_timeout = + +[dask] + +# This section only applies if you are using the DaskExecutor in +# [core] section above +# The IP address and port of the Dask cluster's scheduler. +cluster_address = 127.0.0.1:8786 + +# TLS/ SSL settings to access a secured Dask scheduler. +tls_ca = +tls_cert = +tls_key = + +[scheduler] +# Task instances listen for external kill signal (when you clear tasks +# from the CLI or the UI), this defines the frequency at which they should +# listen (in seconds). +job_heartbeat_sec = 5 + +# How often (in seconds) to check and tidy up 'running' TaskInstancess +# that no longer have a matching DagRun +clean_tis_without_dagrun_interval = 15.0 + +# The scheduler constantly tries to trigger new tasks (look at the +# scheduler section in the docs for more information). This defines +# how often the scheduler should run (in seconds). +scheduler_heartbeat_sec = 5 + +# The number of times to try to schedule each DAG file +# -1 indicates unlimited number +num_runs = -1 + +# The number of seconds to wait between consecutive DAG file processing +processor_poll_interval = 1 + +# Number of seconds after which a DAG file is parsed. The DAG file is parsed every +# ``min_file_process_interval`` number of seconds. Updates to DAGs are reflected after +# this interval. Keeping this number low will increase CPU usage. +min_file_process_interval = 30 + +# How often (in seconds) to scan the DAGs directory for new files. Default to 5 minutes. +dag_dir_list_interval = 300 + +# How often should stats be printed to the logs. Setting to 0 will disable printing stats +print_stats_interval = 30 + +# How often (in seconds) should pool usage stats be sent to statsd (if statsd_on is enabled) +pool_metrics_interval = 5.0 + +# If the last scheduler heartbeat happened more than scheduler_health_check_threshold +# ago (in seconds), scheduler is considered unhealthy. +# This is used by the health check in the "/health" endpoint +scheduler_health_check_threshold = 30 + +# How often (in seconds) should the scheduler check for orphaned tasks and SchedulerJobs +orphaned_tasks_check_interval = 300.0 +child_process_log_directory = /opt/airflow/logs/scheduler + +# Local task jobs periodically heartbeat to the DB. If the job has +# not heartbeat in this many seconds, the scheduler will mark the +# associated task instance as failed and will re-schedule the task. +scheduler_zombie_task_threshold = 300 + +# Turn off scheduler catchup by setting this to ``False``. +# Default behavior is unchanged and +# Command Line Backfills still work, but the scheduler +# will not do scheduler catchup if this is ``False``, +# however it can be set on a per DAG basis in the +# DAG definition (catchup) +catchup_by_default = True + +# This changes the batch size of queries in the scheduling main loop. +# If this is too high, SQL query performance may be impacted by one +# or more of the following: +# - reversion to full table scan +# - complexity of query predicate +# - excessive locking +# Additionally, you may hit the maximum allowable query length for your db. +# Set this to 0 for no limit (not advised) +max_tis_per_query = 512 + +# Should the scheduler issue ``SELECT ... FOR UPDATE`` in relevant queries. +# If this is set to False then you should not run more than a single +# scheduler at once +use_row_level_locking = True + +# Max number of DAGs to create DagRuns for per scheduler loop +# +# Default: 10 +# max_dagruns_to_create_per_loop = + +# How many DagRuns should a scheduler examine (and lock) when scheduling +# and queuing tasks. +# +# Default: 20 +# max_dagruns_per_loop_to_schedule = + +# Should the Task supervisor process perform a "mini scheduler" to attempt to schedule more tasks of the +# same DAG. Leaving this on will mean tasks in the same DAG execute quicker, but might starve out other +# dags in some circumstances +# +# Default: True +# schedule_after_task_execution = + +# The scheduler can run multiple processes in parallel to parse dags. +# This defines how many processes will run. +parsing_processes = 2 + +# Turn off scheduler use of cron intervals by setting this to False. +# DAGs submitted manually in the web UI or with trigger_dag will still run. +use_job_schedule = True + +# Allow externally triggered DagRuns for Execution Dates in the future +# Only has effect if schedule_interval is set to None in DAG +allow_trigger_in_future = False + +[kerberos] +ccache = /tmp/airflow_krb5_ccache + +# gets augmented with fqdn +principal = airflow +reinit_frequency = 3600 +kinit_path = kinit +keytab = airflow.keytab + +[github_enterprise] +api_rev = v3 + +[admin] +# UI to hide sensitive variable fields when set to True +hide_sensitive_variable_fields = True + +# A comma-separated list of sensitive keywords to look for in variables names. +sensitive_variable_fields = + +[elasticsearch] +# Elasticsearch host +host = + +# Format of the log_id, which is used to query for a given tasks logs +log_id_template = {dag_id}-{task_id}-{execution_date}-{try_number} + +# Used to mark the end of a log stream for a task +end_of_log_mark = end_of_log + +# Qualified URL for an elasticsearch frontend (like Kibana) with a template argument for log_id +# Code will construct log_id using the log_id template from the argument above. +# NOTE: The code will prefix the https:// automatically, don't include that here. +frontend = + +# Write the task logs to the stdout of the worker, rather than the default files +write_stdout = False + +# Instead of the default log formatter, write the log lines as JSON +json_format = False + +# Log fields to also attach to the json output, if enabled +json_fields = asctime, filename, lineno, levelname, message + +[elasticsearch_configs] +use_ssl = False +verify_certs = True + +[kubernetes] +# Path to the YAML pod file. If set, all other kubernetes-related fields are ignored. +pod_template_file = + +# The repository of the Kubernetes Image for the Worker to Run +worker_container_repository = + +# The tag of the Kubernetes Image for the Worker to Run +worker_container_tag = + +# The Kubernetes namespace where airflow workers should be created. Defaults to ``default`` +namespace = default + +# If True, all worker pods will be deleted upon termination +delete_worker_pods = True + +# If False (and delete_worker_pods is True), +# failed worker pods will not be deleted so users can investigate them. +delete_worker_pods_on_failure = False + +# Number of Kubernetes Worker Pod creation calls per scheduler loop. +# Note that the current default of "1" will only launch a single pod +# per-heartbeat. It is HIGHLY recommended that users increase this +# number to match the tolerance of their kubernetes cluster for +# better performance. +worker_pods_creation_batch_size = 1 + +# Allows users to launch pods in multiple namespaces. +# Will require creating a cluster-role for the scheduler +multi_namespace_mode = False + +# Use the service account kubernetes gives to pods to connect to kubernetes cluster. +# It's intended for clients that expect to be running inside a pod running on kubernetes. +# It will raise an exception if called from a process not running in a kubernetes environment. +in_cluster = True + +# When running with in_cluster=False change the default cluster_context or config_file +# options to Kubernetes client. Leave blank these to use default behaviour like ``kubectl`` has. +# cluster_context = + +# Path to the kubernetes configfile to be used when ``in_cluster`` is set to False +# config_file = + +# Keyword parameters to pass while calling a kubernetes client core_v1_api methods +# from Kubernetes Executor provided as a single line formatted JSON dictionary string. +# List of supported params are similar for all core_v1_apis, hence a single config +# variable for all apis. See: +# https://raw.githubusercontent.com/kubernetes-client/python/41f11a09995efcd0142e25946adc7591431bfb2f/kubernetes/client/api/core_v1_api.py +kube_client_request_args = + +# Optional keyword arguments to pass to the ``delete_namespaced_pod`` kubernetes client +# ``core_v1_api`` method when using the Kubernetes Executor. +# This should be an object and can contain any of the options listed in the ``v1DeleteOptions`` +# class defined here: +# https://github.com/kubernetes-client/python/blob/41f11a09995efcd0142e25946adc7591431bfb2f/kubernetes/client/models/v1_delete_options.py#L19 +# Example: delete_option_kwargs = {"grace_period_seconds": 10} +delete_option_kwargs = + +# Enables TCP keepalive mechanism. This prevents Kubernetes API requests to hang indefinitely +# when idle connection is time-outed on services like cloud load balancers or firewalls. +enable_tcp_keepalive = False + +# When the `enable_tcp_keepalive` option is enabled, TCP probes a connection that has +# been idle for `tcp_keep_idle` seconds. +tcp_keep_idle = 120 + +# When the `enable_tcp_keepalive` option is enabled, if Kubernetes API does not respond +# to a keepalive probe, TCP retransmits the probe after `tcp_keep_intvl` seconds. +tcp_keep_intvl = 30 + +# When the `enable_tcp_keepalive` option is enabled, if Kubernetes API does not respond +# to a keepalive probe, TCP retransmits the probe `tcp_keep_cnt number` of times before +# a connection is considered to be broken. +tcp_keep_cnt = 6 + +[smart_sensor] +# When `use_smart_sensor` is True, Airflow redirects multiple qualified sensor tasks to +# smart sensor task. +use_smart_sensor = False + +# `shard_code_upper_limit` is the upper limit of `shard_code` value. The `shard_code` is generated +# by `hashcode % shard_code_upper_limit`. +shard_code_upper_limit = 10000 + +# The number of running smart sensor processes for each service. +shards = 5 + +# comma separated sensor classes support in smart_sensor. +sensors_enabled = NamedHivePartitionSensor diff --git a/config/airflow_local_settings.py b/config/airflow_local_settings.py new file mode 100644 index 0000000..4454dff --- /dev/null +++ b/config/airflow_local_settings.py @@ -0,0 +1,11 @@ +STATE_COLORS = { + "queued": "darkgray", + "running": "blue", + "success": "cobalt", + "failed": "firebrick", + "up_for_retry": "yellow", + "up_for_reschedule": "turquoise", + "upstream_failed": "orange", + "skipped": "darkorchid", + "scheduled": "tan", +} diff --git a/config/log_config.py b/config/log_config.py new file mode 100644 index 0000000..30e8967 --- /dev/null +++ b/config/log_config.py @@ -0,0 +1,5 @@ +from copy import deepcopy + +from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG + +LOGGING_CONFIG = deepcopy(DEFAULT_LOGGING_CONFIG) diff --git a/config/vault.py b/config/vault.py new file mode 100644 index 0000000..989687a --- /dev/null +++ b/config/vault.py @@ -0,0 +1,23 @@ +from datetime import datetime + +from airflow import DAG +from airflow.hooks.base_hook import BaseHook +from airflow.operators.python_operator import PythonOperator + + +def get_secrets(**kwargs): + conn = BaseHook.get_connection(kwargs["my_conn_id"]) + print( + f"Password: {conn.password}, Login: {conn.login}, URI: {conn.get_uri()}, Host: {conn.host}" + ) + + +with DAG( + "example_secrets_dags", start_date=datetime(2020, 1, 1), schedule_interval=None +) as dag: + + test_task = PythonOperator( + task_id="test-task", + python_callable=get_secrets, + op_kwargs={"my_conn_id": "smtp_default"}, + ) diff --git a/config/webserver_config.py b/config/webserver_config.py new file mode 100644 index 0000000..780bc31 --- /dev/null +++ b/config/webserver_config.py @@ -0,0 +1,130 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default configuration for the Airflow webserver""" +import os + +from flask_appbuilder.security.manager import AUTH_DB + +# from flask_appbuilder.security.manager import AUTH_LDAP +# from flask_appbuilder.security.manager import AUTH_OAUTH +# from flask_appbuilder.security.manager import AUTH_OID +# from flask_appbuilder.security.manager import AUTH_REMOTE_USER + + +basedir = os.path.abspath(os.path.dirname(__file__)) + +# Flask-WTF flag for CSRF +WTF_CSRF_ENABLED = True + +# ---------------------------------------------------- +# AUTHENTICATION CONFIG +# ---------------------------------------------------- +# For details on how to set up each of the following authentication, see +# http://flask-appbuilder.readthedocs.io/en/latest/security.html# authentication-methods +# for details. + +# The authentication type +# AUTH_OID : Is for OpenID +# AUTH_DB : Is for database +# AUTH_LDAP : Is for LDAP +# AUTH_REMOTE_USER : Is for using REMOTE_USER from web server +# AUTH_OAUTH : Is for OAuth +AUTH_TYPE = AUTH_DB + +# Uncomment to setup Full admin role name +# AUTH_ROLE_ADMIN = 'Admin' + +# Uncomment to setup Public role name, no authentication needed +# AUTH_ROLE_PUBLIC = 'Public' + +# Will allow user self registration +# AUTH_USER_REGISTRATION = True + +# The recaptcha it's automatically enabled for user self registration is active and the keys are necessary +# RECAPTCHA_PRIVATE_KEY = PRIVATE_KEY +# RECAPTCHA_PUBLIC_KEY = PUBLIC_KEY + +# Config for Flask-Mail necessary for user self registration +# MAIL_SERVER = 'smtp.gmail.com' +# MAIL_USE_TLS = True +# MAIL_USERNAME = 'yourappemail@gmail.com' +# MAIL_PASSWORD = 'passwordformail' +# MAIL_DEFAULT_SENDER = 'sender@gmail.com' + +# The default user self registration role +# AUTH_USER_REGISTRATION_ROLE = "Public" + +# When using OAuth Auth, uncomment to setup provider(s) info +# Google OAuth example: +# OAUTH_PROVIDERS = [{ +# 'name':'google', +# 'token_key':'access_token', +# 'icon':'fa-google', +# 'remote_app': { +# 'api_base_url':'https://www.googleapis.com/oauth2/v2/', +# 'client_kwargs':{ +# 'scope': 'email profile' +# }, +# 'access_token_url':'https://accounts.google.com/o/oauth2/token', +# 'authorize_url':'https://accounts.google.com/o/oauth2/auth', +# 'request_token_url': None, +# 'client_id': GOOGLE_KEY, +# 'client_secret': GOOGLE_SECRET_KEY, +# } +# }] + +# When using LDAP Auth, setup the ldap server +# AUTH_LDAP_SERVER = "ldap://ldapserver.new" + +# When using OpenID Auth, uncomment to setup OpenID providers. +# example for OpenID authentication +# OPENID_PROVIDERS = [ +# { 'name': 'Yahoo', 'url': 'https://me.yahoo.com' }, +# { 'name': 'AOL', 'url': 'http://openid.aol.com/' }, +# { 'name': 'Flickr', 'url': 'http://www.flickr.com/' }, +# { 'name': 'MyOpenID', 'url': 'https://www.myopenid.com' }] + +# ---------------------------------------------------- +# Theme CONFIG +# ---------------------------------------------------- +# Flask App Builder comes up with a number of predefined themes +# that you can use for Apache Airflow. +# http://flask-appbuilder.readthedocs.io/en/latest/customizing.html#changing-themes +# Please make sure to remove "navbar_color" configuration from airflow.cfg +# in order to fully utilize the theme. (or use that property in conjunction with theme) +# APP_THEME = "bootstrap-theme.css" # default bootstrap +# APP_THEME = "amelia.css" +# APP_THEME = "cerulean.css" +APP_THEME = "cosmo.css" +# APP_THEME = "cyborg.css" +# APP_THEME = "darkly.css" +# APP_THEME = "flatly.css" +# APP_THEME = "journal.css" +# APP_THEME = "lumen.css" +# APP_THEME = "paper.css" +# APP_THEME = "readable.css" +# APP_THEME = "sandstone.css" +# APP_THEME = "simplex.css" +# APP_THEME = "slate.css" +# APP_THEME = "solar.css" +# APP_THEME = "spacelab.css" +# APP_THEME = "superhero.css" +# APP_THEME = "united.css" +# APP_THEME = "yeti.css" + +# FAB_STATIC_FOLDER = diff --git a/dags/complex.py b/dags/complex.py new file mode 100644 index 0000000..d9bd525 --- /dev/null +++ b/dags/complex.py @@ -0,0 +1,264 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that shows the complex DAG structure. +""" + +from airflow import models +from airflow.models.baseoperator import chain +from airflow.operators.bash import BashOperator +from airflow.operators.python import PythonOperator +from airflow.utils.dates import days_ago + +with models.DAG( + dag_id="example_complex", + schedule_interval=None, + start_date=days_ago(1), + tags=["example", "example2", "example3"], +) as dag: + + # Create + create_entry_group = BashOperator( + task_id="create_entry_group", bash_command="echo create_entry_group" + ) + + create_entry_group_result = BashOperator( + task_id="create_entry_group_result", + bash_command="echo create_entry_group_result", + ) + + create_entry_group_result2 = BashOperator( + task_id="create_entry_group_result2", + bash_command="echo create_entry_group_result2", + ) + + create_entry_gcs = BashOperator( + task_id="create_entry_gcs", bash_command="echo create_entry_gcs" + ) + + create_entry_gcs_result = BashOperator( + task_id="create_entry_gcs_result", bash_command="echo create_entry_gcs_result" + ) + + create_entry_gcs_result2 = BashOperator( + task_id="create_entry_gcs_result2", bash_command="echo create_entry_gcs_result2" + ) + + create_tag = BashOperator(task_id="create_tag", bash_command="echo create_tag") + + create_tag_result = BashOperator( + task_id="create_tag_result", bash_command="echo create_tag_result" + ) + + create_tag_result2 = BashOperator( + task_id="create_tag_result2", bash_command="echo create_tag_result2" + ) + + create_tag_template = BashOperator( + task_id="create_tag_template", bash_command="echo create_tag_template" + ) + + create_tag_template_result = BashOperator( + task_id="create_tag_template_result", + bash_command="echo create_tag_template_result", + ) + + create_tag_template_result2 = BashOperator( + task_id="create_tag_template_result2", + bash_command="echo create_tag_template_result2", + ) + + create_tag_template_field = BashOperator( + task_id="create_tag_template_field", + bash_command="echo create_tag_template_field", + ) + + create_tag_template_field_result = BashOperator( + task_id="create_tag_template_field_result", + bash_command="echo create_tag_template_field_result", + ) + + create_tag_template_field_result2 = BashOperator( + task_id="create_tag_template_field_result2", + bash_command="echo create_tag_template_field_result", + ) + + # Delete + delete_entry = BashOperator( + task_id="delete_entry", bash_command="echo delete_entry" + ) + create_entry_gcs >> delete_entry + + delete_entry_group = BashOperator( + task_id="delete_entry_group", bash_command="echo delete_entry_group" + ) + create_entry_group >> delete_entry_group + + delete_tag = BashOperator(task_id="delete_tag", bash_command="echo delete_tag") + create_tag >> delete_tag + + delete_tag_template_field = BashOperator( + task_id="delete_tag_template_field", + bash_command="echo delete_tag_template_field", + ) + + delete_tag_template = BashOperator( + task_id="delete_tag_template", bash_command="echo delete_tag_template" + ) + + # Get + get_entry_group = BashOperator( + task_id="get_entry_group", bash_command="echo get_entry_group" + ) + + get_entry_group_result = BashOperator( + task_id="get_entry_group_result", bash_command="echo get_entry_group_result" + ) + + get_entry = BashOperator(task_id="get_entry", bash_command="echo get_entry") + + get_entry_result = BashOperator( + task_id="get_entry_result", bash_command="echo get_entry_result" + ) + + get_tag_template = BashOperator( + task_id="get_tag_template", bash_command="echo get_tag_template" + ) + + get_tag_template_result = BashOperator( + task_id="get_tag_template_result", bash_command="echo get_tag_template_result" + ) + + # List + list_tags = BashOperator(task_id="list_tags", bash_command="echo list_tags") + + list_tags_result = BashOperator( + task_id="list_tags_result", bash_command="echo list_tags_result" + ) + + # Lookup + lookup_entry = BashOperator( + task_id="lookup_entry", bash_command="echo lookup_entry" + ) + + lookup_entry_result = BashOperator( + task_id="lookup_entry_result", bash_command="echo lookup_entry_result" + ) + + # Rename + rename_tag_template_field = BashOperator( + task_id="rename_tag_template_field", + bash_command="echo rename_tag_template_field", + ) + + # Search + search_catalog = PythonOperator( + task_id="search_catalog", python_callable=lambda: print("search_catalog") + ) + + search_catalog_result = BashOperator( + task_id="search_catalog_result", bash_command="echo search_catalog_result" + ) + + # Update + update_entry = BashOperator( + task_id="update_entry", bash_command="echo update_entry" + ) + + update_tag = BashOperator(task_id="update_tag", bash_command="echo update_tag") + + update_tag_template = BashOperator( + task_id="update_tag_template", bash_command="echo update_tag_template" + ) + + update_tag_template_field = BashOperator( + task_id="update_tag_template_field", + bash_command="echo update_tag_template_field", + ) + + # Create + create_tasks = [ + create_entry_group, + create_entry_gcs, + create_tag_template, + create_tag_template_field, + create_tag, + ] + chain(*create_tasks) + + create_entry_group >> delete_entry_group + create_entry_group >> create_entry_group_result + create_entry_group >> create_entry_group_result2 + + create_entry_gcs >> delete_entry + create_entry_gcs >> create_entry_gcs_result + create_entry_gcs >> create_entry_gcs_result2 + + create_tag_template >> delete_tag_template_field + create_tag_template >> create_tag_template_result + create_tag_template >> create_tag_template_result2 + + create_tag_template_field >> delete_tag_template_field + create_tag_template_field >> create_tag_template_field_result + create_tag_template_field >> create_tag_template_field_result2 + + create_tag >> delete_tag + create_tag >> create_tag_result + create_tag >> create_tag_result2 + + # Delete + delete_tasks = [ + delete_tag, + delete_tag_template_field, + delete_tag_template, + delete_entry_group, + delete_entry, + ] + chain(*delete_tasks) + + # Get + create_tag_template >> get_tag_template >> delete_tag_template + get_tag_template >> get_tag_template_result + + create_entry_gcs >> get_entry >> delete_entry + get_entry >> get_entry_result + + create_entry_group >> get_entry_group >> delete_entry_group + get_entry_group >> get_entry_group_result + + # List + create_tag >> list_tags >> delete_tag + list_tags >> list_tags_result + + # Lookup + create_entry_gcs >> lookup_entry >> delete_entry + lookup_entry >> lookup_entry_result + + # Rename + create_tag_template_field >> rename_tag_template_field >> delete_tag_template_field + + # Search + chain(create_tasks, search_catalog, delete_tasks) + search_catalog >> search_catalog_result + + # Update + create_entry_gcs >> update_entry >> delete_entry + create_tag >> update_tag >> delete_tag + create_tag_template >> update_tag_template >> delete_tag_template + create_tag_template_field >> update_tag_template_field >> rename_tag_template_field diff --git a/dags/crawler.py b/dags/crawler.py new file mode 100644 index 0000000..7419d68 --- /dev/null +++ b/dags/crawler.py @@ -0,0 +1,93 @@ +""" +Code that goes along with the Airflow located at: +http://airflow.readthedocs.org/en/latest/tutorial.html +""" +from datetime import datetime, timedelta + +from airflow import DAG + +# Operators; we need this to operate! +from airflow.operators.bash import BashOperator +from airflow.operators.dummy_operator import DummyOperator +from airflow.utils.dates import days_ago + +default_args = { + "owner": "donaldrich", + "depends_on_past": False, + "start_date": datetime(2016, 7, 13), + "email": ["email@gmail.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=15), +} + +script_path = "/data/scripts" +data_path = "/data/data/archive" + +dag = DAG( + "zip_docker", + default_args=default_args, + description="A simple tutorial DAG", + schedule_interval=None, + start_date=days_ago(2), + tags=["zip", "docker"], +) + +with dag: + + start = DummyOperator(task_id="start", dag=dag) + + pull1 = BashOperator( + task_id="pull_chrome", + bash_command="sudo docker pull selenoid/chrome:latest", + # retries=3, + # dag=dag + ) + + pull2 = BashOperator( + task_id="pull_recorder", + bash_command="sudo docker pull selenoid/video-recorder:latest-release", + # retries=3, + # dag=dag + ) + + scrape = BashOperator( + task_id="scrape_listings", + bash_command="sh /data/scripts/scrape.sh -s target -k docker", + # retries=3, + # dag=dag + ) + + cleanup = BashOperator( + task_id="cleanup", + bash_command="sh /data/scripts/post-scrape.sh -s target -k devops", + # retries=3, + # dag=dag + ) + + end = DummyOperator(task_id="end", dag=dag) + + start >> [pull1, pull2] >> scrape >> cleanup >> end + + # scrape = BashOperator( + # task_id="scrape_listings", + # bash_command="python3 " + script_path + '/gather/zip.py -k "devops"', + # # retries=3, + # # dag=dag + # ) + + # cleanup = BashOperator( + # task_id="cleanup", + # bash_command=script_path + "/post-scrape.sh -s zip -k devops", + # # retries=3, + # # dag=dag + # ) + + # init >> [pre1,pre2] + + # init >> pre2 + + # pre2 >> scrape + + # scrape >> cleanup diff --git a/dags/debug.py b/dags/debug.py new file mode 100644 index 0000000..b3022c0 --- /dev/null +++ b/dags/debug.py @@ -0,0 +1,125 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +### Tutorial Documentation +Documentation that goes along with the Airflow tutorial located +[here](https://airflow.apache.org/tutorial.html) +""" +# [START tutorial] +# [START import_module] +from datetime import timedelta +from textwrap import dedent + +# The DAG object; we'll need this to instantiate a DAG +from airflow import DAG + +# Operators; we need this to operate! +from airflow.operators.bash import BashOperator +from airflow.utils.dates import days_ago + +# [END import_module] + +# [START default_args] +# These args will get passed on to each operator +# You can override them on a per-task basis during operator initialization +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), + # 'queue': 'bash_queue', + # 'pool': 'backfill', + # 'priority_weight': 10, + # 'end_date': datetime(2016, 1, 1), + # 'wait_for_downstream': False, + # 'dag': dag, + # 'sla': timedelta(hours=2), + # 'execution_timeout': timedelta(seconds=300), + # 'on_failure_callback': some_function, + # 'on_success_callback': some_other_function, + # 'on_retry_callback': another_function, + # 'sla_miss_callback': yet_another_function, + # 'trigger_rule': 'all_success' +} +# [END default_args] + +# [START instantiate_dag] +with DAG( + "debug", + default_args=default_args, + description="A simple tutorial DAG", + schedule_interval=None, + start_date=days_ago(2), + tags=["example"], +) as dag: + # [END instantiate_dag] + + cleanup = BashOperator( + task_id="cleanup", + bash_command="sudo sh /data/scripts/post-scrape.sh -s target -k devops", + ) + + # t1, t2 and t3 are examples of tasks created by instantiating operators + # [START basic_task] + # t1 = BashOperator( + # task_id="print_date", + # bash_command="date", + # ) + + # t2 = BashOperator( + # task_id="sleep", + # depends_on_past=False, + # bash_command="sleep 5", + # retries=3, + # ) + # # [END basic_task] + + # # [START documentation] + # dag.doc_md = __doc__ + + cleanup.doc_md = dedent(""" + #### Task Documentation + You can document your task using the attributes `doc_md` (markdown), + `doc` (plain text), `doc_rst`, `doc_json`, `doc_yaml` which gets + rendered in the UI's Task Instance Details page. + ![img](http://montcs.bloomu.edu/~bobmon/Semesters/2012-01/491/import%20soul.png) + """) + # [END documentation] + + # [START jinja_template] + templated_command = dedent(""" + {% for i in range(5) %} + echo "{{ ds }}" + echo "{{ macros.ds_add(ds, 7)}}" + echo "{{ params.my_param }}" + {% endfor %} + """) + + t3 = BashOperator( + task_id="templated", + depends_on_past=False, + bash_command=templated_command, + params={"my_param": "Parameter I passed in"}, + ) + # [END jinja_template] + + cleanup +# [END tutorial] diff --git a/dags/dev/evade.py b/dags/dev/evade.py new file mode 100755 index 0000000..830671e --- /dev/null +++ b/dags/dev/evade.py @@ -0,0 +1,101 @@ +import random +from time import sleep + +from selenium import webdriver +from selenium.webdriver import ActionChains +from selenium.webdriver.common.by import By +from selenium.webdriver.support import expected_conditions as EC +from selenium.webdriver.support.wait import WebDriverWait + +# INSTANCE DRIVER + + +def configure_driver(settings, user_agent): + + options = webdriver.ChromeOptions() + + # Options to try to fool the site we are a normal browser. + options.add_experimental_option("excludeSwitches", ["enable-automation"]) + options.add_experimental_option("useAutomationExtension", False) + options.add_experimental_option( + "prefs", + { + "download.default_directory": settings.files_path, + "download.prompt_for_download": False, + "download.directory_upgrade": True, + "safebrowsing_for_trusted_sources_enabled": False, + "safebrowsing.enabled": False, + }, + ) + options.add_argument("--incognito") + options.add_argument("start-maximized") + options.add_argument("user-agent={user_agent}") + options.add_argument("disable-blink-features") + options.add_argument("disable-blink-features=AutomationControlled") + + driver = webdriver.Chrome( + options=options, executable_path=settings.chromedriver_path + ) + + # #Overwrite the webdriver property + # driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", + # { + # "source": """ + # Object.defineProperty(navigator, 'webdriver', { + # get: () => undefined + # }) + # """ + # }) + + driver.execute_cdp_cmd("Network.enable", {}) + + # Overwrite the User-Agent header + driver.execute_cdp_cmd( + "Network.setExtraHTTPHeaders", {"headers": {"User-Agent": user_agent}} + ) + + driver.command_executor._commands["send_command"] = ( + "POST", + "/session/$sessionId/chromium/send_command", + ) + + return driver + + +# SOLVENTAR CAPTCHA +def solve_wait_recaptcha(driver): + + ###### Move to reCAPTCHA Iframe + WebDriverWait(driver, 5).until( + EC.frame_to_be_available_and_switch_to_it( + ( + By.CSS_SELECTOR, + "iframe[src^='https://www.google.com/recaptcha/api2/anchor?']", + ) + ) + ) + + check_selector = "span.recaptcha-checkbox.goog-inline-block.recaptcha-checkbox-unchecked.rc-anchor-checkbox" + + captcha_check = driver.find_element_by_css_selector(check_selector) + + ###### Random delay before hover & click the checkbox + sleep(random.uniform(3, 6)) + ActionChains(driver).move_to_element(captcha_check).perform() + + ###### Hover it + sleep(random.uniform(0.5, 1)) + hov = ActionChains(driver).move_to_element(captcha_check).perform() + + ###### Random delay before click the checkbox + sleep(random.uniform(0.5, 1)) + driver.execute_script("arguments[0].click()", captcha_check) + + ###### Wait for recaptcha to be in solved state + elem = None + while elem is None: + try: + elem = driver.find_element_by_class_name("recaptcha-checkbox-checked") + except: + pass + sleep(5) diff --git a/dags/dev/form.py b/dags/dev/form.py new file mode 100755 index 0000000..93042d1 --- /dev/null +++ b/dags/dev/form.py @@ -0,0 +1,27 @@ +import mechanicalsoup + +browser = mechanicalsoup.StatefulBrowser() +browser.open("http://httpbin.org/") + +print(browser.url) +browser.follow_link("forms") +print(browser.url) +print(browser.page) + +browser.select_form('form[action="/post"]') +browser["custname"] = "Me" +browser["custtel"] = "00 00 0001" +browser["custemail"] = "nobody@example.com" +browser["size"] = "medium" +browser["topping"] = "onion" +browser["topping"] = ("bacon", "cheese") +browser["comments"] = "This pizza looks really good :-)" + +# Uncomment to launch a real web browser on the current page. +# browser.launch_browser() + +# Uncomment to display a summary of the filled-in form +browser.form.print_summary() + +response = browser.submit_selected() +print(response.text) diff --git a/dags/dev/log_config.py b/dags/dev/log_config.py new file mode 100644 index 0000000..ad64078 --- /dev/null +++ b/dags/dev/log_config.py @@ -0,0 +1,270 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Airflow logging settings""" + +import os +from pathlib import Path +from typing import Any, Dict, Union +from urllib.parse import urlparse + +from airflow.configuration import conf +from airflow.exceptions import AirflowException + +# TODO: Logging format and level should be configured +# in this file instead of from airflow.cfg. Currently +# there are other log format and level configurations in +# settings.py and cli.py. Please see AIRFLOW-1455. +LOG_LEVEL: str = conf.get('logging', 'LOGGING_LEVEL').upper() + + +# Flask appbuilder's info level log is very verbose, +# so it's set to 'WARN' by default. +FAB_LOG_LEVEL: str = conf.get('logging', 'FAB_LOGGING_LEVEL').upper() + +LOG_FORMAT: str = conf.get('logging', 'LOG_FORMAT') + +COLORED_LOG_FORMAT: str = conf.get('logging', 'COLORED_LOG_FORMAT') + +COLORED_LOG: bool = conf.getboolean('logging', 'COLORED_CONSOLE_LOG') + +COLORED_FORMATTER_CLASS: str = conf.get('logging', 'COLORED_FORMATTER_CLASS') + +BASE_LOG_FOLDER: str = conf.get('logging', 'BASE_LOG_FOLDER') + +PROCESSOR_LOG_FOLDER: str = conf.get('scheduler', 'CHILD_PROCESS_LOG_DIRECTORY') + +DAG_PROCESSOR_MANAGER_LOG_LOCATION: str = conf.get('logging', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION') + +FILENAME_TEMPLATE: str = conf.get('logging', 'LOG_FILENAME_TEMPLATE') + +PROCESSOR_FILENAME_TEMPLATE: str = conf.get('logging', 'LOG_PROCESSOR_FILENAME_TEMPLATE') + +DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'airflow': {'format': LOG_FORMAT}, + 'airflow_coloured': { + 'format': COLORED_LOG_FORMAT if COLORED_LOG else LOG_FORMAT, + 'class': COLORED_FORMATTER_CLASS if COLORED_LOG else 'logging.Formatter', + }, + }, + 'handlers': { + 'console': { + 'class': 'airflow.utils.log.logging_mixin.RedirectStdHandler', + 'formatter': 'airflow_coloured', + 'stream': 'sys.stdout', + }, + 'task': { + 'class': 'airflow.utils.log.file_task_handler.FileTaskHandler', + 'formatter': 'airflow', + 'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER), + 'filename_template': FILENAME_TEMPLATE, + }, + 'processor': { + 'class': 'airflow.utils.log.file_processor_handler.FileProcessorHandler', + 'formatter': 'airflow', + 'base_log_folder': os.path.expanduser(PROCESSOR_LOG_FOLDER), + 'filename_template': PROCESSOR_FILENAME_TEMPLATE, + }, + }, + 'loggers': { + 'airflow.processor': { + 'handlers': ['processor'], + 'level': LOG_LEVEL, + 'propagate': False, + }, + 'airflow.task': { + 'handlers': ['task'], + 'level': LOG_LEVEL, + 'propagate': False, + }, + 'flask_appbuilder': { + 'handler': ['console'], + 'level': FAB_LOG_LEVEL, + 'propagate': True, + }, + }, + 'root': { + 'handlers': ['console'], + 'level': LOG_LEVEL, + }, +} + +EXTRA_LOGGER_NAMES: str = conf.get('logging', 'EXTRA_LOGGER_NAMES', fallback=None) +if EXTRA_LOGGER_NAMES: + new_loggers = { + logger_name.strip(): { + 'handler': ['console'], + 'level': LOG_LEVEL, + 'propagate': True, + } + for logger_name in EXTRA_LOGGER_NAMES.split(",") + } + DEFAULT_LOGGING_CONFIG['loggers'].update(new_loggers) + +DEFAULT_DAG_PARSING_LOGGING_CONFIG: Dict[str, Dict[str, Dict[str, Any]]] = { + 'handlers': { + 'processor_manager': { + 'class': 'logging.handlers.RotatingFileHandler', + 'formatter': 'airflow', + 'filename': DAG_PROCESSOR_MANAGER_LOG_LOCATION, + 'mode': 'a', + 'maxBytes': 104857600, # 100MB + 'backupCount': 5, + } + }, + 'loggers': { + 'airflow.processor_manager': { + 'handlers': ['processor_manager'], + 'level': LOG_LEVEL, + 'propagate': False, + } + }, +} + +# Only update the handlers and loggers when CONFIG_PROCESSOR_MANAGER_LOGGER is set. +# This is to avoid exceptions when initializing RotatingFileHandler multiple times +# in multiple processes. +if os.environ.get('CONFIG_PROCESSOR_MANAGER_LOGGER') == 'True': + DEFAULT_LOGGING_CONFIG['handlers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers']) + DEFAULT_LOGGING_CONFIG['loggers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['loggers']) + + # Manually create log directory for processor_manager handler as RotatingFileHandler + # will only create file but not the directory. + processor_manager_handler_config: Dict[str, Any] = DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers'][ + 'processor_manager' + ] + directory: str = os.path.dirname(processor_manager_handler_config['filename']) + Path(directory).mkdir(parents=True, exist_ok=True, mode=0o755) + +################## +# Remote logging # +################## + +REMOTE_LOGGING: bool = conf.getboolean('logging', 'remote_logging') + +if REMOTE_LOGGING: + + ELASTICSEARCH_HOST: str = conf.get('elasticsearch', 'HOST') + + # Storage bucket URL for remote logging + # S3 buckets should start with "s3://" + # Cloudwatch log groups should start with "cloudwatch://" + # GCS buckets should start with "gs://" + # WASB buckets should start with "wasb" + # just to help Airflow select correct handler + REMOTE_BASE_LOG_FOLDER: str = conf.get('logging', 'REMOTE_BASE_LOG_FOLDER') + + if REMOTE_BASE_LOG_FOLDER.startswith('s3://'): + S3_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { + 'task': { + 'class': 'airflow.providers.amazon.aws.log.s3_task_handler.S3TaskHandler', + 'formatter': 'airflow', + 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), + 's3_log_folder': REMOTE_BASE_LOG_FOLDER, + 'filename_template': FILENAME_TEMPLATE, + }, + } + + DEFAULT_LOGGING_CONFIG['handlers'].update(S3_REMOTE_HANDLERS) + elif REMOTE_BASE_LOG_FOLDER.startswith('cloudwatch://'): + CLOUDWATCH_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { + 'task': { + 'class': 'airflow.providers.amazon.aws.log.cloudwatch_task_handler.CloudwatchTaskHandler', + 'formatter': 'airflow', + 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), + 'log_group_arn': urlparse(REMOTE_BASE_LOG_FOLDER).netloc, + 'filename_template': FILENAME_TEMPLATE, + }, + } + + DEFAULT_LOGGING_CONFIG['handlers'].update(CLOUDWATCH_REMOTE_HANDLERS) + elif REMOTE_BASE_LOG_FOLDER.startswith('gs://'): + key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None) + GCS_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { + 'task': { + 'class': 'airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler', + 'formatter': 'airflow', + 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), + 'gcs_log_folder': REMOTE_BASE_LOG_FOLDER, + 'filename_template': FILENAME_TEMPLATE, + 'gcp_key_path': key_path, + }, + } + + DEFAULT_LOGGING_CONFIG['handlers'].update(GCS_REMOTE_HANDLERS) + elif REMOTE_BASE_LOG_FOLDER.startswith('wasb'): + WASB_REMOTE_HANDLERS: Dict[str, Dict[str, Union[str, bool]]] = { + 'task': { + 'class': 'airflow.providers.microsoft.azure.log.wasb_task_handler.WasbTaskHandler', + 'formatter': 'airflow', + 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), + 'wasb_log_folder': REMOTE_BASE_LOG_FOLDER, + 'wasb_container': 'airflow-logs', + 'filename_template': FILENAME_TEMPLATE, + 'delete_local_copy': False, + }, + } + + DEFAULT_LOGGING_CONFIG['handlers'].update(WASB_REMOTE_HANDLERS) + elif REMOTE_BASE_LOG_FOLDER.startswith('stackdriver://'): + key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None) + # stackdriver:///airflow-tasks => airflow-tasks + log_name = urlparse(REMOTE_BASE_LOG_FOLDER).path[1:] + STACKDRIVER_REMOTE_HANDLERS = { + 'task': { + 'class': 'airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverTaskHandler', + 'formatter': 'airflow', + 'name': log_name, + 'gcp_key_path': key_path, + } + } + + DEFAULT_LOGGING_CONFIG['handlers'].update(STACKDRIVER_REMOTE_HANDLERS) + elif ELASTICSEARCH_HOST: + ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get('elasticsearch', 'LOG_ID_TEMPLATE') + ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get('elasticsearch', 'END_OF_LOG_MARK') + ELASTICSEARCH_FRONTEND: str = conf.get('elasticsearch', 'frontend') + ELASTICSEARCH_WRITE_STDOUT: bool = conf.getboolean('elasticsearch', 'WRITE_STDOUT') + ELASTICSEARCH_JSON_FORMAT: bool = conf.getboolean('elasticsearch', 'JSON_FORMAT') + ELASTICSEARCH_JSON_FIELDS: str = conf.get('elasticsearch', 'JSON_FIELDS') + + ELASTIC_REMOTE_HANDLERS: Dict[str, Dict[str, Union[str, bool]]] = { + 'task': { + 'class': 'airflow.providers.elasticsearch.log.es_task_handler.ElasticsearchTaskHandler', + 'formatter': 'airflow', + 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), + 'log_id_template': ELASTICSEARCH_LOG_ID_TEMPLATE, + 'filename_template': FILENAME_TEMPLATE, + 'end_of_log_mark': ELASTICSEARCH_END_OF_LOG_MARK, + 'host': ELASTICSEARCH_HOST, + 'frontend': ELASTICSEARCH_FRONTEND, + 'write_stdout': ELASTICSEARCH_WRITE_STDOUT, + 'json_format': ELASTICSEARCH_JSON_FORMAT, + 'json_fields': ELASTICSEARCH_JSON_FIELDS, + }, + } + + DEFAULT_LOGGING_CONFIG['handlers'].update(ELASTIC_REMOTE_HANDLERS) + else: + raise AirflowException( + "Incorrect remote log configuration. Please check the configuration of option 'host' in " + "section 'elasticsearch' if you are using Elasticsearch. In the other case, " + "'remote_base_log_folder' option in the 'logging' section." + ) diff --git a/dags/dev/menu.py b/dags/dev/menu.py new file mode 100644 index 0000000..b1f60ac --- /dev/null +++ b/dags/dev/menu.py @@ -0,0 +1,28 @@ +from airflow.plugins_manager import AirflowPlugin +from flask_admin.base import MenuLink + +github = MenuLink( + category="Astronomer", + name="Airflow Instance Github Repo", + url="https://github.com/astronomerio/astronomer-dags", +) + +astronomer_home = MenuLink( + category="Astronomer", name="Astronomer Home", url="https://www.astronomer.io/" +) + +aiflow_plugins = MenuLink( + category="Astronomer", + name="Airflow Plugins", + url="https://github.com/airflow-plugins", +) + +# Defining the plugin class +class AirflowTestPlugin(AirflowPlugin): + name = "AstronomerMenuLinks" + operators = [] + flask_blueprints = [] + hooks = [] + executors = [] + admin_views = [] + menu_links = [github, astronomer_home, aiflow_plugins] diff --git a/dags/dev/singer.py b/dags/dev/singer.py new file mode 100644 index 0000000..c4eb286 --- /dev/null +++ b/dags/dev/singer.py @@ -0,0 +1,32 @@ +""" +Singer +This example shows how to use Singer within Airflow using a custom operator: +- SingerOperator +https://github.com/airflow-plugins/singer_plugin/blob/master/operators/singer_operator.py#L5 +A complete list of Taps and Targets can be found in the Singer.io Github org: +https://github.com/singer-io +""" + + +from datetime import datetime + +from airflow import DAG +from airflow.operators import SingerOperator + +default_args = { + "start_date": datetime(2018, 2, 22), + "retries": 0, + "email": [], + "email_on_failure": True, + "email_on_retry": False, +} + +dag = DAG( + "__singer__fixerio_to_csv", schedule_interval="@hourly", default_args=default_args +) + +with dag: + + singer = SingerOperator(task_id="singer", tap="fixerio", target="csv") + + singer diff --git a/dags/dev/test.py b/dags/dev/test.py new file mode 100644 index 0000000..798dc22 --- /dev/null +++ b/dags/dev/test.py @@ -0,0 +1,181 @@ +""" +Headless Site Navigation and File Download (Using Selenium) to S3 +This example demonstrates using Selenium (via Firefox/GeckoDriver) to: +1) Log into a website w/ credentials stored in connection labeled 'selenium_conn_id' +2) Download a file (initiated on login) +3) Transform the CSV into JSON formatting +4) Append the current data to each record +5) Load the corresponding file into S3 +To use this DAG, you will need to have the following installed: +[XVFB](https://www.x.org/archive/X11R7.6/doc/man/man1/Xvfb.1.xhtml) +[GeckoDriver](https://github.com/mozilla/geckodriver/releases/download) +selenium==3.11.0 +xvfbwrapper==0.2.9 +""" +import csv +import datetime +import json +import logging +import os +import time +from datetime import datetime, timedelta + +import boa +import requests +from airflow import DAG +from airflow.models import Connection +from airflow.operators.dummy_operator import DummyOperator +from airflow.operators.python_operator import PythonOperator, PythonVirtualenvOperator +from airflow.providers.docker.operators.docker import DockerOperator +from airflow.utils.dates import days_ago +from airflow.utils.db import provide_session +from bs4 import BeautifulSoup +from selenium import webdriver +from selenium.webdriver.chrome.options import Options +from selenium.webdriver.chrome.service import Service +from selenium.webdriver.common.by import By +from selenium.webdriver.common.desired_capabilities import DesiredCapabilities +from selenium.webdriver.common.keys import Keys +from tools.network import ( + browser_config, + browser_feature, + driver_object, + ip_status, + proxy_ip, + selenoid_status, + uc_test, + user_agent, + vpn_settings, +) + +from config.webdriver import browser_capabilities, browser_options + +default_args = { + "start_date": days_ago(2), + "email": [], + "email_on_failure": True, + "email_on_retry": False, + # 'retries': 2, + "retry_delay": timedelta(minutes=5), + "catchup": False, +} + + +# def hello_world_py(): +# # selenium_conn_id = kwargs.get('templates_dict', None).get('selenium_conn_id', None) +# # filename = kwargs.get('templates_dict', None).get('filename', None) +# # s3_conn_id = kwargs.get('templates_dict', None).get('s3_conn_id', None) +# # s3_bucket = kwargs.get('templates_dict', None).get('s3_bucket', None) +# # s3_key = kwargs.get('templates_dict', None).get('s3_key', None) +# # date = kwargs.get('templates_dict', None).get('date', None) +# # module_name = kwargs.get('templates_dict', None).get('module', None) + + +# module = "anon_browser_test" + +# chrome_options = browser_options() +# capabilities = browser_capabilities(module) +# logging.info('Assembling driver') +# driver = webdriver.Remote( +# command_executor="http://192.168.1.101:4444/wd/hub", +# options=chrome_options, +# desired_capabilities=capabilities, +# ) +# logging.info('proxy IP') +# proxy_ip() +# logging.info('driver') +# vpn_settings() +# logging.info('driver') +# selenoid_status() +# logging.info('driver') +# ip_status(driver) +# logging.info('driver') +# browser_config(driver) +# logging.info('driver') +# user_agent(driver) +# logging.info('driver') +# driver_object(driver) +# logging.info('driver') +# browser_feature(driver) +# logging.info('driver') + +# driver.quit() +# logging.info('driver') +# uc_test() +# # print("Finished") +# return 'Whatever you return gets printed in the logs' + +# dag = DAG( +# 'anon_browser_test', +# schedule_interval='@daily', +# default_args=default_args, +# catchup=False +# ) + +# dummy_operator = DummyOperator(task_id="dummy_task", retries=3, dag=dag) + +# selenium = PythonOperator( +# task_id='anon_browser_test', +# python_callable=hello_world_py, +# templates_dict={"module": "anon_browser_test"}, +# dag=dag +# # "s3_bucket": S3_BUCKET, +# # "s3_key": S3_KEY, +# # "date": date} +# # provide_context=True +# ) + +# t1 = DockerOperator( +# # api_version='1.19', +# # docker_url='tcp://localhost:2375', # Set your docker URL +# command='/bin/sleep 30', +# image='selenoid/chrome:latest', +# # network_mode='bridge', +# task_id='chrome', +# dag=dag, +# ) + +# t2 = DockerOperator( +# # api_version='1.19', +# # docker_url='tcp://localhost:2375', # Set your docker URL +# command='/bin/sleep 30', +# image='selenoid/video-recorder:latest-release', +# # network_mode='bridge', +# task_id='video_recorder', +# dag=dag, +# ) + +# [START howto_operator_python_venv] +def callable_virtualenv(): + """ + Example function that will be performed in a virtual environment. + Importing at the module level ensures that it will not attempt to import the + library before it is installed. + """ + from time import sleep + + from colorama import Back, Fore, Style + + print(Fore.RED + "some red text") + print(Back.GREEN + "and with a green background") + print(Style.DIM + "and in dim text") + print(Style.RESET_ALL) + for _ in range(10): + print(Style.DIM + "Please wait...", flush=True) + sleep(10) + print("Finished") + + +virtualenv_task = PythonVirtualenvOperator( + task_id="virtualenv_python", + python_callable=callable_virtualenv, + requirements=["colorama==0.4.0"], + system_site_packages=False, + dag=dag, +) +# [END howto_operator_python_venv] + +# selenium >> dummy_operator +# dummy_operator >> virtualenv_task +# t1 >> selenium +# t2 >> selenium diff --git a/dags/docker.py b/dags/docker.py new file mode 100644 index 0000000..9903c84 --- /dev/null +++ b/dags/docker.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import timedelta + +from airflow import DAG +from airflow.operators.bash import BashOperator +from airflow.providers.docker.operators.docker import DockerOperator +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), +} + +dag = DAG( + "docker_sample", + default_args=default_args, + schedule_interval=timedelta(minutes=10), + start_date=days_ago(2), +) + +t1 = BashOperator(task_id="print_date", bash_command="date", dag=dag) + +t2 = BashOperator(task_id="sleep", bash_command="sleep 5", retries=3, dag=dag) + +t3 = DockerOperator( + api_version="1.19", + docker_url="tcp://localhost:2375", # Set your docker URL + command="/bin/sleep 30", + image="centos:latest", + network_mode="bridge", + task_id="docker_op_tester", + dag=dag, +) + + +t4 = BashOperator(task_id="print_hello", bash_command='echo "hello world!!!"', dag=dag) + + +t1 >> t2 +t1 >> t3 +t3 >> t4 diff --git a/dags/etl.py b/dags/etl.py new file mode 100644 index 0000000..413f273 --- /dev/null +++ b/dags/etl.py @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=missing-function-docstring + +# [START tutorial] +# [START import_module] +import json + +from airflow.decorators import dag, task +from airflow.utils.dates import days_ago + +# [END import_module] + +# [START default_args] +# These args will get passed on to each operator +# You can override them on a per-task basis during operator initialization +default_args = { + "owner": "airflow", +} +# [END default_args] + + +# [START instantiate_dag] +@dag( + default_args=default_args, + schedule_interval=None, + start_date=days_ago(2), + tags=["example"], +) +def tutorial_taskflow_api_etl(): + """ + ### TaskFlow API Tutorial Documentation + This is a simple ETL data pipeline example which demonstrates the use of + the TaskFlow API using three simple tasks for Extract, Transform, and Load. + Documentation that goes along with the Airflow TaskFlow API tutorial is + located + [here](https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html) + """ + # [END instantiate_dag] + + # [START extract] + @task() + def extract(): + """ + #### Extract task + A simple Extract task to get data ready for the rest of the data + pipeline. In this case, getting data is simulated by reading from a + hardcoded JSON string. + """ + data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}' + + order_data_dict = json.loads(data_string) + return order_data_dict + + # [END extract] + + # [START transform] + @task(multiple_outputs=True) + def transform(order_data_dict: dict): + """ + #### Transform task + A simple Transform task which takes in the collection of order data and + computes the total order value. + """ + total_order_value = 0 + + for value in order_data_dict.values(): + total_order_value += value + + return {"total_order_value": total_order_value} + + # [END transform] + + # [START load] + @task() + def load(total_order_value: float): + """ + #### Load task + A simple Load task which takes in the result of the Transform task and + instead of saving it to end user review, just prints it out. + """ + + print(f"Total order value is: {total_order_value:.2f}") + + # [END load] + + # [START main_flow] + order_data = extract() + order_summary = transform(order_data) + load(order_summary["total_order_value"]) + # [END main_flow] + + +# [START dag_invocation] +tutorial_etl_dag = tutorial_taskflow_api_etl() +# [END dag_invocation] + +# [END tutorial] diff --git a/dags/postgres.py b/dags/postgres.py new file mode 100644 index 0000000..4fad8fe --- /dev/null +++ b/dags/postgres.py @@ -0,0 +1,61 @@ +from os import getenv + +from sqlalchemy import VARCHAR, create_engine +from sqlalchemy.engine.url import URL + + +def _psql_insert_copy(table, conn, keys, data_iter): + """ + Execute SQL statement inserting data + + Parameters + ---------- + table : pandas.io.sql.SQLTable + conn : sqlalchemy.engine.Engine or sqlalchemy.engine.Connection + keys : list of str + Column names + data_iter : Iterable that iterates the values to be inserted + """ + # Alternative to_sql() *method* for DBs that support COPY FROM + import csv + from io import StringIO + + # gets a DBAPI connection that can provide a cursor + dbapi_conn = conn.connection + with dbapi_conn.cursor() as cur: + s_buf = StringIO() + writer = csv.writer(s_buf) + writer.writerows(data_iter) + s_buf.seek(0) + + columns = ', '.join('"{}"'.format(k) for k in keys) + if table.schema: + table_name = '{}.{}'.format(table.schema, table.name) + else: + table_name = table.name + + sql = 'COPY {} ({}) FROM STDIN WITH CSV'.format( + table_name, columns) + cur.copy_expert(sql=sql, file=s_buf) + + +def load_df_into_db(data_frame, schema: str, table: str) -> None: + engine = create_engine(URL( + username=getenv('DBT_POSTGRES_USER'), + password=getenv('DBT_POSTGRES_PASSWORD'), + host=getenv('DBT_POSTGRES_HOST'), + port=getenv('DBT_POSTGRES_PORT'), + database=getenv('DBT_POSTGRES_DB'), + drivername='postgres' + )) + with engine.connect() as cursor: + cursor.execute(f'CREATE SCHEMA IF NOT EXISTS {schema}') + data_frame.to_sql( + schema=schema, + name=table, + dtype=VARCHAR, + if_exists='append', + con=engine, + method=_psql_insert_copy, + index=False + ) diff --git a/dags/remote_browser.py b/dags/remote_browser.py new file mode 100644 index 0000000..cca66d6 --- /dev/null +++ b/dags/remote_browser.py @@ -0,0 +1,126 @@ +""" +Code that goes along with the Airflow located at: +http://airflow.readthedocs.org/en/latest/tutorial.html +""" +from datetime import datetime, timedelta +from textwrap import dedent + +from airflow import DAG + +# Operators; we need this to operate! +from airflow.operators.bash import BashOperator +from airflow.operators.dummy_operator import DummyOperator +from airflow.operators.python_operator import PythonOperator, PythonVirtualenvOperator +from airflow.utils.dates import days_ago +from airflow.utils.task_group import TaskGroup +from utils.network import selenoid_status, vpn_settings +from airflow.providers.postgres.operators.postgres import PostgresOperator +from utils.notify import PushoverOperator +from utils.postgres import sqlLoad + +default_args = { + "owner": "donaldrich", + "depends_on_past": False, + "start_date": datetime(2016, 7, 13), + "email": ["email@gmail.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=15), + # "on_failure_callback": some_function, + # "on_success_callback": some_other_function, + # "on_retry_callback": another_function, +} + +# [START default_args] +# default_args = { +# 'owner': 'airflow', +# 'depends_on_past': False, +# 'email': ['airflow@example.com'], +# 'email_on_failure': False, +# 'email_on_retry': False, +# 'retries': 1, +# 'retry_delay': timedelta(minutes=5), +# 'queue': 'bash_queue', +# 'pool': 'backfill', +# 'priority_weight': 10, +# 'end_date': datetime(2016, 1, 1), +# 'wait_for_downstream': False, +# 'dag': dag, +# 'sla': timedelta(hours=2), +# 'execution_timeout': timedelta(seconds=300), +# 'on_failure_callback': some_function, +# 'on_success_callback': some_other_function, +# 'on_retry_callback': another_function, +# 'sla_miss_callback': yet_another_function, +# 'trigger_rule': 'all_success' +# } +# [END default_args] + +dag = DAG( + "recon", + default_args=default_args, + description="A simple tutorial DAG", + schedule_interval=None, + start_date=days_ago(2), + tags=["target"], +) + +with dag: + + # start = DummyOperator(task_id="start", dag=dag) + + with TaskGroup("pull_images") as start2: + + pull1 = BashOperator( + task_id="chrome_latest", + bash_command="sudo docker pull selenoid/chrome:latest", + ) + + pull2 = BashOperator( + task_id="video_recorder", + bash_command= + "sudo docker pull selenoid/video-recorder:latest-release", + ) + + t1 = PythonOperator( + task_id="selenoid_status", + python_callable=selenoid_status, + ) + [pull1, pull2] >> t1 + start2.doc_md = dedent(""" + #### Task Documentation + You can document your task using the attributes `doc_md` (markdown), + `doc` (plain text), `doc_rst`, `doc_json`, `doc_yaml` which gets + rendered in the UI's Task Instance Details page. + ![img](http://montcs.bloomu.edu/~bobmon/Semesters/2012-01/491/import%20soul.png) + """) + + scrape = BashOperator( + task_id="scrape_listings", + bash_command="python3 /data/scripts/scrapers/zip-scrape.py -k devops", + ) + + sqlLoad = PythonOperator(task_id="sql_load", python_callable=sqlLoad) + + # get_birth_date = PostgresOperator( + # task_id="get_birth_date", + # postgres_conn_id="postgres_default", + # sql="sql/birth_date.sql", + # params={"begin_date": "2020-01-01", "end_date": "2020-12-31"}, + # ) + + cleanup = BashOperator( + task_id="cleanup", + bash_command="sudo sh /data/scripts/post-scrape.sh -s zip -k devops", + ) + + success_notify = PushoverOperator( + task_id="finished", + title="Airflow Complete", + message="We did it!", + ) + + end = DummyOperator(task_id="end", dag=dag) + + start2 >> scrape >> sqlLoad >> cleanup >> success_notify >> end diff --git a/dags/s3.py b/dags/s3.py new file mode 100644 index 0000000..9e9935a --- /dev/null +++ b/dags/s3.py @@ -0,0 +1,41 @@ +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.operators.python_operator import PythonOperator + +# from airflow.hooks.S3_hook import S3Hook +from airflow.providers.amazon.aws.hooks.s3 import S3Hook + +DEFAULT_ARGS = { + "owner": "Airflow", + "depends_on_past": False, + "start_date": datetime(2020, 1, 13), + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), +} + +dag = DAG("create_date_dimension", default_args=DEFAULT_ARGS, schedule_interval="@once") + +# Create a task to call your processing function +def write_text_file(ds, **kwargs): + with open("/data/data/temp.txt", "w") as fp: + # Add file generation/processing step here, E.g.: + # fp.write(ds) + + # Upload generated file to Minio + # s3 = S3Hook('local_minio') + s3 = S3Hook(aws_conn_id="minio") + s3.load_file("/data/data/temp.txt", key=f"my-test-file.txt", bucket_name="airflow") + + +# Create a task to call your processing function + +t1 = PythonOperator( + task_id="generate_and_upload_to_s3", + provide_context=True, + python_callable=write_text_file, + dag=dag, +) diff --git a/dags/scraper.py b/dags/scraper.py new file mode 100644 index 0000000..d6262c2 --- /dev/null +++ b/dags/scraper.py @@ -0,0 +1,154 @@ +""" +Headless Site Navigation and File Download (Using Selenium) to S3 + +This example demonstrates using Selenium (via Firefox/GeckoDriver) to: +1) Log into a website w/ credentials stored in connection labeled 'selenium_conn_id' +2) Download a file (initiated on login) +3) Transform the CSV into JSON formatting +4) Append the current data to each record +5) Load the corresponding file into S3 + +To use this DAG, you will need to have the following installed: +[XVFB](https://www.x.org/archive/X11R7.6/doc/man/man1/Xvfb.1.xhtml) +[GeckoDriver](https://github.com/mozilla/geckodriver/releases/download) + +selenium==3.11.0 +xvfbwrapper==0.2.9 +""" +# import boa +import csv +import json +import logging +import os +import time +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.hooks import S3Hook +from airflow.models import Connection +from airflow.operators.dummy_operator import DummyOperator +from airflow.operators.python_operator import PythonOperator +from airflow.utils.db import provide_session +from selenium import webdriver +from selenium.webdriver.common.keys import Keys +from selenium.webdriver.firefox.options import Options +from xvfbwrapper import Xvfb + +S3_CONN_ID = "" +S3_BUCKET = "" +S3_KEY = "" + +date = "{{ ds }}" + +default_args = { + "start_date": datetime(2018, 2, 10, 0, 0), + "email": [], + "email_on_failure": True, + "email_on_retry": False, + "retries": 2, + "retry_delay": timedelta(minutes=5), + "catchup": False, +} + +dag = DAG( + "selenium_extraction_to_s3", + schedule_interval="@daily", + default_args=default_args, + catchup=False, +) + + +def imap_py(**kwargs): + selenium_conn_id = kwargs.get("templates_dict", None).get("selenium_conn_id", None) + filename = kwargs.get("templates_dict", None).get("filename", None) + s3_conn_id = kwargs.get("templates_dict", None).get("s3_conn_id", None) + s3_bucket = kwargs.get("templates_dict", None).get("s3_bucket", None) + s3_key = kwargs.get("templates_dict", None).get("s3_key", None) + date = kwargs.get("templates_dict", None).get("date", None) + + @provide_session + def get_conn(conn_id, session=None): + conn = session.query(Connection).filter(Connection.conn_id == conn_id).first() + return conn + + url = get_conn(selenium_conn_id).host + email = get_conn(selenium_conn_id).user + pwd = get_conn(selenium_conn_id).password + + vdisplay = Xvfb() + vdisplay.start() + caps = webdriver.DesiredCapabilities.FIREFOX + caps["marionette"] = True + + profile = webdriver.FirefoxProfile() + profile.set_preference("browser.download.manager.showWhenStarting", False) + profile.set_preference("browser.helperApps.neverAsk.saveToDisk", "text/csv") + + logging.info("Profile set...") + options = Options() + options.set_headless(headless=True) + logging.info("Options set...") + logging.info("Initializing Driver...") + driver = webdriver.Firefox( + firefox_profile=profile, firefox_options=options, capabilities=caps + ) + logging.info("Driver Intialized...") + driver.get(url) + logging.info("Authenticating...") + elem = driver.find_element_by_id("email") + elem.send_keys(email) + elem = driver.find_element_by_id("password") + elem.send_keys(pwd) + elem.send_keys(Keys.RETURN) + + logging.info("Successfully authenticated.") + + sleep_time = 15 + + logging.info("Downloading File....Sleeping for {} Seconds.".format(str(sleep_time))) + time.sleep(sleep_time) + + driver.close() + vdisplay.stop() + + dest_s3 = S3Hook(s3_conn_id=s3_conn_id) + + os.chdir("/root/Downloads") + + csvfile = open(filename, "r") + + output_json = "file.json" + + with open(output_json, "w") as jsonfile: + reader = csv.DictReader(csvfile) + + for row in reader: + # row = dict((boa.constrict(k), v) for k, v in row.items()) + row["run_date"] = date + json.dump(row, jsonfile) + jsonfile.write("\n") + + dest_s3.load_file( + filename=output_json, key=s3_key, bucket_name=s3_bucket, replace=True + ) + + dest_s3.connection.close() + + +with dag: + + kick_off_dag = DummyOperator(task_id="kick_off_dag") + + selenium = PythonOperator( + task_id="selenium_retrieval_to_s3", + python_callable=imap_py, + templates_dict={ + "s3_conn_id": S3_CONN_ID, + "s3_bucket": S3_BUCKET, + "s3_key": S3_KEY, + "date": date, + }, + provide_context=True, + ) + + kick_off_dag >> selenium diff --git a/dags/selenium.py b/dags/selenium.py new file mode 100644 index 0000000..2091db2 --- /dev/null +++ b/dags/selenium.py @@ -0,0 +1,65 @@ +from datetime import timedelta + +from airflow import DAG +from airflow.operators.bash import BashOperator +from airflow.operators.dummy_operator import DummyOperator +from airflow.providers.docker.operators.docker import DockerOperator +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), +} + +dag = DAG( + "selenoid_setup", + default_args=default_args, + schedule_interval="@daily", + start_date=days_ago(2), +) + + +with dag: + + kick_off_dag = DummyOperator(task_id="kick_off_dag") + + t1 = DockerOperator( + # api_version='1.19', + # docker_url='tcp://localhost:2375', # Set your docker URL + command="/bin/sleep 30", + image="selenoid/chrome:latest", + network_mode="bridge", + task_id="selenoid1", + # dag=dag, + ) + + t2 = DockerOperator( + # api_version='1.19', + # docker_url='tcp://localhost:2375', # Set your docker URL + command="/bin/sleep 30", + image="selenoid/video-recorder:latest-release", + network_mode="bridge", + task_id="selenoid2", + # dag=dag, + ) + + scrape = BashOperator( + task_id="pull_selenoid_video-recorder", + bash_command="docker pull selenoid/video-recorder:latest-release", + # retries=3, + # dag=dag + ) + + scrape2 = BashOperator( + task_id="pull_selenoid_chrome", + bash_command="docker pull selenoid/chrome:latest", + # retries=3, + # dag=dag + ) + + scrape >> t1 >> t2 diff --git a/dags/sql/search.sql b/dags/sql/search.sql new file mode 100644 index 0000000..efb2075 --- /dev/null +++ b/dags/sql/search.sql @@ -0,0 +1 @@ +SELECT * FROM pet WHERE birth_date BETWEEN SYMMETRIC {{ params.begin_date }} AND {{ params.end_date }}; \ No newline at end of file diff --git a/dags/sqlite.py b/dags/sqlite.py new file mode 100644 index 0000000..7a3f742 --- /dev/null +++ b/dags/sqlite.py @@ -0,0 +1,122 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example DAG for the use of the SqliteOperator. +In this example, we create two tasks that execute in sequence. +The first task calls an sql command, defined in the SQLite operator, +which when triggered, is performed on the connected sqlite database. +The second task is similar but instead calls the SQL command from an external file. +""" +import apprise +from airflow import DAG +from airflow.operators.bash import BashOperator +from airflow.operators.dummy_operator import DummyOperator +from airflow.operators.python_operator import PythonOperator +from airflow.providers.docker.operators.docker import DockerOperator +from airflow.providers.sqlite.operators.sqlite import SqliteOperator +from airflow.utils.dates import days_ago + +default_args = {"owner": "airflow"} + +dag = DAG( + dag_id="example_sqlite", + default_args=default_args, + schedule_interval="@hourly", + start_date=days_ago(2), + tags=["example"], +) + +with dag: + + # [START howto_operator_sqlite] + + # Example of creating a task that calls a common CREATE TABLE sql command. + # def notify(): + + # # Create an Apprise instance + # apobj = apprise.Apprise() + + # # Add all of the notification services by their server url. + # # A sample email notification: + # # apobj.add('mailto://myuserid:mypass@gmail.com') + + # # A sample pushbullet notification + # apobj.add('pover://aejghiy6af1bshe8mbdksmkzeon3ip@umjiu36dxwwaj8pnfx3n6y2xbm3ssx') + + # # Then notify these services any time you desire. The below would + # # notify all of the services loaded into our Apprise object. + # apobj.notify( + # body='what a great notification service!', + # title='my notification title', + # ) + # return apobj + + # apprise = PythonOperator( + # task_id="apprise", + # python_callable=notify, + # dag=dag + # # "s3_bucket": S3_BUCKET, + # # "s3_key": S3_KEY, + # # "date": date} + # # provide_context=True + # ) + + # t2 = DockerOperator( + # task_id='docker_command', + # image='selenoid/chrome:latest', + # api_version='auto', + # auto_remove=False, + # command="/bin/sleep 30", + # docker_url="unix://var/run/docker.sock", + # network_mode="bridge" + # ) + + start = DummyOperator(task_id="start") + + docker = BashOperator( + task_id="pull_selenoid_2", + bash_command="sudo docker pull selenoid/chrome:latest", + # retries=3, + # dag=dag + ) + + pre2 = BashOperator( + task_id="apprise", + bash_command="apprise -vv -t 'my title' -b 'my notification body' pover://umjiu36dxwwaj8pnfx3n6y2xbm3ssx@aejghiy6af1bshe8mbdksmkzeon3ip", + # retries=3, + # dag=dag + ) + + # t3 = DockerOperator( + # api_version="1.19", + # docker_url="tcp://localhost:2375", # Set your docker URL + # command="/bin/sleep 30", + # image="centos:latest", + # network_mode="bridge", + # task_id="docker_op_tester", + # # dag=dag, + # ) + + end = DummyOperator(task_id="end") + + +start >> docker >> pre2 >> end +# [END howto_operator_sqlite_external_file] + +# create_table_sqlite_task >> external_create_table_sqlite_task +# diff --git a/dags/tutorial_etl_dag.py b/dags/tutorial_etl_dag.py new file mode 100644 index 0000000..30cc02c --- /dev/null +++ b/dags/tutorial_etl_dag.py @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=missing-function-docstring +""" +### ETL DAG Tutorial Documentation +This ETL DAG is compatible with Airflow 1.10.x (specifically tested with 1.10.12) and is referenced +as part of the documentation that goes along with the Airflow Functional DAG tutorial located +[here](https://airflow.apache.org/tutorial_decorated_flows.html) +""" +# [START tutorial] +# [START import_module] +import json +from textwrap import dedent + +# The DAG object; we'll need this to instantiate a DAG +from airflow import DAG + +# Operators; we need this to operate! +from airflow.operators.python import PythonOperator +from airflow.utils.dates import days_ago + +# [END import_module] + +# [START default_args] +# These args will get passed on to each operator +# You can override them on a per-task basis during operator initialization +default_args = { + "owner": "donaldrich", + "email": ["email@gmail.com"], +} +# [END default_args] + +# [START instantiate_dag] +with DAG( + "tutorial_etl_dag", + default_args=default_args, + description="ETL DAG tutorial", + schedule_interval=None, + start_date=days_ago(2), + tags=["example"], +) as dag: + # [END instantiate_dag] + # [START documentation] + dag.doc_md = __doc__ + + # [END documentation] + + # [START extract_function] + def extract(**kwargs): + ti = kwargs["ti"] + data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}' + ti.xcom_push("order_data", data_string) + + # [END extract_function] + + # [START transform_function] + def transform(**kwargs): + ti = kwargs["ti"] + extract_data_string = ti.xcom_pull(task_ids="extract", + key="order_data") + order_data = json.loads(extract_data_string) + + total_order_value = 0 + for value in order_data.values(): + total_order_value += value + + total_value = {"total_order_value": total_order_value} + total_value_json_string = json.dumps(total_value) + ti.xcom_push("total_order_value", total_value_json_string) + + # [END transform_function] + + # [START load_function] + def load(**kwargs): + ti = kwargs["ti"] + total_value_string = ti.xcom_pull(task_ids="transform", + key="total_order_value") + total_order_value = json.loads(total_value_string) + + print(total_order_value) + + # [END load_function] + + # [START main_flow] + extract_task = PythonOperator( + task_id="extract", + python_callable=extract, + ) + extract_task.doc_md = dedent("""\ + #### Extract task + A simple Extract task to get data ready for the rest of the data pipeline. + In this case, getting data is simulated by reading from a hardcoded JSON string. + This data is then put into xcom, so that it can be processed by the next task. + """) + + transform_task = PythonOperator( + task_id="transform", + python_callable=transform, + ) + transform_task.doc_md = dedent("""\ + #### Transform task + A simple Transform task which takes in the collection of order data from xcom + and computes the total order value. + This computed value is then put into xcom, so that it can be processed by the next task. + """) + + load_task = PythonOperator( + task_id="load", + python_callable=load, + ) + load_task.doc_md = dedent("""\ + #### Load task + A simple Load task which takes in the result of the Transform task, by reading it + from xcom and instead of saving it to end user review, just prints it out. + """) + + extract_task >> transform_task >> load_task + +# [END main_flow] + +# [END tutorial] diff --git a/dags/vault.py b/dags/vault.py new file mode 100644 index 0000000..2dc2881 --- /dev/null +++ b/dags/vault.py @@ -0,0 +1,32 @@ +import os +from datetime import datetime +from os import environ + +from airflow import DAG +from airflow.hooks.base_hook import BaseHook +from airflow.operators.python_operator import PythonOperator + +os.environ[ + "AIRFLOW__SECRETS__BACKEND" +] = "airflow.providers.hashicorp.secrets.vault.VaultBackend" +os.environ[ + "AIRFLOW__SECRETS__BACKEND_KWARGS"] = '{"connections_path": "myapp", "mount_point": "secret", "auth_type": "token", "token": "token", "url": "http://vault:8200"}' + + +def get_secrets(**kwargs): + conn = BaseHook.get_connection(kwargs["my_conn_id"]) + print("Password:", {conn.password}) + print(" Login:", {conn.login}) + print(" URI:", {conn.get_uri()}) + print("Host:", {conn.host}) + + +with DAG( + "vault_example", start_date=datetime(2020, 1, 1), schedule_interval=None +) as dag: + + test_task = PythonOperator( + task_id="test-task", + python_callable=get_secrets, + op_kwargs={"my_conn_id": "smtp_default"}, + ) diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..d377b65 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,186 @@ +version: "3.7" + +x-airflow-common: &airflow-common + image: donaldrich/airflow:latest + # build: . + environment: &airflow-common-env + AIRFLOW__CORE__STORE_SERIALIZED_DAGS: "True" + AIRFLOW__CORE__STORE_DAG_CODE: "True" + AIRFLOW__CORE__EXECUTOR: "CeleryExecutor" + AIRFLOW__CORE__SQL_ALCHEMY_CONN: postgresql+psycopg2://airflow:airflow@postgres-dev:5432/airflow + AIRFLOW__CELERY__RESULT_BACKEND: db+postgresql://airflow:airflow@postgres-dev:5432/airflow + AIRFLOW__CELERY__BROKER_URL: redis://:@redis:6379/0 + AIRFLOW__CORE__FERNET_KEY: "" + AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION: "false" + AIRFLOW__CORE__LOAD_EXAMPLES: + "false" + # AIRFLOW__CORE__PARALLELISM: 4 + # AIRFLOW__CORE__DAG_CONCURRENCY: 4 + # AIRFLOW__CORE__MAX_ACTIVE_RUNS_PER_DAG: 4 + AIRFLOW_UID: "1000" + AIRFLOW_GID: "0" + _AIRFLOW_DB_UPGRADE: "true" + _AIRFLOW_WWW_USER_CREATE: "true" + _AIRFLOW_WWW_USER_USERNAME: "{{ user }}" + _AIRFLOW_WWW_USER_PASSWORD: "{{ password }}" + PYTHONPATH: "/data:$$PYTHONPATH" + user: "1000:0" + volumes: + - "./airflow:/opt/airflow" + - "./data:/data" + - "/var/run/docker.sock:/var/run/docker.sock" + # - "/usr/bin/docker:/bin/docker:ro" + networks: + - proxy + - backend + +services: + airflow-init: + <<: *airflow-common + container_name: airflow-init + environment: + <<: *airflow-common-env + # depends_on: + # - airflow-db + command: bash -c "airflow db init && airflow db upgrade && airflow users create --role Admin --username {{ user }} --email {{ email }} --firstname Don --lastname Aldrich --password {{ password }}" + + airflow-webserver: + <<: *airflow-common + container_name: airflow-webserver + hostname: airflow-webserver + command: + webserver + # command: > + # bash -c 'if [[ -z "$$AIRFLOW__API__AUTH_BACKEND" ]] && [[ $$(pip show -f apache-airflow | grep basic_auth.py) ]]; + # then export AIRFLOW__API__AUTH_BACKEND=airflow.api.auth.backend.basic_auth ; + # else export AIRFLOW__API__AUTH_BACKEND=airflow.api.auth.backend.default ; fi && + # { airflow create_user "$$@" || airflow users create "$$@" ; } && + # { airflow sync_perm || airflow sync-perm ;} && + # airflow webserver' -- -r Admin -u admin -e admin@example.com -f admin -l user -p admin + healthcheck: + test: ["CMD", "curl", "--fail", "http://localhost:8080/health"] + interval: 10s + timeout: 10s + retries: 5 + restart: always + privileged: true + depends_on: + - airflow-scheduler + environment: + <<: *airflow-common-env + labels: + traefik.http.services.airflow.loadbalancer.server.port: "8080" + traefik.enable: "true" + traefik.http.routers.airflow.entrypoints: "https" + traefik.http.routers.airflow.tls.certResolver: "cloudflare" + traefik.http.routers.airflow.rule: "Host(`airflow.{{ domain }}.com`)" + traefik.http.routers.airflow.middlewares: "ip-whitelist@file" + traefik.http.routers.airflow.service: "airflow" + + airflow-scheduler: + <<: *airflow-common + command: scheduler + container_name: airflow-scheduler + hostname: airflow-scheduler + restart: always + depends_on: + - airflow-init + environment: + <<: *airflow-common-env + # # entrypoint: sh -c '/app/scripts/wait-for postgres:5432 -- airflow db init && airflow scheduler' + + airflow-worker: + <<: *airflow-common + command: celery worker + restart: always + container_name: airflow-worker + hostname: airflow-worker + + airflow-queue: + <<: *airflow-common + command: celery flower + container_name: airflow-queue + hostname: airflow-queue + ports: + - 5555:5555 + healthcheck: + test: ["CMD", "curl", "--fail", "http://localhost:5555/"] + interval: 10s + timeout: 10s + retries: 5 + restart: always + + dbt: + image: fishtownanalytics/dbt:0.19.1 + container_name: dbt + volumes: + - "/home/{{ user }}/projects/jobfunnel:/data" + - "/home/{{ user }}/projects/jobfunnel/dbt:/usr/app" + # - "dbt-db:/var/lib/postgresql/data" + ports: + - "8081" + networks: + - proxy + - backend + command: docs serve --project-dir /data/transform + working_dir: "/data/transform" + environment: + DBT_SCHEMA: dbt + DBT_RAW_DATA_SCHEMA: dbt_raw_data + DBT_PROFILES_DIR: "/data/transform/profile" + # DBT_PROJECT_DIR: "/data/transform" + # DBT_PROFILES_DIR: "/data" + # DBT_POSTGRES_PASSWORD: dbt + # DBT_POSTGRES_USER : dbt + # DBT_POSTGRES_DB : dbt + # DBT_DBT_SCHEMA: dbt + # DBT_DBT_RAW_DATA_SCHEMA: dbt_raw_data + # DBT_POSTGRES_HOST: dbt-db + + meltano: + image: meltano/meltano:latest + container_name: meltano + volumes: + - "/home/{{ user }}/projects/jobfunnel:/project" + ports: + - "5000:5000" + networks: + - proxy + - backend + # environment: + # MELTANO_UI_SERVER_NAME: "etl.{{ domain }}.com" + # MELTANO_DATABASE_URI: "postgresql://meltano:meltano@meltano-db:5432/meltano" + # MELTANO_WEBAPP_POSTGRES_URL: localhost + # MELTANO_WEBAPP_POSTGRES_DB=meltano + # MELTANO_WEBAPP_POSTGRES_USER=meltano + # MELTANO_WEBAPP_POSTGRES_PASSWORD=meltano + # MELTANO_WEBAPP_POSTGRES_PORT=5501 + # MELTANO_WEBAPP_LOG_PATH="/tmp/meltano.log" + # MELTANO_WEBAPP_API_URL="http://localhost:5000" + # LOG_PATH="/tmp/meltano.log" + # API_URL="http://localhost:5000" + # MELTANO_MODEL_DIR="./model" + # MELTANO_TRANSFORM_DIR="./transform" + # MELTANO_UI_SESSION_COOKIE_DOMAIN: "etl.{{ domain }}.com" + # MELTANO_UI_SESSION_COOKIE_SECURE: "true" + labels: + traefik.http.services.meltano.loadbalancer.server.port: "5000" + traefik.enable: "true" + traefik.http.routers.meltano.entrypoints: "https" + traefik.http.routers.meltano.tls.certResolver: "cloudflare" + traefik.http.routers.meltano.rule: "Host(`etl.{{ domain }}.com`)" + traefik.http.routers.meltano.middlewares: "secured@file,ip-whitelist@file" + traefik.http.routers.meltano.service: "meltano" + +volumes: + tmp_airflow: + driver: local + airflow-db: + driver: local + dbt-db: + driver: local +networks: + proxy: + external: true + backend: + external: true diff --git a/plugins/__init__.py b/plugins/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/plugins/config/__init__.py b/plugins/config/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/plugins/config/logger.py b/plugins/config/logger.py new file mode 100644 index 0000000..fe24abd --- /dev/null +++ b/plugins/config/logger.py @@ -0,0 +1,49 @@ +import csv +import logging +import os +import platform +import random +import re +import time + +# from webdriver_manager.chrome import ChromeDriverManager +from datetime import datetime +from urllib.request import urlopen + +# import pyautogui +import pandas as pd +import yaml +from bs4 import BeautifulSoup + +# import mouseinfo +from selenium import webdriver +from selenium.common.exceptions import NoSuchElementException, TimeoutException +from selenium.webdriver.chrome.options import Options +from selenium.webdriver.common.by import By +from selenium.webdriver.common.desired_capabilities import DesiredCapabilities +from selenium.webdriver.common.keys import Keys +from selenium.webdriver.support import expected_conditions as EC +from selenium.webdriver.support.ui import WebDriverWait + + +def setupLogger(): + dt = datetime.strftime(datetime.now(), "%m_%d_%y %H_%M_%S ") + + if not os.path.isdir("./logs"): + os.mkdir("./logs") + + # TODO need to check if there is a log dir available or not + logging.basicConfig( + filename=("./logs/" + str(dt) + "applyJobs.log"), + filemode="w", + format="%(asctime)s::%(name)s::%(levelname)s::%(message)s", + datefmt="./logs/%d-%b-%y %H:%M:%S", + ) + log.setLevel(logging.DEBUG) + c_handler = logging.StreamHandler() + c_handler.setLevel(logging.DEBUG) + c_format = logging.Formatter( + "%(asctime)s - %(levelname)s - %(message)s", "%H:%M:%S" + ) + c_handler.setFormatter(c_format) + log.addHandler(c_handler) diff --git a/plugins/hooks/discord/discord.py b/plugins/hooks/discord/discord.py new file mode 100644 index 0000000..076adbc --- /dev/null +++ b/plugins/hooks/discord/discord.py @@ -0,0 +1,145 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import json +import re +from typing import Any, Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.http.hooks.http import HttpHook + + +class DiscordWebhookHook(HttpHook): + """ + This hook allows you to post messages to Discord using incoming webhooks. + Takes a Discord connection ID with a default relative webhook endpoint. The + default endpoint can be overridden using the webhook_endpoint parameter + (https://discordapp.com/developers/docs/resources/webhook). + Each Discord webhook can be pre-configured to use a specific username and + avatar_url. You can override these defaults in this hook. + :param http_conn_id: Http connection ID with host as "https://discord.com/api/" and + default webhook endpoint in the extra field in the form of + {"webhook_endpoint": "webhooks/{webhook.id}/{webhook.token}"} + :type http_conn_id: str + :param webhook_endpoint: Discord webhook endpoint in the form of + "webhooks/{webhook.id}/{webhook.token}" + :type webhook_endpoint: str + :param message: The message you want to send to your Discord channel + (max 2000 characters) + :type message: str + :param username: Override the default username of the webhook + :type username: str + :param avatar_url: Override the default avatar of the webhook + :type avatar_url: str + :param tts: Is a text-to-speech message + :type tts: bool + :param proxy: Proxy to use to make the Discord webhook call + :type proxy: str + """ + + def __init__( + self, + http_conn_id: Optional[str] = None, + webhook_endpoint: Optional[str] = None, + message: str = "", + username: Optional[str] = None, + avatar_url: Optional[str] = None, + tts: bool = False, + proxy: Optional[str] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.http_conn_id: Any = http_conn_id + self.webhook_endpoint = self._get_webhook_endpoint( + http_conn_id, webhook_endpoint + ) + self.message = message + self.username = username + self.avatar_url = avatar_url + self.tts = tts + self.proxy = proxy + + def _get_webhook_endpoint( + self, http_conn_id: Optional[str], webhook_endpoint: Optional[str] + ) -> str: + """ + Given a Discord http_conn_id, return the default webhook endpoint or override if a + webhook_endpoint is manually supplied. + :param http_conn_id: The provided connection ID + :param webhook_endpoint: The manually provided webhook endpoint + :return: Webhook endpoint (str) to use + """ + if webhook_endpoint: + endpoint = webhook_endpoint + elif http_conn_id: + conn = self.get_connection(http_conn_id) + extra = conn.extra_dejson + endpoint = extra.get("webhook_endpoint", "") + else: + raise AirflowException( + "Cannot get webhook endpoint: No valid Discord webhook endpoint or http_conn_id supplied." + ) + + # make sure endpoint matches the expected Discord webhook format + if not re.match("^webhooks/[0-9]+/[a-zA-Z0-9_-]+$", endpoint): + raise AirflowException( + 'Expected Discord webhook endpoint in the form of "webhooks/{webhook.id}/{webhook.token}".' + ) + + return endpoint + + def _build_discord_payload(self) -> str: + """ + Construct the Discord JSON payload. All relevant parameters are combined here + to a valid Discord JSON payload. + :return: Discord payload (str) to send + """ + payload: Dict[str, Any] = {} + + if self.username: + payload["username"] = self.username + if self.avatar_url: + payload["avatar_url"] = self.avatar_url + + payload["tts"] = self.tts + + if len(self.message) <= 2000: + payload["content"] = self.message + else: + raise AirflowException( + "Discord message length must be 2000 or fewer characters." + ) + + return json.dumps(payload) + + def execute(self) -> None: + """Execute the Discord webhook call""" + proxies = {} + if self.proxy: + # we only need https proxy for Discord + proxies = {"https": self.proxy} + + discord_payload = self._build_discord_payload() + + self.run( + endpoint=self.webhook_endpoint, + data=discord_payload, + headers={"Content-type": "application/json"}, + extra_options={"proxies": proxies}, + ) diff --git a/plugins/hooks/google.py b/plugins/hooks/google.py new file mode 100644 index 0000000..b57c4a5 --- /dev/null +++ b/plugins/hooks/google.py @@ -0,0 +1,108 @@ +""" +There are two ways to authenticate the Google Analytics Hook. + +If you have already obtained an OAUTH token, place it in the password field +of the relevant connection. + +If you don't have an OAUTH token, you may authenticate by passing a +'client_secrets' object to the extras section of the relevant connection. This +object will expect the following fields and use them to generate an OAUTH token +on execution. + +"type": "service_account", +"project_id": "example-project-id", +"private_key_id": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", +"private_key": "-----BEGIN PRIVATE KEY-----\nXXXXX\n-----END PRIVATE KEY-----\n", +"client_email": "google-analytics@{PROJECT_ID}.iam.gserviceaccount.com", +"client_id": "XXXXXXXXXXXXXXXXXXXXXX", +"auth_uri": "https://accounts.google.com/o/oauth2/auth", +"token_uri": "https://accounts.google.com/o/oauth2/token", +"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", +"client_x509_cert_url": "{CERT_URL}" + +More details can be found here: +https://developers.google.com/api-client-library/python/guide/aaa_client_secrets +""" + +import time + +from airflow.hooks.base_hook import BaseHook +from apiclient.discovery import build +from oauth2client.client import AccessTokenCredentials +from oauth2client.service_account import ServiceAccountCredentials + + +class GoogleHook(BaseHook): + def __init__(self, google_conn_id="google_default"): + self.google_analytics_conn_id = google_conn_id + self.connection = self.get_connection(google_conn_id) + + if self.connection.extra_dejson.get("client_secrets", None): + self.client_secrets = self.connection.extra_dejson["client_secrets"] + + def get_service_object(self, api_name, api_version, scopes=None): + if self.connection.password: + credentials = AccessTokenCredentials( + self.connection.password, "Airflow/1.0" + ) + elif self.client_secrets: + credentials = ServiceAccountCredentials.from_json_keyfile_dict( + self.client_secrets, scopes + ) + + return build(api_name, api_version, credentials=credentials) + + def get_analytics_report( + self, + view_id, + since, + until, + sampling_level, + dimensions, + metrics, + page_size, + include_empty_rows, + ): + analytics = self.get_service_object( + "analyticsreporting", + "v4", + ["https://www.googleapis.com/auth/analytics.readonly"], + ) + + reportRequest = { + "viewId": view_id, + "dateRanges": [{"startDate": since, "endDate": until}], + "samplingLevel": sampling_level or "LARGE", + "dimensions": dimensions, + "metrics": metrics, + "pageSize": page_size or 1000, + "includeEmptyRows": include_empty_rows or False, + } + + response = ( + analytics.reports() + .batchGet(body={"reportRequests": [reportRequest]}) + .execute() + ) + + if response.get("reports"): + report = response["reports"][0] + rows = report.get("data", {}).get("rows", []) + + while report.get("nextPageToken"): + time.sleep(1) + reportRequest.update({"pageToken": report["nextPageToken"]}) + response = ( + analytics.reports() + .batchGet(body={"reportRequests": [reportRequest]}) + .execute() + ) + report = response["reports"][0] + rows.extend(report.get("data", {}).get("rows", [])) + + if report["data"]: + report["data"]["rows"] = rows + + return report + else: + return {} diff --git a/plugins/hooks/selenium_hook.py b/plugins/hooks/selenium_hook.py new file mode 100644 index 0000000..283c922 --- /dev/null +++ b/plugins/hooks/selenium_hook.py @@ -0,0 +1,95 @@ +import logging +import logging as log +import os +import time + +import docker +from airflow.hooks.base_hook import BaseHook +from selenium import webdriver +from selenium.webdriver.chrome.options import Options +from selenium.webdriver.common.desired_capabilities import DesiredCapabilities + + +class SeleniumHook(BaseHook): + """ + Creates a Selenium Docker container on the host and controls the + browser by sending commands to the remote server. + """ + + def __init__(self): + logging.info("initialised hook") + pass + + def create_container(self): + """ + Creates the selenium docker container + """ + logging.info("creating_container") + cwd = os.getcwd() + self.local_downloads = os.path.join(cwd, "downloads") + self.sel_downloads = "/home/seluser/downloads" + volumes = [ + "{}:{}".format(self.local_downloads, self.sel_downloads), + "/dev/shm:/dev/shm", + ] + client = docker.from_env() + container = client.containers.run( + "selenium/standalone-chrome", + volumes=volumes, + network="container_bridge", + detach=True, + ) + self.container = container + cli = docker.APIClient() + self.container_ip = cli.inspect_container(container.id)["NetworkSettings"][ + "Networks" + ]["container_bridge"]["IPAddress"] + + def create_driver(self): + """ + creates and configure the remote Selenium webdriver. + """ + logging.info("creating driver") + options = Options() + options.add_argument("--headless") + options.add_argument("--window-size=1920x1080") + chrome_driver = "{}:4444/wd/hub".format(self.container_ip) + # chrome_driver = '{}:4444/wd/hub'.format('http://127.0.0.1') # local + # wait for remote, unless timeout. + while True: + try: + driver = webdriver.Remote( + command_executor=chrome_driver, + desired_capabilities=DesiredCapabilities.CHROME, + options=options, + ) + print("remote ready") + break + except: + print("remote not ready, sleeping for ten seconds.") + time.sleep(10) + # Enable downloads in headless chrome. + driver.command_executor._commands["send_command"] = ( + "POST", + "/session/$sessionId/chromium/send_command", + ) + params = { + "cmd": "Page.setDownloadBehavior", + "params": {"behavior": "allow", "downloadPath": self.sel_downloads}, + } + driver.execute("send_command", params) + self.driver = driver + + def remove_container(self): + """ + This removes the Selenium container. + """ + self.container.remove(force=True) + print("Removed container: {}".format(self.container.id)) + + def run_script(self, script, args): + """ + This is a wrapper around the python script which sends commands to + the docker container. The first variable of the script must be the web driver. + """ + script(self.driver, *args) diff --git a/plugins/operators/__init__.py b/plugins/operators/__init__.py new file mode 100644 index 0000000..8dcf697 --- /dev/null +++ b/plugins/operators/__init__.py @@ -0,0 +1,2 @@ +from operators.singer import SingerOperator +from operators.sudo_bash import SudoBashOperator diff --git a/plugins/operators/discord.py b/plugins/operators/discord.py new file mode 100644 index 0000000..495970f --- /dev/null +++ b/plugins/operators/discord.py @@ -0,0 +1,95 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookHook +from airflow.providers.http.operators.http import SimpleHttpOperator +from airflow.utils.decorators import apply_defaults + + +class DiscordWebhookOperator(SimpleHttpOperator): + """ + This operator allows you to post messages to Discord using incoming webhooks. + Takes a Discord connection ID with a default relative webhook endpoint. The + default endpoint can be overridden using the webhook_endpoint parameter + (https://discordapp.com/developers/docs/resources/webhook). + Each Discord webhook can be pre-configured to use a specific username and + avatar_url. You can override these defaults in this operator. + :param http_conn_id: Http connection ID with host as "https://discord.com/api/" and + default webhook endpoint in the extra field in the form of + {"webhook_endpoint": "webhooks/{webhook.id}/{webhook.token}"} + :type http_conn_id: str + :param webhook_endpoint: Discord webhook endpoint in the form of + "webhooks/{webhook.id}/{webhook.token}" + :type webhook_endpoint: str + :param message: The message you want to send to your Discord channel + (max 2000 characters). (templated) + :type message: str + :param username: Override the default username of the webhook. (templated) + :type username: str + :param avatar_url: Override the default avatar of the webhook + :type avatar_url: str + :param tts: Is a text-to-speech message + :type tts: bool + :param proxy: Proxy to use to make the Discord webhook call + :type proxy: str + """ + + template_fields = ["username", "message"] + + @apply_defaults + def __init__( + self, + *, + http_conn_id: Optional[str] = None, + webhook_endpoint: Optional[str] = None, + message: str = "", + username: Optional[str] = None, + avatar_url: Optional[str] = None, + tts: bool = False, + proxy: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(endpoint=webhook_endpoint, **kwargs) + + if not http_conn_id: + raise AirflowException("No valid Discord http_conn_id supplied.") + + self.http_conn_id = http_conn_id + self.webhook_endpoint = webhook_endpoint + self.message = message + self.username = username + self.avatar_url = avatar_url + self.tts = tts + self.proxy = proxy + self.hook: Optional[DiscordWebhookHook] = None + + def execute(self, context: Dict) -> None: + """Call the DiscordWebhookHook to post message""" + self.hook = DiscordWebhookHook( + self.http_conn_id, + self.webhook_endpoint, + self.message, + self.username, + self.avatar_url, + self.tts, + self.proxy, + ) + self.hook.execute() diff --git a/plugins/operators/google.py b/plugins/operators/google.py new file mode 100644 index 0000000..d2d9606 --- /dev/null +++ b/plugins/operators/google.py @@ -0,0 +1,254 @@ +import gzip +import json +import logging +import os +import time + +import boa +from airflow.hooks.base_hook import BaseHook +from airflow.hooks.S3_hook import S3Hook +from airflow.models import BaseOperator, Variable +from hooks.google import GoogleHook +from six import BytesIO + + +class GoogleSheetsToS3Operator(BaseOperator): + """ + Google Sheets To S3 Operator + + :param google_conn_id: The Google connection id. + :type google_conn_id: string + :param sheet_id: The id for associated report. + :type sheet_id: string + :param sheet_names: The name for the relevent sheets in the report. + :type sheet_names: string/array + :param range: The range of of cells containing the relevant data. + This must be the same for all sheets if multiple + are being pulled together. + Example: Sheet1!A2:E80 + :type range: string + :param include_schema: If set to true, infer the schema of the data and + output to S3 as a separate file + :type include_schema: boolean + :param s3_conn_id: The s3 connection id. + :type s3_conn_id: string + :param s3_key: The S3 key to be used to store the + retrieved data. + :type s3_key: string + """ + + template_fields = ("s3_key",) + + def __init__( + self, + google_conn_id, + sheet_id, + s3_conn_id, + s3_key, + compression_bound, + include_schema=False, + sheet_names=[], + range=None, + output_format="json", + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.google_conn_id = google_conn_id + self.sheet_id = sheet_id + self.sheet_names = sheet_names + self.s3_conn_id = s3_conn_id + self.s3_key = s3_key + self.include_schema = include_schema + self.range = range + self.output_format = output_format.lower() + self.compression_bound = compression_bound + if self.output_format not in ("json"): + raise Exception("Acceptable output formats are: json.") + + if self.sheet_names and not isinstance(self.sheet_names, (str, list, tuple)): + raise Exception("Please specify the sheet names as a string or list.") + + def execute(self, context): + g_conn = GoogleHook(self.google_conn_id) + + if isinstance(self.sheet_names, str) and "," in self.sheet_names: + sheet_names = self.sheet_names.split(",") + else: + sheet_names = self.sheet_names + + sheets_object = g_conn.get_service_object("sheets", "v4") + logging.info("Retrieved Sheets Object") + + response = ( + sheets_object.spreadsheets() + .get(spreadsheetId=self.sheet_id, includeGridData=True) + .execute() + ) + + title = response.get("properties").get("title") + sheets = response.get("sheets") + + final_output = dict() + + total_sheets = [] + for sheet in sheets: + name = sheet.get("properties").get("title") + name = boa.constrict(name) + total_sheets.append(name) + + if self.sheet_names: + if name not in sheet_names: + logging.info( + "{} is not found in available sheet names.".format(name) + ) + continue + + table_name = name + data = sheet.get("data")[0].get("rowData") + output = [] + + for row in data: + row_data = [] + values = row.get("values") + for value in values: + ev = value.get("effectiveValue") + if ev is None: + row_data.append(None) + else: + for v in ev.values(): + row_data.append(v) + + output.append(row_data) + + if self.output_format == "json": + headers = output.pop(0) + output = [dict(zip(headers, row)) for row in output] + + final_output[table_name] = output + + s3 = S3Hook(self.s3_conn_id) + + for sheet in final_output: + output_data = final_output.get(sheet) + + file_name, file_extension = os.path.splitext(self.s3_key) + + output_name = "".join([file_name, "_", sheet, file_extension]) + + if self.include_schema is True: + schema_name = "".join( + [file_name, "_", sheet, "_schema", file_extension] + ) + + self.output_manager( + s3, output_name, output_data, context, sheet, schema_name + ) + + dag_id = context["ti"].dag_id + + var_key = "_".join([dag_id, self.sheet_id]) + Variable.set(key=var_key, value=json.dumps(total_sheets)) + time.sleep(10) + + return boa.constrict(title) + + def output_manager( + self, s3, output_name, output_data, context, sheet_name, schema_name=None + ): + self.s3_bucket = BaseHook.get_connection(self.s3_conn_id).host + if self.output_format == "json": + output = "\n".join( + [ + json.dumps({boa.constrict(str(k)): v for k, v in record.items()}) + for record in output_data + ] + ) + + enc_output = str.encode(output, "utf-8") + + # if file is more than bound then apply gzip compression + if len(enc_output) / 1024 / 1024 >= self.compression_bound: + logging.info( + "File is more than {}MB, gzip compression will be applied".format( + self.compression_bound + ) + ) + output = gzip.compress(enc_output, compresslevel=5) + self.xcom_push( + context, + key="is_compressed_{}".format(sheet_name), + value="compressed", + ) + self.load_bytes( + s3, + bytes_data=output, + key=output_name, + bucket_name=self.s3_bucket, + replace=True, + ) + else: + logging.info( + "File is less than {}MB, compression will not be applied".format( + self.compression_bound + ) + ) + self.xcom_push( + context, + key="is_compressed_{}".format(sheet_name), + value="non-compressed", + ) + s3.load_string( + string_data=output, + key=output_name, + bucket_name=self.s3_bucket, + replace=True, + ) + + if self.include_schema is True: + output_keys = output_data[0].keys() + schema = [ + {"name": boa.constrict(a), "type": "varchar(512)"} + for a in output_keys + if a is not None + ] + schema = {"columns": schema} + + s3.load_string( + string_data=json.dumps(schema), + key=schema_name, + bucket_name=self.s3_bucket, + replace=True, + ) + + logging.info('Successfully output of "{}" to S3.'.format(output_name)) + + # TODO -- Add support for csv output + + # elif self.output_format == 'csv': + # with NamedTemporaryFile("w") as f: + # writer = csv.writer(f) + # writer.writerows(output_data) + # s3.load_file( + # filename=f.name, + # key=output_name, + # bucket_name=self.s3_bucket, + # replace=True + # ) + # + # if self.include_schema is True: + # pass + + # TODO: remove when airflow version is upgraded to 1.10 + def load_bytes(self, s3, bytes_data, key, bucket_name=None, replace=False): + if not bucket_name: + (bucket_name, key) = s3.parse_s3_url(key) + + if not replace and s3.check_for_key(key, bucket_name): + raise ValueError("The key {key} already exists.".format(key=key)) + + filelike_buffer = BytesIO(bytes_data) + + client = s3.get_conn() + client.upload_fileobj(filelike_buffer, bucket_name, key, ExtraArgs={}) diff --git a/plugins/operators/selenium_operator.py b/plugins/operators/selenium_operator.py new file mode 100644 index 0000000..62d7af6 --- /dev/null +++ b/plugins/operators/selenium_operator.py @@ -0,0 +1,24 @@ +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from hooks.selenium_hook import SeleniumHook + + +class SeleniumOperator(BaseOperator): + """ + Selenium Operator + """ + + template_fields = ["script_args"] + + @apply_defaults + def __init__(self, script, script_args, *args, **kwargs): + super().__init__(*args, **kwargs) + self.script = script + self.script_args = script_args + + def execute(self, context): + hook = SeleniumHook() + hook.create_container() + hook.create_driver() + hook.run_script(self.script, self.script_args) + hook.remove_container() diff --git a/plugins/operators/singer.py b/plugins/operators/singer.py new file mode 100644 index 0000000..deb0678 --- /dev/null +++ b/plugins/operators/singer.py @@ -0,0 +1,45 @@ +from airflow.operators.bash_operator import BashOperator +from airflow.utils.decorators import apply_defaults + + +class SingerOperator(BashOperator): + """ + Singer Plugin + + :param tap: The relevant Singer Tap. + :type tap: string + :param target: The relevant Singer Target + :type target: string + :param tap_config: The config path for the Singer Tap. + :type tap_config: string + :param target_config: The config path for the Singer Target. + :type target_config: string + """ + + @apply_defaults + def __init__( + self, tap, target, tap_config=None, target_config=None, *args, **kwargs + ): + + self.tap = "tap-{}".format(tap) + self.target = "target-{}".format(target) + self.tap_config = tap_config + self.target_config = target_config + + if self.tap_config: + if self.target_config: + self.bash_command = "{} -c {} | {} -c {}".format( + self.tap, self.tap_config, self.target, self.target_config + ) + else: + self.bash_command = "{} -c {} | {}".format( + self.tap, self.tap_config, self.target + ) + elif self.target_config: + self.bash_command = "{} | {} -c {}".format( + self.tap, self.target, self.target_config + ) + else: + self.bash_command = "{} | {}".format(self.tap, self.target) + + super().__init__(bash_command=self.bash_command, *args, **kwargs) diff --git a/plugins/operators/sudo_bash.py b/plugins/operators/sudo_bash.py new file mode 100644 index 0000000..b6e06ad --- /dev/null +++ b/plugins/operators/sudo_bash.py @@ -0,0 +1,110 @@ +import getpass +import logging +import os +from builtins import bytes +from subprocess import PIPE, STDOUT, Popen +from tempfile import NamedTemporaryFile, gettempdir + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.utils.file import TemporaryDirectory + + +class SudoBashOperator(BaseOperator): + """ + Execute a Bash script, command or set of commands but sudo's as another user to execute them. + + :param bash_command: The command, set of commands or reference to a + bash script (must be '.sh') to be executed. + :type bash_command: string + :param user: The user to run the command as. The Airflow worker user + must have permission to sudo as that user + :type user: string + :param env: If env is not None, it must be a mapping that defines the + environment variables for the new process; these are used instead + of inheriting the current process environment, which is the default + behavior. + :type env: dict + :type output_encoding: output encoding of bash command + """ + + template_fields = ("bash_command", "user", "env", "output_encoding") + template_ext = ( + ".sh", + ".bash", + ) + ui_color = "#f0ede4" + + @apply_defaults + def __init__( + self, + bash_command, + user, + xcom_push=False, + env=None, + output_encoding="utf-8", + *args, + **kwargs, + ): + + super(SudoBashOperator, self).__init__(*args, **kwargs) + self.bash_command = bash_command + self.user = user + self.env = env + self.xcom_push_flag = xcom_push + self.output_encoding = output_encoding + + def execute(self, context): + """ + Execute the bash command in a temporary directory which will be cleaned afterwards. + """ + logging.info("tmp dir root location: \n" + gettempdir()) + with TemporaryDirectory(prefix="airflowtmp") as tmp_dir: + os.chmod(tmp_dir, 777) + # Ensure the sudo user has perms to their current working directory for making tempfiles + # This is not really a security flaw because the only thing in that dir is the + # temp script, owned by the airflow user and any temp files made by the sudo user + # and all of those will be created with the owning user's umask + # If a process needs finer control over the tempfiles it creates, that process can chmod + # them as they are created. + with NamedTemporaryFile(dir=tmp_dir, prefix=self.task_id) as f: + + if self.user == getpass.getuser(): # don't try to sudo as yourself + f.write(bytes(self.bash_command, "utf_8")) + else: + sudo_cmd = "sudo -u {} sh -c '{}'".format( + self.user, self.bash_command + ) + f.write(bytes(sudo_cmd, "utf_8")) + f.flush() + + logging.info("Temporary script location: {0}".format(f.name)) + logging.info("Running command: {}".format(self.bash_command)) + self.sp = Popen( + ["bash", f.name], + stdout=PIPE, + stderr=STDOUT, + cwd=tmp_dir, + env=self.env, + ) + + logging.info("Output:") + line = "" + for line in iter(self.sp.stdout.readline, b""): + line = line.decode(self.output_encoding).strip() + logging.info(line) + self.sp.wait() + logging.info( + "Command exited with return code {0}".format(self.sp.returncode) + ) + + if self.sp.returncode: + raise AirflowException("Bash command failed") + + if self.xcom_push_flag: + return line + + def on_kill(self): + logging.warn("Sending SIGTERM signal to bash subprocess") + self.sp.terminate() diff --git a/plugins/tools/GoogleSheetsPlugin.py b/plugins/tools/GoogleSheetsPlugin.py new file mode 100644 index 0000000..66d11ee --- /dev/null +++ b/plugins/tools/GoogleSheetsPlugin.py @@ -0,0 +1,15 @@ +from airflow.plugins_manager import AirflowPlugin +from hooks.google import GoogleHook +from operators.google import GoogleSheetsToS3Operator + + +class google_sheets_plugin(AirflowPlugin): + name = "GoogleSheetsPlugin" + operators = [GoogleSheetsToS3Operator] + # Leave in for explicitness + hooks = [GoogleHook] + executors = [] + macros = [] + admin_views = [] + flask_blueprints = [] + menu_links = [] diff --git a/plugins/tools/singer.py b/plugins/tools/singer.py new file mode 100644 index 0000000..3b8ec53 --- /dev/null +++ b/plugins/tools/singer.py @@ -0,0 +1,13 @@ +from airflow.plugins_manager import AirflowPlugin +from operators.singer import SingerOperator + + +class SingerPlugin(AirflowPlugin): + name = "singer_plugin" + hooks = [] + operators = [SingerOperator] + executors = [] + macros = [] + admin_views = [] + flask_blueprints = [] + menu_links = [] diff --git a/reference/providers/airbyte/CHANGELOG.rst b/reference/providers/airbyte/CHANGELOG.rst new file mode 100644 index 0000000..cef7dda --- /dev/null +++ b/reference/providers/airbyte/CHANGELOG.rst @@ -0,0 +1,25 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/airbyte/__init__.py b/reference/providers/airbyte/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/airbyte/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/airbyte/example_dags/__init__.py b/reference/providers/airbyte/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/airbyte/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/airbyte/example_dags/example_airbyte_trigger_job.py b/reference/providers/airbyte/example_dags/example_airbyte_trigger_job.py new file mode 100644 index 0000000..5c8ac42 --- /dev/null +++ b/reference/providers/airbyte/example_dags/example_airbyte_trigger_job.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example DAG demonstrating the usage of the AirbyteTriggerSyncOperator.""" + +from datetime import timedelta + +from airflow import DAG +from airflow.providers.airbyte.operators.airbyte import AirbyteTriggerSyncOperator +from airflow.providers.airbyte.sensors.airbyte import AirbyteJobSensor +from airflow.utils.dates import days_ago + +args = { + "owner": "airflow", +} + +with DAG( + dag_id="example_airbyte_operator", + default_args=args, + schedule_interval=None, + start_date=days_ago(1), + dagrun_timeout=timedelta(minutes=60), + tags=["example"], +) as dag: + + # [START howto_operator_airbyte_synchronous] + sync_source_destination = AirbyteTriggerSyncOperator( + task_id="airbyte_sync_source_dest_example", + airbyte_conn_id="airbyte_default", + connection_id="15bc3800-82e4-48c3-a32d-620661273f28", + ) + # [END howto_operator_airbyte_synchronous] + + # [START howto_operator_airbyte_asynchronous] + async_source_destination = AirbyteTriggerSyncOperator( + task_id="airbyte_async_source_dest_example", + airbyte_conn_id="airbyte_default", + connection_id="15bc3800-82e4-48c3-a32d-620661273f28", + asynchronous=True, + ) + + airbyte_sensor = AirbyteJobSensor( + task_id="airbyte_sensor_source_dest_example", + airbyte_job_id="{{task_instance.xcom_pull(task_ids='airbyte_async_source_dest_example')}}", + airbyte_conn_id="airbyte_default", + ) + # [END howto_operator_airbyte_asynchronous] + + async_source_destination >> airbyte_sensor diff --git a/reference/providers/airbyte/hooks/__init__.py b/reference/providers/airbyte/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/airbyte/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/airbyte/hooks/airbyte.py b/reference/providers/airbyte/hooks/airbyte.py new file mode 100644 index 0000000..e1d0a60 --- /dev/null +++ b/reference/providers/airbyte/hooks/airbyte.py @@ -0,0 +1,123 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import time +from typing import Any, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.http.hooks.http import HttpHook + + +class AirbyteHook(HttpHook): + """ + Hook for Airbyte API + + :param airbyte_conn_id: Required. The name of the Airflow connection to get + connection information for Airbyte. + :type airbyte_conn_id: str + :param api_version: Optional. Airbyte API version. + :type api_version: str + """ + + RUNNING = "running" + SUCCEEDED = "succeeded" + CANCELLED = "cancelled" + PENDING = "pending" + FAILED = "failed" + ERROR = "error" + + def __init__( + self, + airbyte_conn_id: str = "airbyte_default", + api_version: Optional[str] = "v1", + ) -> None: + super().__init__(http_conn_id=airbyte_conn_id) + self.api_version: str = api_version + + def wait_for_job( + self, + job_id: str, + wait_seconds: Optional[float] = 3, + timeout: Optional[float] = 3600, + ) -> None: + """ + Helper method which polls a job to check if it finishes. + + :param job_id: Required. Id of the Airbyte job + :type job_id: str + :param wait_seconds: Optional. Number of seconds between checks. + :type wait_seconds: float + :param timeout: Optional. How many seconds wait for job to be ready. + Used only if ``asynchronous`` is False. + :type timeout: float + """ + state = None + start = time.monotonic() + while True: + if timeout and start + timeout < time.monotonic(): + raise AirflowException( + f"Timeout: Airbyte job {job_id} is not ready after {timeout}s" + ) + time.sleep(wait_seconds) + try: + job = self.get_job(job_id=job_id) + state = job.json()["job"]["status"] + except AirflowException as err: + self.log.info( + "Retrying. Airbyte API returned server error when waiting for job: %s", + err, + ) + continue + + if state in (self.RUNNING, self.PENDING): + continue + if state == self.SUCCEEDED: + break + if state == self.ERROR: + raise AirflowException(f"Job failed:\n{job}") + elif state == self.CANCELLED: + raise AirflowException(f"Job was cancelled:\n{job}") + else: + raise Exception( + f"Encountered unexpected state `{state}` for job_id `{job_id}`" + ) + + def submit_sync_connection(self, connection_id: str) -> Any: + """ + Submits a job to a Airbyte server. + + :param connection_id: Required. The ConnectionId of the Airbyte Connection. + :type connectiond_id: str + """ + return self.run( + endpoint=f"api/{self.api_version}/connections/sync", + json={"connectionId": connection_id}, + headers={"accept": "application/json"}, + ) + + def get_job(self, job_id: int) -> Any: + """ + Gets the resource representation for a job in Airbyte. + + :param job_id: Required. Id of the Airbyte job + :type job_id: int + """ + return self.run( + endpoint=f"api/{self.api_version}/jobs/get", + json={"id": job_id}, + headers={"accept": "application/json"}, + ) diff --git a/reference/providers/airbyte/operators/__init__.py b/reference/providers/airbyte/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/airbyte/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/airbyte/operators/airbyte.py b/reference/providers/airbyte/operators/airbyte.py new file mode 100644 index 0000000..0ea9fdf --- /dev/null +++ b/reference/providers/airbyte/operators/airbyte.py @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from airflow.models import BaseOperator +from airflow.providers.airbyte.hooks.airbyte import AirbyteHook +from airflow.utils.decorators import apply_defaults + + +class AirbyteTriggerSyncOperator(BaseOperator): + """ + This operator allows you to submit a job to an Airbyte server to run a integration + process between your source and destination. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AirbyteTriggerSyncOperator` + + :param airbyte_conn_id: Required. The name of the Airflow connection to get connection + information for Airbyte. + :type airbyte_conn_id: str + :param connection_id: Required. The Airbyte ConnectionId UUID between a source and destination. + :type connection_id: str + :param asynchronous: Optional. Flag to get job_id after submitting the job to the Airbyte API. + This is useful for submitting long running jobs and + waiting on them asynchronously using the AirbyteJobSensor. + :type asynchronous: bool + :param api_version: Optional. Airbyte API version. + :type api_version: str + :param wait_seconds: Optional. Number of seconds between checks. Only used when ``asynchronous`` is False. + :type wait_seconds: float + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Only used when ``asynchronous`` is False. + :type timeout: float + """ + + template_fields = ("connection_id",) + + @apply_defaults + def __init__( + self, + connection_id: str, + airbyte_conn_id: str = "airbyte_default", + asynchronous: Optional[bool] = False, + api_version: Optional[str] = "v1", + wait_seconds: Optional[float] = 3, + timeout: Optional[float] = 3600, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.airbyte_conn_id = airbyte_conn_id + self.connection_id = connection_id + self.timeout = timeout + self.api_version = api_version + self.wait_seconds = wait_seconds + self.asynchronous = asynchronous + + def execute(self, context) -> None: + """Create Airbyte Job and wait to finish""" + hook = AirbyteHook( + airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version + ) + job_object = hook.submit_sync_connection(connection_id=self.connection_id) + job_id = job_object.json()["job"]["id"] + + self.log.info("Job %s was submitted to Airbyte Server", job_id) + if not self.asynchronous: + self.log.info("Waiting for job %s to complete", job_id) + hook.wait_for_job( + job_id=job_id, wait_seconds=self.wait_seconds, timeout=self.timeout + ) + self.log.info("Job %s completed successfully", job_id) + + return job_id diff --git a/reference/providers/airbyte/provider.yaml b/reference/providers/airbyte/provider.yaml new file mode 100644 index 0000000..77b109f --- /dev/null +++ b/reference/providers/airbyte/provider.yaml @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-airbyte +name: Airbyte +description: | + `Airbyte `__ + +versions: + - 1.0.0 + +integrations: + - integration-name: Airbyte + external-doc-url: https://www.airbyte.io/ + logo: /integration-logos/airbyte/Airbyte.png + how-to-guide: + - /docs/apache-airflow-providers-airbyte/operators/airbyte.rst + tags: [service] + +operators: + - integration-name: Airbyte + python-modules: + - airflow.providers.airbyte.operators.airbyte + +hooks: + - integration-name: Airbyte + python-modules: + - airflow.providers.airbyte.hooks.airbyte + +sensors: + - integration-name: Airbyte + python-modules: + - airflow.providers.airbyte.sensors.airbyte + +hook-class-names: + - airflow.providers.airbyte.hooks.airbyte.AirbyteHook diff --git a/reference/providers/airbyte/sensors/__init__.py b/reference/providers/airbyte/sensors/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/airbyte/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/airbyte/sensors/airbyte.py b/reference/providers/airbyte/sensors/airbyte.py new file mode 100644 index 0000000..6dac4bd --- /dev/null +++ b/reference/providers/airbyte/sensors/airbyte.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Airbyte Job sensor.""" +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.providers.airbyte.hooks.airbyte import AirbyteHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class AirbyteJobSensor(BaseSensorOperator): + """ + Check for the state of a previously submitted Airbyte job. + + :param airbyte_job_id: Required. Id of the Airbyte job + :type airbyte_job_id: str + :param airbyte_conn_id: Required. The name of the Airflow connection to get + connection information for Airbyte. + :type airbyte_conn_id: str + :param api_version: Optional. Airbyte API version. + :type api_version: str + """ + + template_fields = ("airbyte_job_id",) + ui_color = "#6C51FD" + + @apply_defaults + def __init__( + self, + *, + airbyte_job_id: str, + airbyte_conn_id: str = "airbyte_default", + api_version: Optional[str] = "v1", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.airbyte_conn_id = airbyte_conn_id + self.airbyte_job_id = airbyte_job_id + self.api_version = api_version + + def poke(self, context: dict) -> bool: + hook = AirbyteHook( + airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version + ) + job = hook.get_job(job_id=self.airbyte_job_id) + status = job.json()["job"]["status"] + + if status == hook.FAILED: + raise AirflowException(f"Job failed: \n{job}") + elif status == hook.CANCELLED: + raise AirflowException(f"Job was cancelled: \n{job}") + elif status == hook.SUCCEEDED: + self.log.info("Job %s completed successfully.", self.airbyte_job_id) + return True + elif status == hook.ERROR: + self.log.info("Job %s attempt has failed.", self.airbyte_job_id) + + self.log.info("Waiting for job %s to complete.", self.airbyte_job_id) + return False diff --git a/reference/providers/amazon/CHANGELOG.rst b/reference/providers/amazon/CHANGELOG.rst new file mode 100644 index 0000000..eb9c2e3 --- /dev/null +++ b/reference/providers/amazon/CHANGELOG.rst @@ -0,0 +1,63 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.2.0 +..... + +Features +~~~~~~~~ + +* ``Avoid using threads in S3 remote logging upload (#14414)`` +* ``Allow AWS Operator RedshiftToS3Transfer To Run a Custom Query (#14177)`` +* ``includes the STS token if STS credentials are used (#11227)`` + +1.1.0 +..... + +Features +~~~~~~~~ + +* ``Adding support to put extra arguments for Glue Job. (#14027)`` +* ``Add aws ses email backend for use with EmailOperator. (#13986)`` +* ``Add bucket_name to template fileds in S3 operators (#13973)`` +* ``Add ExasolToS3Operator (#13847)`` +* ``AWS Glue Crawler Integration (#13072)`` +* ``Add acl_policy to S3CopyObjectOperator (#13773)`` +* ``AllowDiskUse parameter and docs in MongotoS3Operator (#12033)`` +* ``Add S3ToFTPOperator (#11747)`` +* ``add xcom push for ECSOperator (#12096)`` +* ``[AIRFLOW-3723] Add Gzip capability to mongo_to_S3 operator (#13187)`` +* ``Add S3KeySizeSensor (#13049)`` +* ``Add 'mongo_collection' to template_fields in MongoToS3Operator (#13361)`` +* ``Allow Tags on AWS Batch Job Submission (#13396)`` + +Bug fixes +~~~~~~~~~ + +* ``Fix bug in GCSToS3Operator (#13718)`` +* ``Fix S3KeysUnchangedSensor so that template_fields work (#13490)`` + + +1.0.0 +..... + + +Initial version of the provider. diff --git a/reference/providers/amazon/__init__.py b/reference/providers/amazon/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/amazon/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/amazon/aws/__init__.py b/reference/providers/amazon/aws/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/amazon/aws/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/amazon/aws/example_dags/__init__.py b/reference/providers/amazon/aws/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/amazon/aws/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/amazon/aws/example_dags/example_datasync_1.py b/reference/providers/amazon/aws/example_dags/example_datasync_1.py new file mode 100644 index 0000000..e9d2017 --- /dev/null +++ b/reference/providers/amazon/aws/example_dags/example_datasync_1.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example dag for using `AWSDataSyncOperator` in a straightforward manner. + +This DAG gets an AWS TaskArn for a specified source and destination, and then attempts to execute it. +It assumes there is a single task returned and does not do error checking (eg if multiple tasks were found). + +This DAG relies on the following environment variables: + +* SOURCE_LOCATION_URI - Source location URI, usually on premises SMB or NFS +* DESTINATION_LOCATION_URI - Destination location URI, usually S3 +""" + +from os import getenv + +from airflow import models +from airflow.providers.amazon.aws.operators.datasync import AWSDataSyncOperator +from airflow.utils.dates import days_ago + +# [START howto_operator_datasync_1_args_1] +TASK_ARN = getenv("TASK_ARN", "my_aws_datasync_task_arn") +# [END howto_operator_datasync_1_args_1] + +# [START howto_operator_datasync_1_args_2] +SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/") + +DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix") +# [END howto_operator_datasync_1_args_2] + + +with models.DAG( + "example_datasync_1_1", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + + # [START howto_operator_datasync_1_1] + datasync_task_1 = AWSDataSyncOperator( + aws_conn_id="aws_default", task_id="datasync_task_1", task_arn=TASK_ARN + ) + # [END howto_operator_datasync_1_1] + +with models.DAG( + "example_datasync_1_2", + start_date=days_ago(1), + schedule_interval=None, # Override to match your needs +) as dag: + # [START howto_operator_datasync_1_2] + datasync_task_2 = AWSDataSyncOperator( + aws_conn_id="aws_default", + task_id="datasync_task_2", + source_location_uri=SOURCE_LOCATION_URI, + destination_location_uri=DESTINATION_LOCATION_URI, + ) + # [END howto_operator_datasync_1_2] diff --git a/reference/providers/amazon/aws/example_dags/example_datasync_2.py b/reference/providers/amazon/aws/example_dags/example_datasync_2.py new file mode 100644 index 0000000..c604364 --- /dev/null +++ b/reference/providers/amazon/aws/example_dags/example_datasync_2.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example dag for using `AWSDataSyncOperator` in a more complex manner. + +- Try to get a TaskArn. If one exists, update it. +- If no tasks exist, try to create a new DataSync Task. + - If source and destination locations don't exist for the new task, create them first +- If many tasks exist, raise an Exception +- After getting or creating a DataSync Task, run it + +This DAG relies on the following environment variables: + +* SOURCE_LOCATION_URI - Source location URI, usually on premises SMB or NFS +* DESTINATION_LOCATION_URI - Destination location URI, usually S3 +* CREATE_TASK_KWARGS - Passed to boto3.create_task(**kwargs) +* CREATE_SOURCE_LOCATION_KWARGS - Passed to boto3.create_location(**kwargs) +* CREATE_DESTINATION_LOCATION_KWARGS - Passed to boto3.create_location(**kwargs) +* UPDATE_TASK_KWARGS - Passed to boto3.update_task(**kwargs) +""" + +import json +import re +from os import getenv + +from airflow import models +from airflow.providers.amazon.aws.operators.datasync import AWSDataSyncOperator +from airflow.utils.dates import days_ago + +# [START howto_operator_datasync_2_args] +SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/") + +DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix") + +default_create_task_kwargs = '{"Name": "Created by Airflow"}' +CREATE_TASK_KWARGS = json.loads( + getenv("CREATE_TASK_KWARGS", default_create_task_kwargs) +) + +default_create_source_location_kwargs = "{}" +CREATE_SOURCE_LOCATION_KWARGS = json.loads( + getenv("CREATE_SOURCE_LOCATION_KWARGS", default_create_source_location_kwargs) +) + +bucket_access_role_arn = ( + "arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role" +) +default_destination_location_kwargs = """\ +{"S3BucketArn": "arn:aws:s3:::mybucket", + "S3Config": {"BucketAccessRoleArn": + "arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role"} +}""" +CREATE_DESTINATION_LOCATION_KWARGS = json.loads( + getenv( + "CREATE_DESTINATION_LOCATION_KWARGS", + re.sub(r"[\s+]", "", default_destination_location_kwargs), + ) +) + +default_update_task_kwargs = '{"Name": "Updated by Airflow"}' +UPDATE_TASK_KWARGS = json.loads( + getenv("UPDATE_TASK_KWARGS", default_update_task_kwargs) +) + +# [END howto_operator_datasync_2_args] + +with models.DAG( + "example_datasync_2", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + + # [START howto_operator_datasync_2] + datasync_task = AWSDataSyncOperator( + aws_conn_id="aws_default", + task_id="datasync_task", + source_location_uri=SOURCE_LOCATION_URI, + destination_location_uri=DESTINATION_LOCATION_URI, + create_task_kwargs=CREATE_TASK_KWARGS, + create_source_location_kwargs=CREATE_SOURCE_LOCATION_KWARGS, + create_destination_location_kwargs=CREATE_DESTINATION_LOCATION_KWARGS, + update_task_kwargs=UPDATE_TASK_KWARGS, + delete_task_after_execution=True, + ) + # [END howto_operator_datasync_2] diff --git a/reference/providers/amazon/aws/example_dags/example_ecs_fargate.py b/reference/providers/amazon/aws/example_dags/example_ecs_fargate.py new file mode 100644 index 0000000..94cecba --- /dev/null +++ b/reference/providers/amazon/aws/example_dags/example_ecs_fargate.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This is an example dag for ECSOperator. + +The task "hello_world" runs `hello-world` task in `c` cluster. +It overrides the command in the `hello-world-container` container. +""" + +import datetime +import os + +from airflow import DAG +from airflow.providers.amazon.aws.operators.ecs import ECSOperator + +DEFAULT_ARGS = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, +} + +dag = DAG( + dag_id="ecs_fargate_dag", + default_args=DEFAULT_ARGS, + default_view="graph", + schedule_interval=None, + start_date=datetime.datetime(2020, 1, 1), + tags=["example"], +) +# generate dag documentation +dag.doc_md = __doc__ + +# [START howto_operator_ecs] +hello_world = ECSOperator( + task_id="hello_world", + dag=dag, + aws_conn_id="aws_ecs", + cluster="c", + task_definition="hello-world", + launch_type="FARGATE", + overrides={ + "containerOverrides": [ + { + "name": "hello-world-container", + "command": ["echo", "hello", "world"], + }, + ], + }, + network_configuration={ + "awsvpcConfiguration": { + "securityGroups": [os.environ.get("SECURITY_GROUP_ID", "sg-123abc")], + "subnets": [os.environ.get("SUBNET_ID", "subnet-123456ab")], + }, + }, + tags={ + "Customer": "X", + "Project": "Y", + "Application": "Z", + "Version": "0.0.1", + "Environment": "Development", + }, + awslogs_group="/ecs/hello-world", + awslogs_stream_prefix="prefix_b/hello-world-container", # prefix with container name +) +# [END howto_operator_ecs] diff --git a/reference/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py b/reference/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py new file mode 100644 index 0000000..3d3039c --- /dev/null +++ b/reference/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example dag for a AWS EMR Pipeline with auto steps. +""" +from datetime import timedelta + +from airflow import DAG +from airflow.providers.amazon.aws.operators.emr_create_job_flow import ( + EmrCreateJobFlowOperator, +) +from airflow.providers.amazon.aws.sensors.emr_job_flow import EmrJobFlowSensor +from airflow.utils.dates import days_ago + +DEFAULT_ARGS = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, +} + +# [START howto_operator_emr_automatic_steps_config] +SPARK_STEPS = [ + { + "Name": "calculate_pi", + "ActionOnFailure": "CONTINUE", + "HadoopJarStep": { + "Jar": "command-runner.jar", + "Args": ["/usr/lib/spark/bin/run-example", "SparkPi", "10"], + }, + } +] + +JOB_FLOW_OVERRIDES = { + "Name": "PiCalc", + "ReleaseLabel": "emr-5.29.0", + "Instances": { + "InstanceGroups": [ + { + "Name": "Master node", + "Market": "SPOT", + "InstanceRole": "MASTER", + "InstanceType": "m1.medium", + "InstanceCount": 1, + } + ], + "KeepJobFlowAliveWhenNoSteps": False, + "TerminationProtected": False, + }, + "Steps": SPARK_STEPS, + "JobFlowRole": "EMR_EC2_DefaultRole", + "ServiceRole": "EMR_DefaultRole", +} +# [END howto_operator_emr_automatic_steps_config] + +with DAG( + dag_id="emr_job_flow_automatic_steps_dag", + default_args=DEFAULT_ARGS, + dagrun_timeout=timedelta(hours=2), + start_date=days_ago(2), + schedule_interval="0 3 * * *", + tags=["example"], +) as dag: + + # [START howto_operator_emr_automatic_steps_tasks] + job_flow_creator = EmrCreateJobFlowOperator( + task_id="create_job_flow", + job_flow_overrides=JOB_FLOW_OVERRIDES, + aws_conn_id="aws_default", + emr_conn_id="emr_default", + ) + + job_sensor = EmrJobFlowSensor( + task_id="check_job_flow", + job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}", + aws_conn_id="aws_default", + ) + + job_flow_creator >> job_sensor + # [END howto_operator_emr_automatic_steps_tasks] diff --git a/reference/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py b/reference/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py new file mode 100644 index 0000000..393c56b --- /dev/null +++ b/reference/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example dag for a AWS EMR Pipeline. + +Starting by creating a cluster, adding steps/operations, checking steps and finally when finished +terminating the cluster. +""" +from datetime import timedelta + +from airflow import DAG +from airflow.providers.amazon.aws.operators.emr_add_steps import EmrAddStepsOperator +from airflow.providers.amazon.aws.operators.emr_create_job_flow import ( + EmrCreateJobFlowOperator, +) +from airflow.providers.amazon.aws.operators.emr_terminate_job_flow import ( + EmrTerminateJobFlowOperator, +) +from airflow.providers.amazon.aws.sensors.emr_step import EmrStepSensor +from airflow.utils.dates import days_ago + +DEFAULT_ARGS = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, +} + +SPARK_STEPS = [ + { + "Name": "calculate_pi", + "ActionOnFailure": "CONTINUE", + "HadoopJarStep": { + "Jar": "command-runner.jar", + "Args": ["/usr/lib/spark/bin/run-example", "SparkPi", "10"], + }, + } +] + +JOB_FLOW_OVERRIDES = { + "Name": "PiCalc", + "ReleaseLabel": "emr-5.29.0", + "Instances": { + "InstanceGroups": [ + { + "Name": "Master node", + "Market": "SPOT", + "InstanceRole": "MASTER", + "InstanceType": "m1.medium", + "InstanceCount": 1, + } + ], + "KeepJobFlowAliveWhenNoSteps": True, + "TerminationProtected": False, + }, + "JobFlowRole": "EMR_EC2_DefaultRole", + "ServiceRole": "EMR_DefaultRole", +} + +with DAG( + dag_id="emr_job_flow_manual_steps_dag", + default_args=DEFAULT_ARGS, + dagrun_timeout=timedelta(hours=2), + start_date=days_ago(2), + schedule_interval="0 3 * * *", + tags=["example"], +) as dag: + + # [START howto_operator_emr_manual_steps_tasks] + cluster_creator = EmrCreateJobFlowOperator( + task_id="create_job_flow", + job_flow_overrides=JOB_FLOW_OVERRIDES, + aws_conn_id="aws_default", + emr_conn_id="emr_default", + ) + + step_adder = EmrAddStepsOperator( + task_id="add_steps", + job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}", + aws_conn_id="aws_default", + steps=SPARK_STEPS, + ) + + step_checker = EmrStepSensor( + task_id="watch_step", + job_flow_id="{{ task_instance.xcom_pull('create_job_flow', key='return_value') }}", + step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}", + aws_conn_id="aws_default", + ) + + cluster_remover = EmrTerminateJobFlowOperator( + task_id="remove_cluster", + job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}", + aws_conn_id="aws_default", + ) + + cluster_creator >> step_adder >> step_checker >> cluster_remover + # [END howto_operator_emr_manual_steps_tasks] diff --git a/reference/providers/amazon/aws/example_dags/example_glacier_to_gcs.py b/reference/providers/amazon/aws/example_dags/example_glacier_to_gcs.py new file mode 100644 index 0000000..87d6aa6 --- /dev/null +++ b/reference/providers/amazon/aws/example_dags/example_glacier_to_gcs.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os + +from airflow import models +from airflow.providers.amazon.aws.operators.glacier import GlacierCreateJobOperator +from airflow.providers.amazon.aws.sensors.glacier import GlacierJobOperationSensor +from airflow.providers.amazon.aws.transfers.glacier_to_gcs import GlacierToGCSOperator +from airflow.utils.dates import days_ago + +VAULT_NAME = "airflow" +BUCKET_NAME = os.environ.get("GLACIER_GCS_BUCKET_NAME", "gs://glacier_bucket") +OBJECT_NAME = os.environ.get("GLACIER_OBJECT", "example-text.txt") + +with models.DAG( + "example_glacier_to_gcs", + schedule_interval=None, + start_date=days_ago(1), # Override to match your needs +) as dag: + # [START howto_glacier_create_job_operator] + create_glacier_job = GlacierCreateJobOperator( + task_id="create_glacier_job", + aws_conn_id="aws_default", + vault_name=VAULT_NAME, + ) + JOB_ID = '{{ task_instance.xcom_pull("create_glacier_job")["jobId"] }}' + # [END howto_glacier_create_job_operator] + + # [START howto_glacier_job_operation_sensor] + wait_for_operation_complete = GlacierJobOperationSensor( + aws_conn_id="aws_default", + vault_name=VAULT_NAME, + job_id=JOB_ID, + task_id="wait_for_operation_complete", + ) + # [END howto_glacier_job_operation_sensor] + + # [START howto_glacier_transfer_data_to_gcs] + transfer_archive_to_gcs = GlacierToGCSOperator( + task_id="transfer_archive_to_gcs", + aws_conn_id="aws_default", + gcp_conn_id="google_cloud_default", + vault_name=VAULT_NAME, + bucket_name=BUCKET_NAME, + object_name=OBJECT_NAME, + gzip=False, + # Override to match your needs + # If chunk size is bigger than actual file size + # then whole file will be downloaded + chunk_size=1024, + delegate_to=None, + google_impersonation_chain=None, + ) + # [END howto_glacier_transfer_data_to_gcs] + + create_glacier_job >> wait_for_operation_complete >> transfer_archive_to_gcs diff --git a/reference/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_advanced.py b/reference/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_advanced.py new file mode 100644 index 0000000..95bb6de --- /dev/null +++ b/reference/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_advanced.py @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is a more advanced example dag for using `GoogleApiToS3Transfer` which uses xcom to pass data between +tasks to retrieve specific information about YouTube videos: + +First it searches for up to 50 videos (due to pagination) in a given time range +(YOUTUBE_VIDEO_PUBLISHED_AFTER, YOUTUBE_VIDEO_PUBLISHED_BEFORE) on a YouTube channel (YOUTUBE_CHANNEL_ID) +saves the response in S3 + passes over the YouTube IDs to the next request which then gets the information +(YOUTUBE_VIDEO_FIELDS) for the requested videos and saves them in S3 (S3_DESTINATION_KEY). + +Further information: + +YOUTUBE_VIDEO_PUBLISHED_AFTER and YOUTUBE_VIDEO_PUBLISHED_BEFORE needs to be formatted +"YYYY-MM-DDThh:mm:ss.sZ". See https://developers.google.com/youtube/v3/docs/search/list for more information. +YOUTUBE_VIDEO_PARTS depends on the fields you pass via YOUTUBE_VIDEO_FIELDS. See +https://developers.google.com/youtube/v3/docs/videos/list#parameters for more information. +YOUTUBE_CONN_ID is optional for public videos. It does only need to authenticate when there are private videos +on a YouTube channel you want to retrieve. +""" + +from os import getenv + +from airflow import DAG +from airflow.operators.dummy import DummyOperator +from airflow.operators.python import BranchPythonOperator +from airflow.providers.amazon.aws.transfers.google_api_to_s3 import ( + GoogleApiToS3Operator, +) +from airflow.utils.dates import days_ago + +# [START howto_operator_google_api_to_s3_transfer_advanced_env_variables] +YOUTUBE_CONN_ID = getenv("YOUTUBE_CONN_ID", "google_cloud_default") +YOUTUBE_CHANNEL_ID = getenv( + "YOUTUBE_CHANNEL_ID", "UCSXwxpWZQ7XZ1WL3wqevChA" +) # "Apache Airflow" +YOUTUBE_VIDEO_PUBLISHED_AFTER = getenv( + "YOUTUBE_VIDEO_PUBLISHED_AFTER", "2019-09-25T00:00:00Z" +) +YOUTUBE_VIDEO_PUBLISHED_BEFORE = getenv( + "YOUTUBE_VIDEO_PUBLISHED_BEFORE", "2019-10-18T00:00:00Z" +) +S3_DESTINATION_KEY = getenv("S3_DESTINATION_KEY", "s3://bucket/key.json") +YOUTUBE_VIDEO_PARTS = getenv("YOUTUBE_VIDEO_PARTS", "snippet") +YOUTUBE_VIDEO_FIELDS = getenv( + "YOUTUBE_VIDEO_FIELDS", "items(id,snippet(description,publishedAt,tags,title))" +) +# [END howto_operator_google_api_to_s3_transfer_advanced_env_variables] + + +# pylint: disable=unused-argument +# [START howto_operator_google_api_to_s3_transfer_advanced_task_1_2] +def _check_and_transform_video_ids(xcom_key, task_ids, task_instance, **kwargs): + video_ids_response = task_instance.xcom_pull(task_ids=task_ids, key=xcom_key) + video_ids = [item["id"]["videoId"] for item in video_ids_response["items"]] + + if video_ids: + task_instance.xcom_push(key="video_ids", value={"id": ",".join(video_ids)}) + return "video_data_to_s3" + return "no_video_ids" + + +# [END howto_operator_google_api_to_s3_transfer_advanced_task_1_2] +# pylint: enable=unused-argument + +s3_directory, s3_file = S3_DESTINATION_KEY.rsplit("/", 1) +s3_file_name, _ = s3_file.rsplit(".", 1) + +with DAG( + dag_id="example_google_api_to_s3_transfer_advanced", + schedule_interval=None, + start_date=days_ago(1), + tags=["example"], +) as dag: + # [START howto_operator_google_api_to_s3_transfer_advanced_task_1] + task_video_ids_to_s3 = GoogleApiToS3Operator( + gcp_conn_id=YOUTUBE_CONN_ID, + google_api_service_name="youtube", + google_api_service_version="v3", + google_api_endpoint_path="youtube.search.list", + google_api_endpoint_params={ + "part": "snippet", + "channelId": YOUTUBE_CHANNEL_ID, + "maxResults": 50, + "publishedAfter": YOUTUBE_VIDEO_PUBLISHED_AFTER, + "publishedBefore": YOUTUBE_VIDEO_PUBLISHED_BEFORE, + "type": "video", + "fields": "items/id/videoId", + }, + google_api_response_via_xcom="video_ids_response", + s3_destination_key=f"{s3_directory}/youtube_search_{s3_file_name}.json", + task_id="video_ids_to_s3", + ) + # [END howto_operator_google_api_to_s3_transfer_advanced_task_1] + # [START howto_operator_google_api_to_s3_transfer_advanced_task_1_1] + task_check_and_transform_video_ids = BranchPythonOperator( + python_callable=_check_and_transform_video_ids, + op_args=[ + task_video_ids_to_s3.google_api_response_via_xcom, + task_video_ids_to_s3.task_id, + ], + task_id="check_and_transform_video_ids", + ) + # [END howto_operator_google_api_to_s3_transfer_advanced_task_1_1] + # [START howto_operator_google_api_to_s3_transfer_advanced_task_2] + task_video_data_to_s3 = GoogleApiToS3Operator( + gcp_conn_id=YOUTUBE_CONN_ID, + google_api_service_name="youtube", + google_api_service_version="v3", + google_api_endpoint_path="youtube.videos.list", + google_api_endpoint_params={ + "part": YOUTUBE_VIDEO_PARTS, + "maxResults": 50, + "fields": YOUTUBE_VIDEO_FIELDS, + }, + google_api_endpoint_params_via_xcom="video_ids", + s3_destination_key=f"{s3_directory}/youtube_videos_{s3_file_name}.json", + task_id="video_data_to_s3", + ) + # [END howto_operator_google_api_to_s3_transfer_advanced_task_2] + # [START howto_operator_google_api_to_s3_transfer_advanced_task_2_1] + task_no_video_ids = DummyOperator(task_id="no_video_ids") + # [END howto_operator_google_api_to_s3_transfer_advanced_task_2_1] + task_video_ids_to_s3 >> task_check_and_transform_video_ids >> [ + task_video_data_to_s3, + task_no_video_ids, + ] diff --git a/reference/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_basic.py b/reference/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_basic.py new file mode 100644 index 0000000..1e3a754 --- /dev/null +++ b/reference/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_basic.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is a basic example dag for using `GoogleApiToS3Transfer` to retrieve Google Sheets data: + +You need to set all env variables to request the data. +""" + +from os import getenv + +from airflow import DAG +from airflow.providers.amazon.aws.transfers.google_api_to_s3 import ( + GoogleApiToS3Operator, +) +from airflow.utils.dates import days_ago + +# [START howto_operator_google_api_to_s3_transfer_basic_env_variables] +GOOGLE_SHEET_ID = getenv("GOOGLE_SHEET_ID") +GOOGLE_SHEET_RANGE = getenv("GOOGLE_SHEET_RANGE") +S3_DESTINATION_KEY = getenv("S3_DESTINATION_KEY", "s3://bucket/key.json") +# [END howto_operator_google_api_to_s3_transfer_basic_env_variables] + + +with DAG( + dag_id="example_google_api_to_s3_transfer_basic", + schedule_interval=None, + start_date=days_ago(1), + tags=["example"], +) as dag: + # [START howto_operator_google_api_to_s3_transfer_basic_task_1] + task_google_sheets_values_to_s3 = GoogleApiToS3Operator( + google_api_service_name="sheets", + google_api_service_version="v4", + google_api_endpoint_path="sheets.spreadsheets.values.get", + google_api_endpoint_params={ + "spreadsheetId": GOOGLE_SHEET_ID, + "range": GOOGLE_SHEET_RANGE, + }, + s3_destination_key=S3_DESTINATION_KEY, + task_id="google_sheets_values_to_s3", + dag=dag, + ) + # [END howto_operator_google_api_to_s3_transfer_basic_task_1] diff --git a/reference/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py b/reference/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py new file mode 100644 index 0000000..e1e8aae --- /dev/null +++ b/reference/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This is an example dag for using `ImapAttachmentToS3Operator` to transfer an email attachment via IMAP +protocol from a mail server to S3 Bucket. +""" + +from os import getenv + +from airflow import DAG +from airflow.providers.amazon.aws.transfers.imap_attachment_to_s3 import ( + ImapAttachmentToS3Operator, +) +from airflow.utils.dates import days_ago + +# [START howto_operator_imap_attachment_to_s3_env_variables] +IMAP_ATTACHMENT_NAME = getenv("IMAP_ATTACHMENT_NAME", "test.txt") +IMAP_MAIL_FOLDER = getenv("IMAP_MAIL_FOLDER", "INBOX") +IMAP_MAIL_FILTER = getenv("IMAP_MAIL_FILTER", "All") +S3_DESTINATION_KEY = getenv("S3_DESTINATION_KEY", "s3://bucket/key.json") +# [END howto_operator_imap_attachment_to_s3_env_variables] + +with DAG( + dag_id="example_imap_attachment_to_s3", + start_date=days_ago(1), + schedule_interval=None, + tags=["example"], +) as dag: + # [START howto_operator_imap_attachment_to_s3_task_1] + task_transfer_imap_attachment_to_s3 = ImapAttachmentToS3Operator( + imap_attachment_name=IMAP_ATTACHMENT_NAME, + s3_key=S3_DESTINATION_KEY, + imap_mail_folder=IMAP_MAIL_FOLDER, + imap_mail_filter=IMAP_MAIL_FILTER, + task_id="transfer_imap_attachment_to_s3", + dag=dag, + ) + # [END howto_operator_imap_attachment_to_s3_task_1] diff --git a/reference/providers/amazon/aws/example_dags/example_s3_bucket.py b/reference/providers/amazon/aws/example_dags/example_s3_bucket.py new file mode 100644 index 0000000..ae4b300 --- /dev/null +++ b/reference/providers/amazon/aws/example_dags/example_s3_bucket.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os + +from airflow.models.dag import DAG +from airflow.operators.python import PythonOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.operators.s3_bucket import ( + S3CreateBucketOperator, + S3DeleteBucketOperator, +) +from airflow.utils.dates import days_ago + +BUCKET_NAME = os.environ.get("BUCKET_NAME", "test-airflow-12345") + + +def upload_keys(): + """This is a python callback to add keys into the s3 bucket""" + # add keys to bucket + s3_hook = S3Hook() + for i in range(0, 3): + s3_hook.load_string( + string_data="input", + key=f"path/data{i}", + bucket_name=BUCKET_NAME, + ) + + +with DAG( + dag_id="s3_bucket_dag", + schedule_interval=None, + start_date=days_ago(2), + max_active_runs=1, + tags=["example"], +) as dag: + + # [START howto_operator_s3_bucket] + create_bucket = S3CreateBucketOperator( + task_id="s3_bucket_dag_create", + bucket_name=BUCKET_NAME, + region_name="us-east-1", + ) + + add_keys_to_bucket = PythonOperator( + task_id="s3_bucket_dag_add_keys_to_bucket", python_callable=upload_keys + ) + + delete_bucket = S3DeleteBucketOperator( + task_id="s3_bucket_dag_delete", + bucket_name=BUCKET_NAME, + force_delete=True, + ) + # [END howto_operator_s3_bucket] + + create_bucket >> add_keys_to_bucket >> delete_bucket diff --git a/reference/providers/amazon/aws/example_dags/example_s3_bucket_tagging.py b/reference/providers/amazon/aws/example_dags/example_s3_bucket_tagging.py new file mode 100644 index 0000000..c295704 --- /dev/null +++ b/reference/providers/amazon/aws/example_dags/example_s3_bucket_tagging.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os + +from airflow.models.dag import DAG +from airflow.providers.amazon.aws.operators.s3_bucket import ( + S3CreateBucketOperator, + S3DeleteBucketOperator, +) +from airflow.providers.amazon.aws.operators.s3_bucket_tagging import ( + S3DeleteBucketTaggingOperator, + S3GetBucketTaggingOperator, + S3PutBucketTaggingOperator, +) +from airflow.utils.dates import days_ago + +BUCKET_NAME = os.environ.get("BUCKET_NAME", "test-s3-bucket-tagging") +TAG_KEY = os.environ.get("TAG_KEY", "test-s3-bucket-tagging-key") +TAG_VALUE = os.environ.get("TAG_VALUE", "test-s3-bucket-tagging-value") + + +with DAG( + dag_id="s3_bucket_tagging_dag", + schedule_interval=None, + start_date=days_ago(2), + max_active_runs=1, + tags=["example"], +) as dag: + + create_bucket = S3CreateBucketOperator( + task_id="s3_bucket_tagging_dag_create", + bucket_name=BUCKET_NAME, + region_name="us-east-1", + ) + + delete_bucket = S3DeleteBucketOperator( + task_id="s3_bucket_tagging_dag_delete", + bucket_name=BUCKET_NAME, + force_delete=True, + ) + + # [START howto_operator_s3_bucket_tagging] + get_tagging = S3GetBucketTaggingOperator( + task_id="s3_bucket_tagging_dag_get_tagging", bucket_name=BUCKET_NAME + ) + + put_tagging = S3PutBucketTaggingOperator( + task_id="s3_bucket_tagging_dag_put_tagging", + bucket_name=BUCKET_NAME, + key=TAG_KEY, + value=TAG_VALUE, + ) + + delete_tagging = S3DeleteBucketTaggingOperator( + task_id="s3_bucket_tagging_dag_delete_tagging", bucket_name=BUCKET_NAME + ) + # [END howto_operator_s3_bucket_tagging] + + create_bucket >> put_tagging >> get_tagging >> delete_tagging >> delete_bucket diff --git a/reference/providers/amazon/aws/example_dags/example_s3_to_redshift.py b/reference/providers/amazon/aws/example_dags/example_s3_to_redshift.py new file mode 100644 index 0000000..6231f13 --- /dev/null +++ b/reference/providers/amazon/aws/example_dags/example_s3_to_redshift.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This is an example dag for using `S3ToRedshiftOperator` to copy a S3 key into a Redshift table. +""" + +from os import getenv + +from airflow import DAG +from airflow.operators.python import PythonOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.transfers.s3_to_redshift import S3ToRedshiftOperator +from airflow.providers.postgres.operators.postgres import PostgresOperator +from airflow.utils.dates import days_ago + +# [START howto_operator_s3_to_redshift_env_variables] +S3_BUCKET = getenv("S3_BUCKET", "test-bucket") +S3_KEY = getenv("S3_KEY", "key") +REDSHIFT_TABLE = getenv("REDSHIFT_TABLE", "test_table") +# [END howto_operator_s3_to_redshift_env_variables] + + +def _add_sample_data_to_s3(): + s3_hook = S3Hook() + s3_hook.load_string( + "0,Airflow", f"{S3_KEY}/{REDSHIFT_TABLE}", S3_BUCKET, replace=True + ) + + +def _remove_sample_data_from_s3(): + s3_hook = S3Hook() + if s3_hook.check_for_key(f"{S3_KEY}/{REDSHIFT_TABLE}", S3_BUCKET): + s3_hook.delete_objects(S3_BUCKET, f"{S3_KEY}/{REDSHIFT_TABLE}") + + +with DAG( + dag_id="example_s3_to_redshift", + start_date=days_ago(1), + schedule_interval=None, + tags=["example"], +) as dag: + setup__task_add_sample_data_to_s3 = PythonOperator( + python_callable=_add_sample_data_to_s3, task_id="setup__add_sample_data_to_s3" + ) + setup__task_create_table = PostgresOperator( + sql=f"CREATE TABLE IF NOT EXISTS {REDSHIFT_TABLE}(Id int, Name varchar)", + postgres_conn_id="redshift_default", + task_id="setup__create_table", + ) + # [START howto_operator_s3_to_redshift_task_1] + task_transfer_s3_to_redshift = S3ToRedshiftOperator( + s3_bucket=S3_BUCKET, + s3_key=S3_KEY, + schema="PUBLIC", + table=REDSHIFT_TABLE, + copy_options=["csv"], + task_id="transfer_s3_to_redshift", + ) + # [END howto_operator_s3_to_redshift_task_1] + teardown__task_drop_table = PostgresOperator( + sql=f"DROP TABLE IF EXISTS {REDSHIFT_TABLE}", + postgres_conn_id="redshift_default", + task_id="teardown__drop_table", + ) + teardown__task_remove_sample_data_from_s3 = PythonOperator( + python_callable=_remove_sample_data_from_s3, + task_id="teardown__remove_sample_data_from_s3", + ) + [ + setup__task_add_sample_data_to_s3, + setup__task_create_table, + ] >> task_transfer_s3_to_redshift >> [ + teardown__task_drop_table, + teardown__task_remove_sample_data_from_s3, + ] diff --git a/reference/providers/amazon/aws/hooks/__init__.py b/reference/providers/amazon/aws/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/amazon/aws/hooks/athena.py b/reference/providers/amazon/aws/hooks/athena.py new file mode 100644 index 0000000..9f9e313 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/athena.py @@ -0,0 +1,263 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains AWS Athena hook""" +from time import sleep +from typing import Any, Dict, Optional + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from botocore.paginate import PageIterator + + +class AWSAthenaHook(AwsBaseHook): + """ + Interact with AWS Athena to run, poll queries and return query results + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + + :param sleep_time: Time (in seconds) to wait between two consecutive calls to check query status on Athena + :type sleep_time: int + """ + + INTERMEDIATE_STATES = ( + "QUEUED", + "RUNNING", + ) + FAILURE_STATES = ( + "FAILED", + "CANCELLED", + ) + SUCCESS_STATES = ("SUCCEEDED",) + + def __init__(self, *args: Any, sleep_time: int = 30, **kwargs: Any) -> None: + super().__init__(client_type="athena", *args, **kwargs) # type: ignore + self.sleep_time = sleep_time + + def run_query( + self, + query: str, + query_context: Dict[str, str], + result_configuration: Dict[str, Any], + client_request_token: Optional[str] = None, + workgroup: str = "primary", + ) -> str: + """ + Run Presto query on athena with provided config and return submitted query_execution_id + + :param query: Presto query to run + :type query: str + :param query_context: Context in which query need to be run + :type query_context: dict + :param result_configuration: Dict with path to store results in and config related to encryption + :type result_configuration: dict + :param client_request_token: Unique token created by user to avoid multiple executions of same query + :type client_request_token: str + :param workgroup: Athena workgroup name, when not specified, will be 'primary' + :type workgroup: str + :return: str + """ + params = { + "QueryString": query, + "QueryExecutionContext": query_context, + "ResultConfiguration": result_configuration, + "WorkGroup": workgroup, + } + if client_request_token: + params["ClientRequestToken"] = client_request_token + response = self.get_conn().start_query_execution(**params) + query_execution_id = response["QueryExecutionId"] + return query_execution_id + + def check_query_status(self, query_execution_id: str) -> Optional[str]: + """ + Fetch the status of submitted athena query. Returns None or one of valid query states. + + :param query_execution_id: Id of submitted athena query + :type query_execution_id: str + :return: str + """ + response = self.get_conn().get_query_execution( + QueryExecutionId=query_execution_id + ) + state = None + try: + state = response["QueryExecution"]["Status"]["State"] + except Exception as ex: # pylint: disable=broad-except + self.log.error("Exception while getting query state %s", ex) + finally: + # The error is being absorbed here and is being handled by the caller. + # The error is being absorbed to implement retries. + return state # pylint: disable=lost-exception + + def get_state_change_reason(self, query_execution_id: str) -> Optional[str]: + """ + Fetch the reason for a state change (e.g. error message). Returns None or reason string. + + :param query_execution_id: Id of submitted athena query + :type query_execution_id: str + :return: str + """ + response = self.get_conn().get_query_execution( + QueryExecutionId=query_execution_id + ) + reason = None + try: + reason = response["QueryExecution"]["Status"]["StateChangeReason"] + except Exception as ex: # pylint: disable=broad-except + self.log.error("Exception while getting query state change reason: %s", ex) + finally: + # The error is being absorbed here and is being handled by the caller. + # The error is being absorbed to implement retries. + return reason # pylint: disable=lost-exception + + def get_query_results( + self, + query_execution_id: str, + next_token_id: Optional[str] = None, + max_results: int = 1000, + ) -> Optional[dict]: + """ + Fetch submitted athena query results. returns none if query is in intermediate state or + failed/cancelled state else dict of query output + + :param query_execution_id: Id of submitted athena query + :type query_execution_id: str + :param next_token_id: The token that specifies where to start pagination. + :type next_token_id: str + :param max_results: The maximum number of results (rows) to return in this request. + :type max_results: int + :return: dict + """ + query_state = self.check_query_status(query_execution_id) + if query_state is None: + self.log.error("Invalid Query state") + return None + elif ( + query_state in self.INTERMEDIATE_STATES + or query_state in self.FAILURE_STATES + ): + self.log.error('Query is in "%s" state. Cannot fetch results', query_state) + return None + result_params = { + "QueryExecutionId": query_execution_id, + "MaxResults": max_results, + } + if next_token_id: + result_params["NextToken"] = next_token_id + return self.get_conn().get_query_results(**result_params) + + def get_query_results_paginator( + self, + query_execution_id: str, + max_items: Optional[int] = None, + page_size: Optional[int] = None, + starting_token: Optional[str] = None, + ) -> Optional[PageIterator]: + """ + Fetch submitted athena query results. returns none if query is in intermediate state or + failed/cancelled state else a paginator to iterate through pages of results. If you + wish to get all results at once, call build_full_result() on the returned PageIterator + + :param query_execution_id: Id of submitted athena query + :type query_execution_id: str + :param max_items: The total number of items to return. + :type max_items: int + :param page_size: The size of each page. + :type page_size: int + :param starting_token: A token to specify where to start paginating. + :type starting_token: str + :return: PageIterator + """ + query_state = self.check_query_status(query_execution_id) + if query_state is None: + self.log.error("Invalid Query state (null)") + return None + if ( + query_state in self.INTERMEDIATE_STATES + or query_state in self.FAILURE_STATES + ): + self.log.error('Query is in "%s" state. Cannot fetch results', query_state) + return None + result_params = { + "QueryExecutionId": query_execution_id, + "PaginationConfig": { + "MaxItems": max_items, + "PageSize": page_size, + "StartingToken": starting_token, + }, + } + paginator = self.get_conn().get_paginator("get_query_results") + return paginator.paginate(**result_params) + + def poll_query_status( + self, query_execution_id: str, max_tries: Optional[int] = None + ) -> Optional[str]: + """ + Poll the status of submitted athena query until query state reaches final state. + Returns one of the final states + + :param query_execution_id: Id of submitted athena query + :type query_execution_id: str + :param max_tries: Number of times to poll for query state before function exits + :type max_tries: int + :return: str + """ + try_number = 1 + final_query_state = ( + None # Query state when query reaches final state or max_tries reached + ) + while True: + query_state = self.check_query_status(query_execution_id) + if query_state is None: + self.log.info( + "Trial %s: Invalid query state. Retrying again", try_number + ) + elif query_state in self.INTERMEDIATE_STATES: + self.log.info( + "Trial %s: Query is still in an intermediate state - %s", + try_number, + query_state, + ) + else: + self.log.info( + "Trial %s: Query execution completed. Final state is %s}", + try_number, + query_state, + ) + final_query_state = query_state + break + if max_tries and try_number >= max_tries: # Break loop if max_tries reached + final_query_state = query_state + break + try_number += 1 + sleep(self.sleep_time) + return final_query_state + + def stop_query(self, query_execution_id: str) -> Dict: + """ + Cancel the submitted athena query + + :param query_execution_id: Id of submitted athena query + :type query_execution_id: str + :return: dict + """ + return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id) diff --git a/reference/providers/amazon/aws/hooks/aws_dynamodb.py b/reference/providers/amazon/aws/hooks/aws_dynamodb.py new file mode 100644 index 0000000..fc6bd1b --- /dev/null +++ b/reference/providers/amazon/aws/hooks/aws_dynamodb.py @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`.""" + +import warnings + +# pylint: disable=unused-import +from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook # noqa + +warnings.warn( + "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/reference/providers/amazon/aws/hooks/base_aws.py b/reference/providers/amazon/aws/hooks/base_aws.py new file mode 100644 index 0000000..a22a67b --- /dev/null +++ b/reference/providers/amazon/aws/hooks/base_aws.py @@ -0,0 +1,596 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This module contains Base AWS Hook. + +.. seealso:: + For more information on how to use this hook, take a look at the guide: + :ref:`howto/connection:AWSHook` +""" + +import configparser +import datetime +import logging +from typing import Any, Dict, Optional, Tuple, Union + +import boto3 +import botocore +import botocore.session +from botocore.config import Config +from botocore.credentials import ReadOnlyCredentials + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.models.connection import Connection +from airflow.utils.log.logging_mixin import LoggingMixin +from dateutil.tz import tzlocal + + +class _SessionFactory(LoggingMixin): + def __init__( + self, conn: Connection, region_name: Optional[str], config: Config + ) -> None: + super().__init__() + self.conn = conn + self.region_name = region_name + self.config = config + self.extra_config = self.conn.extra_dejson + + def create_session(self) -> boto3.session.Session: + """Create AWS session.""" + session_kwargs = {} + if "session_kwargs" in self.extra_config: + self.log.info( + "Retrieving session_kwargs from Connection.extra_config['session_kwargs']: %s", + self.extra_config["session_kwargs"], + ) + session_kwargs = self.extra_config["session_kwargs"] + session = self._create_basic_session(session_kwargs=session_kwargs) + role_arn = self._read_role_arn_from_extra_config() + # If role_arn was specified then STS + assume_role + if role_arn is None: + return session + + return self._impersonate_to_role( + role_arn=role_arn, session=session, session_kwargs=session_kwargs + ) + + def _create_basic_session( + self, session_kwargs: Dict[str, Any] + ) -> boto3.session.Session: + ( + aws_access_key_id, + aws_secret_access_key, + ) = self._read_credentials_from_connection() + aws_session_token = self.extra_config.get("aws_session_token") + region_name = self.region_name + if self.region_name is None and "region_name" in self.extra_config: + self.log.info( + "Retrieving region_name from Connection.extra_config['region_name']" + ) + region_name = self.extra_config["region_name"] + self.log.info( + "Creating session with aws_access_key_id=%s region_name=%s", + aws_access_key_id, + region_name, + ) + + return boto3.session.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=region_name, + aws_session_token=aws_session_token, + **session_kwargs, + ) + + def _impersonate_to_role( + self, + role_arn: str, + session: boto3.session.Session, + session_kwargs: Dict[str, Any], + ) -> boto3.session.Session: + assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {}) + assume_role_method = self.extra_config.get("assume_role_method") + self.log.info("assume_role_method=%s", assume_role_method) + if not assume_role_method or assume_role_method == "assume_role": + sts_client = session.client("sts", config=self.config) + sts_response = self._assume_role( + sts_client=sts_client, + role_arn=role_arn, + assume_role_kwargs=assume_role_kwargs, + ) + elif assume_role_method == "assume_role_with_saml": + sts_client = session.client("sts", config=self.config) + sts_response = self._assume_role_with_saml( + sts_client=sts_client, + role_arn=role_arn, + assume_role_kwargs=assume_role_kwargs, + ) + elif assume_role_method == "assume_role_with_web_identity": + botocore_session = self._assume_role_with_web_identity( + role_arn=role_arn, + assume_role_kwargs=assume_role_kwargs, + base_session=session._session, # pylint: disable=protected-access + ) + return boto3.session.Session( + region_name=session.region_name, + botocore_session=botocore_session, + **session_kwargs, + ) + else: + raise NotImplementedError( + f"assume_role_method={assume_role_method} in Connection {self.conn.conn_id} Extra." + 'Currently "assume_role" or "assume_role_with_saml" are supported.' + '(Exclude this setting will default to "assume_role").' + ) + # Use credentials retrieved from STS + credentials = sts_response["Credentials"] + aws_access_key_id = credentials["AccessKeyId"] + aws_secret_access_key = credentials["SecretAccessKey"] + aws_session_token = credentials["SessionToken"] + self.log.info( + "Creating session with aws_access_key_id=%s region_name=%s", + aws_access_key_id, + session.region_name, + ) + + return boto3.session.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=session.region_name, + aws_session_token=aws_session_token, + **session_kwargs, + ) + + def _read_role_arn_from_extra_config(self) -> Optional[str]: + aws_account_id = self.extra_config.get("aws_account_id") + aws_iam_role = self.extra_config.get("aws_iam_role") + role_arn = self.extra_config.get("role_arn") + if role_arn is None and aws_account_id is not None and aws_iam_role is not None: + self.log.info("Constructing role_arn from aws_account_id and aws_iam_role") + role_arn = f"arn:aws:iam::{aws_account_id}:role/{aws_iam_role}" + self.log.info("role_arn is %s", role_arn) + return role_arn + + def _read_credentials_from_connection(self) -> Tuple[Optional[str], Optional[str]]: + aws_access_key_id = None + aws_secret_access_key = None + if self.conn.login: + aws_access_key_id = self.conn.login + aws_secret_access_key = self.conn.password + self.log.info("Credentials retrieved from login") + elif ( + "aws_access_key_id" in self.extra_config + and "aws_secret_access_key" in self.extra_config + ): + aws_access_key_id = self.extra_config["aws_access_key_id"] + aws_secret_access_key = self.extra_config["aws_secret_access_key"] + self.log.info("Credentials retrieved from extra_config") + elif "s3_config_file" in self.extra_config: + aws_access_key_id, aws_secret_access_key = _parse_s3_config( + self.extra_config["s3_config_file"], + self.extra_config.get("s3_config_format"), + self.extra_config.get("profile"), + ) + self.log.info("Credentials retrieved from extra_config['s3_config_file']") + else: + self.log.info("No credentials retrieved from Connection") + return aws_access_key_id, aws_secret_access_key + + def _assume_role( + self, + sts_client: boto3.client, + role_arn: str, + assume_role_kwargs: Dict[str, Any], + ) -> Dict: + if "external_id" in self.extra_config: # Backwards compatibility + assume_role_kwargs["ExternalId"] = self.extra_config.get("external_id") + role_session_name = f"Airflow_{self.conn.conn_id}" + self.log.info( + "Doing sts_client.assume_role to role_arn=%s (role_session_name=%s)", + role_arn, + role_session_name, + ) + return sts_client.assume_role( + RoleArn=role_arn, RoleSessionName=role_session_name, **assume_role_kwargs + ) + + def _assume_role_with_saml( + self, + sts_client: boto3.client, + role_arn: str, + assume_role_kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + saml_config = self.extra_config["assume_role_with_saml"] + principal_arn = saml_config["principal_arn"] + + idp_auth_method = saml_config["idp_auth_method"] + if idp_auth_method == "http_spegno_auth": + saml_assertion = self._fetch_saml_assertion_using_http_spegno_auth( + saml_config + ) + else: + raise NotImplementedError( + f"idp_auth_method={idp_auth_method} in Connection {self.conn.conn_id} Extra." + 'Currently only "http_spegno_auth" is supported, and must be specified.' + ) + + self.log.info("Doing sts_client.assume_role_with_saml to role_arn=%s", role_arn) + return sts_client.assume_role_with_saml( + RoleArn=role_arn, + PrincipalArn=principal_arn, + SAMLAssertion=saml_assertion, + **assume_role_kwargs, + ) + + def _fetch_saml_assertion_using_http_spegno_auth( + self, saml_config: Dict[str, Any] + ) -> str: + import requests + + # requests_gssapi will need paramiko > 2.6 since you'll need + # 'gssapi' not 'python-gssapi' from PyPi. + # https://github.com/paramiko/paramiko/pull/1311 + import requests_gssapi + from lxml import etree + + idp_url = saml_config["idp_url"] + self.log.info("idp_url= %s", idp_url) + idp_request_kwargs = saml_config["idp_request_kwargs"] + auth = requests_gssapi.HTTPSPNEGOAuth() + if "mutual_authentication" in saml_config: + mutual_auth = saml_config["mutual_authentication"] + if mutual_auth == "REQUIRED": + auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.REQUIRED) + elif mutual_auth == "OPTIONAL": + auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.OPTIONAL) + elif mutual_auth == "DISABLED": + auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.DISABLED) + else: + raise NotImplementedError( + f"mutual_authentication={mutual_auth} in Connection {self.conn.conn_id} Extra." + 'Currently "REQUIRED", "OPTIONAL" and "DISABLED" are supported.' + "(Exclude this setting will default to HTTPSPNEGOAuth() )." + ) + # Query the IDP + idp_response = requests.get(idp_url, auth=auth, **idp_request_kwargs) + idp_response.raise_for_status() + # Assist with debugging. Note: contains sensitive info! + xpath = saml_config["saml_response_xpath"] + log_idp_response = ( + "log_idp_response" in saml_config and saml_config["log_idp_response"] + ) + if log_idp_response: + self.log.warning( + "The IDP response contains sensitive information, but log_idp_response is ON (%s).", + log_idp_response, + ) + self.log.info("idp_response.content= %s", idp_response.content) + self.log.info("xpath= %s", xpath) + # Extract SAML Assertion from the returned HTML / XML + xml = etree.fromstring(idp_response.content) + saml_assertion = xml.xpath(xpath) + if isinstance(saml_assertion, list): + if len(saml_assertion) == 1: + saml_assertion = saml_assertion[0] + if not saml_assertion: + raise ValueError("Invalid SAML Assertion") + return saml_assertion + + def _assume_role_with_web_identity( + self, role_arn, assume_role_kwargs, base_session + ): + base_session = base_session or botocore.session.get_session() + client_creator = base_session.create_client + federation = self.extra_config.get("assume_role_with_web_identity_federation") + if federation == "google": + web_identity_token_loader = self._get_google_identity_token_loader() + else: + raise AirflowException( + f'Unsupported federation: {federation}. Currently "google" only are supported.' + ) + fetcher = botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher( + client_creator=client_creator, + web_identity_token_loader=web_identity_token_loader, + role_arn=role_arn, + extra_args=assume_role_kwargs or {}, + ) + aws_creds = botocore.credentials.DeferredRefreshableCredentials( + method="assume-role-with-web-identity", + refresh_using=fetcher.fetch_credentials, + time_fetcher=lambda: datetime.datetime.now(tz=tzlocal()), + ) + botocore_session = botocore.session.Session() + botocore_session._credentials = aws_creds # pylint: disable=protected-access + return botocore_session + + def _get_google_identity_token_loader(self): + from airflow.providers.google.common.utils.id_token_credentials import ( + get_default_id_token_credentials, + ) + from google.auth.transport import requests as requests_transport + + audience = self.extra_config.get( + "assume_role_with_web_identity_federation_audience" + ) + + google_id_token_credentials = get_default_id_token_credentials( + target_audience=audience + ) + + def web_identity_token_loader(): + if not google_id_token_credentials.valid: + request_adapter = requests_transport.Request() + google_id_token_credentials.refresh(request=request_adapter) + return google_id_token_credentials.token + + return web_identity_token_loader + + +class AwsBaseHook(BaseHook): + """ + Interact with AWS. + This class is a thin wrapper around the boto3 python library. + + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is None or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates. + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :type verify: Union[bool, str, None] + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :type region_name: Optional[str] + :param client_type: boto3.client client_type. Eg 's3', 'emr' etc + :type client_type: Optional[str] + :param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc + :type resource_type: Optional[str] + :param config: Configuration for botocore client. + (https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html) + :type config: Optional[botocore.client.Config] + """ + + conn_name_attr = "aws_conn_id" + default_conn_name = "aws_default" + conn_type = "aws" + hook_name = "Amazon Web Services" + + def __init__( + self, + aws_conn_id: Optional[str] = default_conn_name, + verify: Union[bool, str, None] = None, + region_name: Optional[str] = None, + client_type: Optional[str] = None, + resource_type: Optional[str] = None, + config: Optional[Config] = None, + ) -> None: + super().__init__() + self.aws_conn_id = aws_conn_id + self.verify = verify + self.client_type = client_type + self.resource_type = resource_type + self.region_name = region_name + self.config = config + + if not (self.client_type or self.resource_type): + raise AirflowException( + "Either client_type or resource_type must be provided." + ) + + def _get_credentials( + self, region_name: Optional[str] + ) -> Tuple[boto3.session.Session, Optional[str]]: + + if not self.aws_conn_id: + session = boto3.session.Session(region_name=region_name) + return session, None + + self.log.info("Airflow Connection: aws_conn_id=%s", self.aws_conn_id) + + try: + # Fetch the Airflow connection object + connection_object = self.get_connection(self.aws_conn_id) + extra_config = connection_object.extra_dejson + endpoint_url = extra_config.get("host") + + # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html#botocore.config.Config + if "config_kwargs" in extra_config: + self.log.info( + "Retrieving config_kwargs from Connection.extra_config['config_kwargs']: %s", + extra_config["config_kwargs"], + ) + self.config = Config(**extra_config["config_kwargs"]) + + session = _SessionFactory( + conn=connection_object, region_name=region_name, config=self.config + ).create_session() + + return session, endpoint_url + + except AirflowException: + self.log.warning("Unable to use Airflow Connection for credentials.") + self.log.info("Fallback on boto3 credential strategy") + # http://boto3.readthedocs.io/en/latest/guide/configuration.html + + self.log.info( + "Creating session using boto3 credential strategy region_name=%s", + region_name, + ) + session = boto3.session.Session(region_name=region_name) + return session, None + + def get_client_type( + self, + client_type: str, + region_name: Optional[str] = None, + config: Optional[Config] = None, + ) -> boto3.client: + """Get the underlying boto3 client using boto3 session""" + session, endpoint_url = self._get_credentials(region_name) + + # No AWS Operators use the config argument to this method. + # Keep backward compatibility with other users who might use it + if config is None: + config = self.config + + return session.client( + client_type, endpoint_url=endpoint_url, config=config, verify=self.verify + ) + + def get_resource_type( + self, + resource_type: str, + region_name: Optional[str] = None, + config: Optional[Config] = None, + ) -> boto3.re# + """Get the underlying boto3 resource using boto3 session""" + session, endpoint_url = self._get_credentials(region_name) + + # No AWS Operators use the config argument to this method. + # Keep backward compatibility with other users who might use it + if config is None: + config = self.config + + return session.resource( + resource_type, endpoint_url=endpoint_url, config=config, verify=self.verify + ) + + @cached_property + def conn(self) -> Union[boto3.client, boto3.resource]: + """ + Get the underlying boto3 client/resource (cached) + + :return: boto3.client or boto3.resource + :rtype: Union[boto3.client, boto3.resource] + """ + if self.client_type: + return self.get_client_type(self.client_type, region_name=self.region_name) + elif self.resource_type: + return self.get_resource_type( + self.resource_type, region_name=self.region_name + ) + else: + # Rare possibility - subclasses have not specified a client_type or resource_type + raise NotImplementedError("Could not get boto3 connection!") + + def get_conn(self) -> Union[boto3.client, boto3.resource]: + """ + Get the underlying boto3 client/resource (cached) + + Implemented so that caching works as intended. It exists for compatibility + with subclasses that rely on a super().get_conn() method. + + :return: boto3.client or boto3.resource + :rtype: Union[boto3.client, boto3.resource] + """ + # Compat shim + return self.conn + + def get_session(self, region_name: Optional[str] = None) -> boto3.session.Session: + """Get the underlying boto3.session.""" + session, _ = self._get_credentials(region_name) + return session + + def get_credentials(self, region_name: Optional[str] = None) -> ReadOnlyCredentials: + """ + Get the underlying `botocore.Credentials` object. + + This contains the following authentication attributes: access_key, secret_key and token. + """ + session, _ = self._get_credentials(region_name) + # Credentials are refreshable, so accessing your access key and + # secret key separately can lead to a race condition. + # See https://stackoverflow.com/a/36291428/8283373 + return session.get_credentials().get_frozen_credentials() + + def expand_role(self, role: str) -> str: + """ + If the IAM role is a role name, get the Amazon Resource Name (ARN) for the role. + If IAM role is already an IAM role ARN, no change is made. + + :param role: IAM role name or ARN + :return: IAM role ARN + """ + if "/" in role: + return role + else: + return self.get_client_type("iam").get_role(RoleName=role)["Role"]["Arn"] + + +def _parse_s3_config( + config_file_name: str, + config_format: Optional[str] = "boto", + profile: Optional[str] = None, +) -> Tuple[Optional[str], Optional[str]]: + """ + Parses a config file for s3 credentials. Can currently + parse boto, s3cmd.conf and AWS SDK config formats + + :param config_file_name: path to the config file + :type config_file_name: str + :param config_format: config type. One of "boto", "s3cmd" or "aws". + Defaults to "boto" + :type config_format: str + :param profile: profile name in AWS type config file + :type profile: str + """ + config = configparser.ConfigParser() + if config.read(config_file_name): # pragma: no cover + sections = config.sections() + else: + raise AirflowException(f"Couldn't read {config_file_name}") + # Setting option names depending on file format + if config_format is None: + config_format = "boto" + conf_format = config_format.lower() + if conf_format == "boto": # pragma: no cover + if profile is not None and "profile " + profile in sections: + cred_section = "profile " + profile + else: + cred_section = "Credentials" + elif conf_format == "aws" and profile is not None: + cred_section = profile + else: + cred_section = "default" + # Option names + if conf_format in ("boto", "aws"): # pragma: no cover + key_id_option = "aws_access_key_id" + secret_key_option = "aws_secret_access_key" + # security_token_option = 'aws_security_token' + else: + key_id_option = "access_key" + secret_key_option = "secret_key" + # Actual Parsing + if cred_section not in sections: + raise AirflowException("This config file format is not recognized") + else: + try: + access_key = config.get(cred_section, key_id_option) + secret_key = config.get(cred_section, secret_key_option) + except Exception: + logging.warning("Option Error in parsing s3 config file") + raise + return access_key, secret_key diff --git a/reference/providers/amazon/aws/hooks/batch_client.py b/reference/providers/amazon/aws/hooks/batch_client.py new file mode 100644 index 0000000..2d63264 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/batch_client.py @@ -0,0 +1,556 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +A client for AWS batch services + +.. seealso:: + + - http://boto3.readthedocs.io/en/latest/guide/configuration.html + - http://boto3.readthedocs.io/en/latest/reference/services/batch.html + - https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html +""" + +from random import uniform +from time import sleep +from typing import Dict, List, Optional, Union + +import botocore.client +import botocore.exceptions +import botocore.waiter +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.typing_compat import Protocol, runtime_checkable + +# Add exceptions to pylint for the boto3 protocol only; ideally the boto3 library +# could provide +# protocols for all their dynamically generated classes (try to migrate this to a PR on botocore). +# Note that the use of invalid-name parameters should be restricted to the boto3 mappings only; +# all the Airflow wrappers of boto3 clients should not adopt invalid-names to match boto3. +# pylint: disable=invalid-name, unused-argument + + +@runtime_checkable +class AwsBatchProtocol(Protocol): + """ + A structured Protocol for ``boto3.client('batch') -> botocore.client.Batch``. + This is used for type hints on :py:meth:`.AwsBatchClient.client`; it covers + only the subset of client methods required. + + .. seealso:: + + - https://mypy.readthedocs.io/en/latest/protocols.html + - http://boto3.readthedocs.io/en/latest/reference/services/batch.html + """ + + def describe_jobs(self, jobs: List[str]) -> Dict: + """ + Get job descriptions from AWS batch + + :param jobs: a list of JobId to describe + :type jobs: List[str] + + :return: an API response to describe jobs + :rtype: Dict + """ + ... + + def get_waiter(self, waiterName: str) -> botocore.waiter.Waiter: + """ + Get an AWS Batch service waiter + + :param waiterName: The name of the waiter. The name should match + the name (including the casing) of the key name in the waiter + model file (typically this is CamelCasing). + :type waiterName: str + + :return: a waiter object for the named AWS batch service + :rtype: botocore.waiter.Waiter + + .. note:: + AWS batch might not have any waiters (until botocore PR-1307 is released). + + .. code-block:: python + + import boto3 + boto3.client('batch').waiter_names == [] + + .. seealso:: + + - https://boto3.amazonaws.com/v1/documentation/api/latest/guide/clients.html#waiters + - https://github.com/boto/botocore/pull/1307 + """ + ... + + def submit_job( + self, + jobName: str, + jobQueue: str, + jobDefinition: str, + arrayProperties: Dict, + parameters: Dict, + containerOverrides: Dict, + tags: Dict, + ) -> Dict: + """ + Submit a batch job + + :param jobName: the name for the AWS batch job + :type jobName: str + + :param jobQueue: the queue name on AWS Batch + :type jobQueue: str + + :param jobDefinition: the job definition name on AWS Batch + :type jobDefinition: str + + :param arrayProperties: the same parameter that boto3 will receive + :type arrayProperties: Dict + + :param parameters: the same parameter that boto3 will receive + :type parameters: Dict + + :param containerOverrides: the same parameter that boto3 will receive + :type containerOverrides: Dict + + :param tags: the same parameter that boto3 will receive + :type tags: Dict + + :return: an API response + :rtype: Dict + """ + ... + + def terminate_job(self, jobId: str, reason: str) -> Dict: + """ + Terminate a batch job + + :param jobId: a job ID to terminate + :type jobId: str + + :param reason: a reason to terminate job ID + :type reason: str + + :return: an API response + :rtype: Dict + """ + ... + + +# Note that the use of invalid-name parameters should be restricted to the boto3 mappings only; +# all the Airflow wrappers of boto3 clients should not adopt invalid-names to match boto3. +# pylint: enable=invalid-name, unused-argument + + +class AwsBatchClientHook(AwsBaseHook): + """ + A client for AWS batch services. + + :param max_retries: exponential back-off retries, 4200 = 48 hours; + polling is only used when waiters is None + :type max_retries: Optional[int] + + :param status_retries: number of HTTP retries to get job status, 10; + polling is only used when waiters is None + :type status_retries: Optional[int] + + .. note:: + Several methods use a default random delay to check or poll for job status, i.e. + ``random.uniform(DEFAULT_DELAY_MIN, DEFAULT_DELAY_MAX)`` + Using a random interval helps to avoid AWS API throttle limits + when many concurrent tasks request job-descriptions. + + To modify the global defaults for the range of jitter allowed when a + random delay is used to check batch job status, modify these defaults, e.g.: + .. code-block:: + + AwsBatchClient.DEFAULT_DELAY_MIN = 0 + AwsBatchClient.DEFAULT_DELAY_MAX = 5 + + When explicit delay values are used, a 1 second random jitter is applied to the + delay (e.g. a delay of 0 sec will be a ``random.uniform(0, 1)`` delay. It is + generally recommended that random jitter is added to API requests. A + convenience method is provided for this, e.g. to get a random delay of + 10 sec +/- 5 sec: ``delay = AwsBatchClient.add_jitter(10, width=5, minima=0)`` + + .. seealso:: + - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html + - https://docs.aws.amazon.com/general/latest/gr/api-retries.html + - https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ + """ + + MAX_RETRIES = 4200 + STATUS_RETRIES = 10 + + # delays are in seconds + DEFAULT_DELAY_MIN = 1 + DEFAULT_DELAY_MAX = 10 + + def __init__( + self, + *args, + max_retries: Optional[int] = None, + status_retries: Optional[int] = None, + **kwargs, + ) -> None: + # https://github.com/python/mypy/issues/6799 hence type: ignore + super().__init__(client_type="batch", *args, **kwargs) # type: ignore + self.max_retries = max_retries or self.MAX_RETRIES + self.status_retries = status_retries or self.STATUS_RETRIES + + @property + def client( + self, + ) -> Union[AwsBatchProtocol, botocore.client.BaseClient]: # noqa: D402 + """ + An AWS API client for batch services, like ``boto3.client('batch')`` + + :return: a boto3 'batch' client for the ``.region_name`` + :rtype: Union[AwsBatchProtocol, botocore.client.BaseClient] + """ + return self.conn + + def terminate_job(self, job_id: str, reason: str) -> Dict: + """ + Terminate a batch job + + :param job_id: a job ID to terminate + :type job_id: str + + :param reason: a reason to terminate job ID + :type reason: str + + :return: an API response + :rtype: Dict + """ + response = self.get_conn().terminate_job(jobId=job_id, reason=reason) + self.log.info(response) + return response + + def check_job_success(self, job_id: str) -> bool: + """ + Check the final status of the batch job; return True if the job + 'SUCCEEDED', else raise an AirflowException + + :param job_id: a batch job ID + :type job_id: str + + :rtype: bool + + :raises: AirflowException + """ + job = self.get_job_description(job_id) + job_status = job.get("status") + + if job_status == "SUCCEEDED": + self.log.info("AWS batch job (%s) succeeded: %s", job_id, job) + return True + + if job_status == "FAILED": + raise AirflowException(f"AWS Batch job ({job_id}) failed: {job}") + + if job_status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]: + raise AirflowException(f"AWS Batch job ({job_id}) is not complete: {job}") + + raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}") + + def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None: + """ + Wait for batch job to complete + + :param job_id: a batch job ID + :type job_id: str + + :param delay: a delay before polling for job status + :type delay: Optional[Union[int, float]] + + :raises: AirflowException + """ + self.delay(delay) + self.poll_for_job_running(job_id, delay) + self.poll_for_job_complete(job_id, delay) + self.log.info("AWS Batch job (%s) has completed", job_id) + + def poll_for_job_running( + self, job_id: str, delay: Union[int, float, None] = None + ) -> None: + """ + Poll for job running. The status that indicates a job is running or + already complete are: 'RUNNING'|'SUCCEEDED'|'FAILED'. + + So the status options that this will wait for are the transitions from: + 'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'|'SUCCEEDED'|'FAILED' + + The completed status options are included for cases where the status + changes too quickly for polling to detect a RUNNING status that moves + quickly from STARTING to RUNNING to completed (often a failure). + + :param job_id: a batch job ID + :type job_id: str + + :param delay: a delay before polling for job status + :type delay: Optional[Union[int, float]] + + :raises: AirflowException + """ + self.delay(delay) + running_status = ["RUNNING", "SUCCEEDED", "FAILED"] + self.poll_job_status(job_id, running_status) + + def poll_for_job_complete( + self, job_id: str, delay: Union[int, float, None] = None + ) -> None: + """ + Poll for job completion. The status that indicates job completion + are: 'SUCCEEDED'|'FAILED'. + + So the status options that this will wait for are the transitions from: + 'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED' + + :param job_id: a batch job ID + :type job_id: str + + :param delay: a delay before polling for job status + :type delay: Optional[Union[int, float]] + + :raises: AirflowException + """ + self.delay(delay) + complete_status = ["SUCCEEDED", "FAILED"] + self.poll_job_status(job_id, complete_status) + + def poll_job_status(self, job_id: str, match_status: List[str]) -> bool: + """ + Poll for job status using an exponential back-off strategy (with max_retries). + + :param job_id: a batch job ID + :type job_id: str + + :param match_status: a list of job status to match; the batch job status are: + 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED' + :type match_status: List[str] + + :rtype: bool + + :raises: AirflowException + """ + retries = 0 + while True: + + job = self.get_job_description(job_id) + job_status = job.get("status") + self.log.info( + "AWS Batch job (%s) check status (%s) in %s", + job_id, + job_status, + match_status, + ) + + if job_status in match_status: + return True + + if retries >= self.max_retries: + raise AirflowException( + f"AWS Batch job ({job_id}) status checks exceed max_retries" + ) + + retries += 1 + pause = self.exponential_delay(retries) + self.log.info( + "AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds", + job_id, + retries, + self.max_retries, + pause, + ) + self.delay(pause) + + def get_job_description(self, job_id: str) -> Dict: + """ + Get job description (using status_retries). + + :param job_id: a batch job ID + :type job_id: str + + :return: an API response for describe jobs + :rtype: Dict + + :raises: AirflowException + """ + retries = 0 + while True: + try: + response = self.get_conn().describe_jobs(jobs=[job_id]) + return self.parse_job_description(job_id, response) + + except botocore.exceptions.ClientError as err: + error = err.response.get("Error", {}) + if error.get("Code") == "TooManyRequestsException": + pass # allow it to retry, if possible + else: + raise AirflowException( + f"AWS Batch job ({job_id}) description error: {err}" + ) + + retries += 1 + if retries >= self.status_retries: + raise AirflowException( + "AWS Batch job ({}) description error: exceeded " + "status_retries ({})".format(job_id, self.status_retries) + ) + + pause = self.exponential_delay(retries) + self.log.info( + "AWS Batch job (%s) description retry (%d of %d) in the next %.2f seconds", + job_id, + retries, + self.status_retries, + pause, + ) + self.delay(pause) + + @staticmethod + def parse_job_description(job_id: str, response: Dict) -> Dict: + """ + Parse job description to extract description for job_id + + :param job_id: a batch job ID + :type job_id: str + + :param response: an API response for describe jobs + :type response: Dict + + :return: an API response to describe job_id + :rtype: Dict + + :raises: AirflowException + """ + jobs = response.get("jobs", []) + matching_jobs = [job for job in jobs if job.get("jobId") == job_id] + if len(matching_jobs) != 1: + raise AirflowException( + f"AWS Batch job ({job_id}) description error: response: {response}" + ) + + return matching_jobs[0] + + @staticmethod + def add_jitter( + delay: Union[int, float], + width: Union[int, float] = 1, + minima: Union[int, float] = 0, + ) -> float: + """ + Use delay +/- width for random jitter + + Adding jitter to status polling can help to avoid + AWS batch API limits for monitoring batch jobs with + a high concurrency in Airflow tasks. + + :param delay: number of seconds to pause; + delay is assumed to be a positive number + :type delay: Union[int, float] + + :param width: delay +/- width for random jitter; + width is assumed to be a positive number + :type width: Union[int, float] + + :param minima: minimum delay allowed; + minima is assumed to be a non-negative number + :type minima: Union[int, float] + + :return: uniform(delay - width, delay + width) jitter + and it is a non-negative number + :rtype: float + """ + delay = abs(delay) + width = abs(width) + minima = abs(minima) + lower = max(minima, delay - width) + upper = delay + width + return uniform(lower, upper) + + @staticmethod + def delay(delay: Union[int, float, None] = None) -> None: + """ + Pause execution for ``delay`` seconds. + + :param delay: a delay to pause execution using ``time.sleep(delay)``; + a small 1 second jitter is applied to the delay. + :type delay: Optional[Union[int, float]] + + .. note:: + This method uses a default random delay, i.e. + ``random.uniform(DEFAULT_DELAY_MIN, DEFAULT_DELAY_MAX)``; + using a random interval helps to avoid AWS API throttle limits + when many concurrent tasks request job-descriptions. + """ + if delay is None: + delay = uniform( + AwsBatchClientHook.DEFAULT_DELAY_MIN, + AwsBatchClientHook.DEFAULT_DELAY_MAX, + ) + else: + delay = AwsBatchClientHook.add_jitter(delay) + sleep(delay) + + @staticmethod + def exponential_delay(tries: int) -> float: + """ + An exponential back-off delay, with random jitter. There is a maximum + interval of 10 minutes (with random jitter between 3 and 10 minutes). + This is used in the :py:meth:`.poll_for_job_status` method. + + :param tries: Number of tries + :type tries: int + + :rtype: float + + Examples of behavior: + + .. code-block:: python + + def exp(tries): + max_interval = 600.0 # 10 minutes in seconds + delay = 1 + pow(tries * 0.6, 2) + delay = min(max_interval, delay) + print(delay / 3, delay) + + for tries in range(10): + exp(tries) + + # 0.33 1.0 + # 0.45 1.35 + # 0.81 2.44 + # 1.41 4.23 + # 2.25 6.76 + # 3.33 10.00 + # 4.65 13.95 + # 6.21 18.64 + # 8.01 24.04 + # 10.05 30.15 + + .. seealso:: + + - https://docs.aws.amazon.com/general/latest/gr/api-retries.html + - https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ + """ + max_interval = 600.0 # results in 3 to 10 minute delay + delay = 1 + pow(tries * 0.6, 2) + delay = min(max_interval, delay) + return uniform(delay / 3, delay) diff --git a/reference/providers/amazon/aws/hooks/batch_waiters.json b/reference/providers/amazon/aws/hooks/batch_waiters.json new file mode 100644 index 0000000..2ed9638 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/batch_waiters.json @@ -0,0 +1,105 @@ +{ + "version": 2, + "waiters": { + "JobExists": { + "delay": 2, + "operation": "DescribeJobs", + "maxAttempts": 100, + "acceptors": [ + { + "argument": "jobs[].status", + "expected": "SUBMITTED", + "matcher": "pathAll", + "state": "success" + }, + { + "argument": "jobs[].status", + "expected": "PENDING", + "matcher": "pathAll", + "state": "success" + }, + { + "argument": "jobs[].status", + "expected": "RUNNABLE", + "matcher": "pathAll", + "state": "success" + }, + { + "argument": "jobs[].status", + "expected": "STARTING", + "matcher": "pathAll", + "state": "success" + }, + { + "argument": "jobs[].status", + "expected": "RUNNING", + "matcher": "pathAll", + "state": "success" + }, + { + "argument": "jobs[].status", + "expected": "FAILED", + "matcher": "pathAll", + "state": "success" + }, + { + "argument": "jobs[].status", + "expected": "SUCCEEDED", + "matcher": "pathAll", + "state": "success" + }, + { + "argument": "jobs[].status", + "expected": "", + "matcher": "error", + "state": "failure" + } + ] + }, + "JobRunning": { + "delay": 5, + "operation": "DescribeJobs", + "maxAttempts": 100, + "acceptors": [ + { + "argument": "jobs[].status", + "expected": "RUNNING", + "matcher": "pathAll", + "state": "success" + }, + { + "argument": "jobs[].status", + "expected": "FAILED", + "matcher": "pathAll", + "state": "success" + }, + { + "argument": "jobs[].status", + "expected": "SUCCEEDED", + "matcher": "pathAll", + "state": "success" + } + ] + }, + "JobComplete": { + "delay": 300, + "operation": "DescribeJobs", + "maxAttempts": 288, + "description": "Wait until job status is SUCCEEDED or FAILED", + "acceptors": [ + { + "argument": "jobs[].status", + "expected": "SUCCEEDED", + "matcher": "pathAll", + "state": "success" + }, + { + "argument": "jobs[].status", + "expected": "FAILED", + "matcher": "pathAny", + "state": "failure" + } + ] + } + } +} diff --git a/reference/providers/amazon/aws/hooks/batch_waiters.py b/reference/providers/amazon/aws/hooks/batch_waiters.py new file mode 100644 index 0000000..de09f98 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/batch_waiters.py @@ -0,0 +1,242 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +""" +AWS batch service waiters + +.. seealso:: + + - https://boto3.amazonaws.com/v1/documentation/api/latest/guide/clients.html#waiters + - https://github.com/boto/botocore/blob/develop/botocore/waiter.py +""" + +import json +import sys +from copy import deepcopy +from pathlib import Path +from typing import Dict, List, Optional, Union + +import botocore.client +import botocore.exceptions +import botocore.waiter +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.batch_client import AwsBatchClientHook + + +class AwsBatchWaitersHook(AwsBatchClientHook): + """ + A utility to manage waiters for AWS batch services + Examples: + + .. code-block:: python + + import random + from airflow.providers.amazon.aws.operators.batch_waiters import AwsBatchWaiters + + # to inspect default waiters + waiters = AwsBatchWaiters() + config = waiters.default_config # type: Dict + waiter_names = waiters.list_waiters() # -> ["JobComplete", "JobExists", "JobRunning"] + + # The default_config is a useful stepping stone to creating custom waiters, e.g. + custom_config = waiters.default_config # this is a deepcopy + # modify custom_config['waiters'] as necessary and get a new instance: + waiters = AwsBatchWaiters(waiter_config=custom_config) + waiters.waiter_config # check the custom configuration (this is a deepcopy) + waiters.list_waiters() # names of custom waiters + + # During the init for AwsBatchWaiters, the waiter_config is used to build a waiter_model; + # and note that this only occurs during the class init, to avoid any accidental mutations + # of waiter_config leaking into the waiter_model. + waiters.waiter_model # -> botocore.waiter.WaiterModel object + + # The waiter_model is combined with the waiters.client to get a specific waiter + # and the details of the config on that waiter can be further modified without any + # accidental impact on the generation of new waiters from the defined waiter_model, e.g. + waiters.get_waiter("JobExists").config.delay # -> 5 + waiter = waiters.get_waiter("JobExists") # -> botocore.waiter.Batch.Waiter.JobExists object + waiter.config.delay = 10 + waiters.get_waiter("JobExists").config.delay # -> 5 as defined by waiter_model + + # To use a specific waiter, update the config and call the `wait()` method for jobId, e.g. + waiter = waiters.get_waiter("JobExists") # -> botocore.waiter.Batch.Waiter.JobExists object + waiter.config.delay = random.uniform(1, 10) # seconds + waiter.config.max_attempts = 10 + waiter.wait(jobs=[jobId]) + + .. seealso:: + + - https://www.2ndwatch.com/blog/use-waiters-boto3-write/ + - https://github.com/boto/botocore/blob/develop/botocore/waiter.py + - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html#waiters + - https://github.com/boto/botocore/tree/develop/botocore/data/ec2/2016-11-15 + - https://github.com/boto/botocore/issues/1915 + + :param waiter_config: a custom waiter configuration for AWS batch services + :type waiter_config: Optional[Dict] + + :param aws_conn_id: connection id of AWS credentials / region name. If None, + credential boto3 strategy will be used + (http://boto3.readthedocs.io/en/latest/guide/configuration.html). + :type aws_conn_id: Optional[str] + + :param region_name: region name to use in AWS client. + Override the AWS region in connection (if provided) + :type region_name: Optional[str] + """ + + def __init__(self, *args, waiter_config: Optional[Dict] = None, **kwargs) -> None: + + super().__init__(*args, **kwargs) + + self._default_config = None # type: Optional[Dict] + self._waiter_config = waiter_config or self.default_config + self._waiter_model = botocore.waiter.WaiterModel(self._waiter_config) + + @property + def default_config(self) -> Dict: + """ + An immutable default waiter configuration + + :return: a waiter configuration for AWS batch services + :rtype: Dict + """ + if self._default_config is None: + config_path = Path(__file__).with_name("batch_waiters.json").absolute() + with open(config_path) as config_file: + self._default_config = json.load(config_file) + return deepcopy(self._default_config) # avoid accidental mutation + + @property + def waiter_config(self) -> Dict: + """ + An immutable waiter configuration for this instance; a ``deepcopy`` is returned by this + property. During the init for AwsBatchWaiters, the waiter_config is used to build a + waiter_model and this only occurs during the class init, to avoid any accidental + mutations of waiter_config leaking into the waiter_model. + + :return: a waiter configuration for AWS batch services + :rtype: Dict + """ + return deepcopy(self._waiter_config) # avoid accidental mutation + + @property + def waiter_model(self) -> botocore.waiter.WaiterModel: + """ + A configured waiter model used to generate waiters on AWS batch services. + + :return: a waiter model for AWS batch services + :rtype: botocore.waiter.WaiterModel + """ + return self._waiter_model + + def get_waiter(self, waiter_name: str) -> botocore.waiter.Waiter: + """ + Get an AWS Batch service waiter, using the configured ``.waiter_model``. + + The ``.waiter_model`` is combined with the ``.client`` to get a specific waiter and + the properties of that waiter can be modified without any accidental impact on the + generation of new waiters from the ``.waiter_model``, e.g. + .. code-block:: + + waiters.get_waiter("JobExists").config.delay # -> 5 + waiter = waiters.get_waiter("JobExists") # a new waiter object + waiter.config.delay = 10 + waiters.get_waiter("JobExists").config.delay # -> 5 as defined by waiter_model + + To use a specific waiter, update the config and call the `wait()` method for jobId, e.g. + .. code-block:: + + import random + waiter = waiters.get_waiter("JobExists") # a new waiter object + waiter.config.delay = random.uniform(1, 10) # seconds + waiter.config.max_attempts = 10 + waiter.wait(jobs=[jobId]) + + :param waiter_name: The name of the waiter. The name should match + the name (including the casing) of the key name in the waiter + model file (typically this is CamelCasing); see ``.list_waiters``. + :type waiter_name: str + + :return: a waiter object for the named AWS batch service + :rtype: botocore.waiter.Waiter + """ + return botocore.waiter.create_waiter_with_client( + waiter_name, self.waiter_model, self.client + ) + + def list_waiters(self) -> List[str]: + """ + List the waiters in a waiter configuration for AWS Batch services. + + :return: waiter names for AWS batch services + :rtype: List[str] + """ + return self.waiter_model.waiter_names + + def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None: + """ + Wait for batch job to complete. This assumes that the ``.waiter_model`` is configured + using some variation of the ``.default_config`` so that it can generate waiters with the + following names: "JobExists", "JobRunning" and "JobComplete". + + :param job_id: a batch job ID + :type job_id: str + + :param delay: A delay before polling for job status + :type delay: Union[int, float, None] + + :raises: AirflowException + + .. note:: + This method adds a small random jitter to the ``delay`` (+/- 2 sec, >= 1 sec). + Using a random interval helps to avoid AWS API throttle limits when many + concurrent tasks request job-descriptions. + + It also modifies the ``max_attempts`` to use the ``sys.maxsize``, + which allows Airflow to manage the timeout on waiting. + """ + self.delay(delay) + try: + waiter = self.get_waiter("JobExists") + waiter.config.delay = self.add_jitter( + waiter.config.delay, width=2, minima=1 + ) + waiter.config.max_attempts = sys.maxsize # timeout is managed by Airflow + waiter.wait(jobs=[job_id]) + + waiter = self.get_waiter("JobRunning") + waiter.config.delay = self.add_jitter( + waiter.config.delay, width=2, minima=1 + ) + waiter.config.max_attempts = sys.maxsize # timeout is managed by Airflow + waiter.wait(jobs=[job_id]) + + waiter = self.get_waiter("JobComplete") + waiter.config.delay = self.add_jitter( + waiter.config.delay, width=2, minima=1 + ) + waiter.config.max_attempts = sys.maxsize # timeout is managed by Airflow + waiter.wait(jobs=[job_id]) + + except ( + botocore.exceptions.ClientError, + botocore.exceptions.WaiterError, + ) as err: + raise AirflowException(err) diff --git a/reference/providers/amazon/aws/hooks/cloud_formation.py b/reference/providers/amazon/aws/hooks/cloud_formation.py new file mode 100644 index 0000000..cec5a8b --- /dev/null +++ b/reference/providers/amazon/aws/hooks/cloud_formation.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains AWS CloudFormation Hook""" +from typing import Optional, Union + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from boto3 import client, resource +from botocore.exceptions import ClientError + + +class AWSCloudFormationHook(AwsBaseHook): + """ + Interact with AWS CloudFormation. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs): + super().__init__(client_type="cloudformation", *args, **kwargs) + + def get_stack_status(self, stack_name: Union[client, resource]) -> Optional[dict]: + """Get stack status from CloudFormation.""" + self.log.info("Poking for stack %s", stack_name) + + try: + stacks = self.get_conn().describe_stacks(StackName=stack_name)["Stacks"] + return stacks[0]["StackStatus"] + except ClientError as e: + if "does not exist" in str(e): + return None + else: + raise e + + def create_stack(self, stack_name: str, params: dict) -> None: + """ + Create stack in CloudFormation. + + :param stack_name: stack_name. + :type stack_name: str + :param params: parameters to be passed to CloudFormation. + :type params: dict + """ + if "StackName" not in params: + params["StackName"] = stack_name + self.get_conn().create_stack(**params) + + def delete_stack(self, stack_name: str, params: Optional[dict] = None) -> None: + """ + Delete stack in CloudFormation. + + :param stack_name: stack_name. + :type stack_name: str + :param params: parameters to be passed to CloudFormation (optional). + :type params: dict + """ + params = params or {} + if "StackName" not in params: + params["StackName"] = stack_name + self.get_conn().delete_stack(**params) diff --git a/reference/providers/amazon/aws/hooks/datasync.py b/reference/providers/amazon/aws/hooks/datasync.py new file mode 100644 index 0000000..00eff13 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/datasync.py @@ -0,0 +1,334 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Interact with AWS DataSync, using the AWS ``boto3`` library.""" + +import time +from typing import List, Optional + +from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowTaskTimeout +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class AWSDataSyncHook(AwsBaseHook): + """ + Interact with AWS DataSync. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + :class:`~airflow.providers.amazon.aws.operators.datasync.AWSDataSyncOperator` + + :param int wait_for_task_execution: Time to wait between two + consecutive calls to check TaskExecution status. + :raises ValueError: If wait_interval_seconds is not between 0 and 15*60 seconds. + """ + + TASK_EXECUTION_INTERMEDIATE_STATES = ( + "INITIALIZING", + "QUEUED", + "LAUNCHING", + "PREPARING", + "TRANSFERRING", + "VERIFYING", + ) + TASK_EXECUTION_FAILURE_STATES = ("ERROR",) + TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",) + + def __init__(self, wait_interval_seconds: int = 5, *args, **kwargs) -> None: + super().__init__(client_type="datasync", *args, **kwargs) # type: ignore[misc] + self.locations: list = [] + self.tasks: list = [] + # wait_interval_seconds = 0 is used during unit tests + if wait_interval_seconds < 0 or wait_interval_seconds > 15 * 60: + raise ValueError(f"Invalid wait_interval_seconds {wait_interval_seconds}") + self.wait_interval_seconds = wait_interval_seconds + + def create_location(self, location_uri: str, **create_location_kwargs) -> str: + r"""Creates a new location. + + :param str location_uri: Location URI used to determine the location type (S3, SMB, NFS, EFS). + :param create_location_kwargs: Passed to ``boto.create_location_xyz()``. + See AWS boto3 datasync documentation. + :return str: LocationArn of the created Location. + :raises AirflowException: If location type (prefix from ``location_uri``) is invalid. + """ + typ = location_uri.split(":")[0] + if typ == "smb": + location = self.get_conn().create_location_smb(**create_location_kwargs) + elif typ == "s3": + location = self.get_conn().create_location_s3(**create_location_kwargs) + elif typ == "nfs": + location = self.get_conn().create_loction_nfs(**create_location_kwargs) + elif typ == "efs": + location = self.get_conn().create_loction_efs(**create_location_kwargs) + else: + raise AirflowException(f"Invalid location type: {typ}") + self._refresh_locations() + return location["LocationArn"] + + def get_location_arns( + self, + location_uri: str, + case_sensitive: bool = False, + ignore_trailing_slash: bool = True, + ) -> List[str]: + """ + Return all LocationArns which match a LocationUri. + + :param str location_uri: Location URI to search for, eg ``s3://mybucket/mypath`` + :param bool case_sensitive: Do a case sensitive search for location URI. + :param bool ignore_trailing_slash: Ignore / at the end of URI when matching. + :return: List of LocationArns. + :rtype: list(str) + :raises AirflowBadRequest: if ``location_uri`` is empty + """ + if not location_uri: + raise AirflowBadRequest("location_uri not specified") + if not self.locations: + self._refresh_locations() + result = [] + + if not case_sensitive: + location_uri = location_uri.lower() + if ignore_trailing_slash and location_uri.endswith("/"): + location_uri = location_uri[:-1] + + for location_from_aws in self.locations: + location_uri_from_aws = location_from_aws["LocationUri"] + if not case_sensitive: + location_uri_from_aws = location_uri_from_aws.lower() + if ignore_trailing_slash and location_uri_from_aws.endswith("/"): + location_uri_from_aws = location_uri_from_aws[:-1] + if location_uri == location_uri_from_aws: + result.append(location_from_aws["LocationArn"]) + return result + + def _refresh_locations(self) -> None: + """Refresh the local list of Locations.""" + self.locations = [] + next_token = None + while True: + if next_token: + locations = self.get_conn().list_locations(NextToken=next_token) + else: + locations = self.get_conn().list_locations() + self.locations.extend(locations["Locations"]) + if "NextToken" not in locations: + break + next_token = locations["NextToken"] + + def create_task( + self, + source_location_arn: str, + destination_location_arn: str, + **create_task_kwargs, + ) -> str: + r"""Create a Task between the specified source and destination LocationArns. + + :param str source_location_arn: Source LocationArn. Must exist already. + :param str destination_location_arn: Destination LocationArn. Must exist already. + :param create_task_kwargs: Passed to ``boto.create_task()``. See AWS boto3 datasync documentation. + :return: TaskArn of the created Task + :rtype: str + """ + task = self.get_conn().create_task( + SourceLocationArn=source_location_arn, + DestinationLocationArn=destination_location_arn, + **create_task_kwargs, + ) + self._refresh_tasks() + return task["TaskArn"] + + def update_task(self, task_arn: str, **update_task_kwargs) -> None: + r"""Update a Task. + + :param str task_arn: The TaskArn to update. + :param update_task_kwargs: Passed to ``boto.update_task()``, See AWS boto3 datasync documentation. + """ + self.get_conn().update_task(TaskArn=task_arn, **update_task_kwargs) + + def delete_task(self, task_arn: str) -> None: + r"""Delete a Task. + + :param str task_arn: The TaskArn to delete. + """ + self.get_conn().delete_task(TaskArn=task_arn) + + def _refresh_tasks(self) -> None: + """Refreshes the local list of Tasks""" + self.tasks = [] + next_token = None + while True: + if next_token: + tasks = self.get_conn().list_tasks(NextToken=next_token) + else: + tasks = self.get_conn().list_tasks() + self.tasks.extend(tasks["Tasks"]) + if "NextToken" not in tasks: + break + next_token = tasks["NextToken"] + + def get_task_arns_for_location_arns( + self, + source_location_arns: list, + destination_location_arns: list, + ) -> list: + """ + Return list of TaskArns for which use any one of the specified + source LocationArns and any one of the specified destination LocationArns. + + :param list source_location_arns: List of source LocationArns. + :param list destination_location_arns: List of destination LocationArns. + :return: list + :rtype: list(TaskArns) + :raises AirflowBadRequest: if ``source_location_arns`` or ``destination_location_arns`` are empty. + """ + if not source_location_arns: + raise AirflowBadRequest("source_location_arns not specified") + if not destination_location_arns: + raise AirflowBadRequest("destination_location_arns not specified") + if not self.tasks: + self._refresh_tasks() + + result = [] + for task in self.tasks: + task_arn = task["TaskArn"] + task_description = self.get_task_description(task_arn) + if task_description["SourceLocationArn"] in source_location_arns: + if ( + task_description["DestinationLocationArn"] + in destination_location_arns + ): + result.append(task_arn) + return result + + def start_task_execution(self, task_arn: str, **kwargs) -> str: + r""" + Start a TaskExecution for the specified task_arn. + Each task can have at most one TaskExecution. + + :param str task_arn: TaskArn + :return: TaskExecutionArn + :param kwargs: kwargs sent to ``boto3.start_task_execution()`` + :rtype: str + :raises ClientError: If a TaskExecution is already busy running for this ``task_arn``. + :raises AirflowBadRequest: If ``task_arn`` is empty. + """ + if not task_arn: + raise AirflowBadRequest("task_arn not specified") + task_execution = self.get_conn().start_task_execution( + TaskArn=task_arn, **kwargs + ) + return task_execution["TaskExecutionArn"] + + def cancel_task_execution(self, task_execution_arn: str) -> None: + """ + Cancel a TaskExecution for the specified ``task_execution_arn``. + + :param str task_execution_arn: TaskExecutionArn. + :raises AirflowBadRequest: If ``task_execution_arn`` is empty. + """ + if not task_execution_arn: + raise AirflowBadRequest("task_execution_arn not specified") + self.get_conn().cancel_task_execution(TaskExecutionArn=task_execution_arn) + + def get_task_description(self, task_arn: str) -> dict: + """ + Get description for the specified ``task_arn``. + + :param str task_arn: TaskArn + :return: AWS metadata about a task. + :rtype: dict + :raises AirflowBadRequest: If ``task_arn`` is empty. + """ + if not task_arn: + raise AirflowBadRequest("task_arn not specified") + return self.get_conn().describe_task(TaskArn=task_arn) + + def describe_task_execution(self, task_execution_arn: str) -> dict: + """ + Get description for the specified ``task_execution_arn``. + + :param str task_execution_arn: TaskExecutionArn + :return: AWS metadata about a task execution. + :rtype: dict + :raises AirflowBadRequest: If ``task_execution_arn`` is empty. + """ + return self.get_conn().describe_task_execution( + TaskExecutionArn=task_execution_arn + ) + + def get_current_task_execution_arn(self, task_arn: str) -> Optional[str]: + """ + Get current TaskExecutionArn (if one exists) for the specified ``task_arn``. + + :param str task_arn: TaskArn + :return: CurrentTaskExecutionArn for this ``task_arn`` or None. + :rtype: str + :raises AirflowBadRequest: if ``task_arn`` is empty. + """ + if not task_arn: + raise AirflowBadRequest("task_arn not specified") + task_description = self.get_task_description(task_arn) + if "CurrentTaskExecutionArn" in task_description: + return task_description["CurrentTaskExecutionArn"] + return None + + def wait_for_task_execution( + self, task_execution_arn: str, max_iterations: int = 2 * 180 + ) -> bool: + """ + Wait for Task Execution status to be complete (SUCCESS/ERROR). + The ``task_execution_arn`` must exist, or a boto3 ClientError will be raised. + + :param str task_execution_arn: TaskExecutionArn + :param int max_iterations: Maximum number of iterations before timing out. + :return: Result of task execution. + :rtype: bool + :raises AirflowTaskTimeout: If maximum iterations is exceeded. + :raises AirflowBadRequest: If ``task_execution_arn`` is empty. + """ + if not task_execution_arn: + raise AirflowBadRequest("task_execution_arn not specified") + + status = None + iterations = max_iterations + while status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES: + task_execution = self.get_conn().describe_task_execution( + TaskExecutionArn=task_execution_arn + ) + status = task_execution["Status"] + self.log.info("status=%s", status) + iterations -= 1 + if status in self.TASK_EXECUTION_FAILURE_STATES: + break + if status in self.TASK_EXECUTION_SUCCESS_STATES: + break + if iterations <= 0: + break + time.sleep(self.wait_interval_seconds) + + if status in self.TASK_EXECUTION_SUCCESS_STATES: + return True + if status in self.TASK_EXECUTION_FAILURE_STATES: + return False + if iterations <= 0: + raise AirflowTaskTimeout("Max iterations exceeded!") + raise AirflowException(f"Unknown status: {status}") # Should never happen diff --git a/reference/providers/amazon/aws/hooks/dynamodb.py b/reference/providers/amazon/aws/hooks/dynamodb.py new file mode 100644 index 0000000..c13c327 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/dynamodb.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +"""This module contains the AWS DynamoDB hook""" +from typing import Iterable, List, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class AwsDynamoDBHook(AwsBaseHook): + """ + Interact with AWS DynamoDB. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + + :param table_keys: partition key and sort key + :type table_keys: list + :param table_name: target DynamoDB table + :type table_name: str + """ + + def __init__( + self, + *args, + table_keys: Optional[List] = None, + table_name: Optional[str] = None, + **kwargs, + ) -> None: + self.table_keys = table_keys + self.table_name = table_name + kwargs["resource_type"] = "dynamodb" + super().__init__(*args, **kwargs) + + def write_batch_data(self, items: Iterable) -> bool: + """Write batch items to DynamoDB table with provisioned throughout capacity.""" + try: + table = self.get_conn().Table(self.table_name) + + with table.batch_writer(overwrite_by_pkeys=self.table_keys) as batch: + for item in items: + batch.put_item(Item=item) + return True + except Exception as general_error: + raise AirflowException( + f"Failed to insert items in dynamodb, error: {str(general_error)}" + ) diff --git a/reference/providers/amazon/aws/hooks/ec2.py b/reference/providers/amazon/aws/hooks/ec2.py new file mode 100644 index 0000000..fc00d5e --- /dev/null +++ b/reference/providers/amazon/aws/hooks/ec2.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import time + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class EC2Hook(AwsBaseHook): + """ + Interact with AWS EC2 Service. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs) -> None: + kwargs["resource_type"] = "ec2" + super().__init__(*args, **kwargs) + + def get_instance(self, instance_id: str): + """ + Get EC2 instance by id and return it. + + :param instance_id: id of the AWS EC2 instance + :type instance_id: str + :return: Instance object + :rtype: ec2.Instance + """ + return self.conn.Instance(id=instance_id) + + def get_instance_state(self, instance_id: str) -> str: + """ + Get EC2 instance state by id and return it. + + :param instance_id: id of the AWS EC2 instance + :type instance_id: str + :return: current state of the instance + :rtype: str + """ + return self.get_instance(instance_id=instance_id).state["Name"] + + def wait_for_state( + self, instance_id: str, target_state: str, check_interval: float + ) -> None: + """ + Wait EC2 instance until its state is equal to the target_state. + + :param instance_id: id of the AWS EC2 instance + :type instance_id: str + :param target_state: target state of instance + :type target_state: str + :param check_interval: time in seconds that the job should wait in + between each instance state checks until operation is completed + :type check_interval: float + :return: None + :rtype: None + """ + instance_state = self.get_instance_state(instance_id=instance_id) + while instance_state != target_state: + self.log.info("instance state: %s", instance_state) + time.sleep(check_interval) + instance_state = self.get_instance_state(instance_id=instance_id) diff --git a/reference/providers/amazon/aws/hooks/elasticache_replication_group.py b/reference/providers/amazon/aws/hooks/elasticache_replication_group.py new file mode 100644 index 0000000..4ed49d3 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/elasticache_replication_group.py @@ -0,0 +1,325 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from time import sleep +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class ElastiCacheReplicationGroupHook(AwsBaseHook): + """ + Interact with AWS ElastiCache + + :param max_retries: Max retries for checking availability of and deleting replication group + If this is not supplied then this is defaulted to 10 + :type max_retries: int + :param exponential_back_off_factor: Multiplication factor for deciding next sleep time + If this is not supplied then this is defaulted to 1 + :type exponential_back_off_factor: float + :param initial_poke_interval: Initial sleep time in seconds + If this is not supplied then this is defaulted to 60 seconds + :type initial_poke_interval: float + """ + + TERMINAL_STATES = frozenset({"available", "create-failed", "deleting"}) + + def __init__( + self, + max_retries: int = 10, + exponential_back_off_factor: float = 1, + initial_poke_interval: float = 60, + *args, + **kwargs, + ): + self.max_retries = max_retries + self.exponential_back_off_factor = exponential_back_off_factor + self.initial_poke_interval = initial_poke_interval + + kwargs["client_type"] = "elasticache" + super().__init__(*args, **kwargs) + + def create_replication_group(self, config: dict) -> dict: + """ + Call ElastiCache API for creating a replication group + + :param config: Configuration for creating the replication group + :type config: dict + :return: Response from ElastiCache create replication group API + :rtype: dict + """ + return self.conn.create_replication_group(**config) + + def delete_replication_group(self, replication_group_id: str) -> dict: + """ + Call ElastiCache API for deleting a replication group + + :param replication_group_id: ID of replication group to delete + :type replication_group_id: str + :return: Response from ElastiCache delete replication group API + :rtype: dict + """ + return self.conn.delete_replication_group( + ReplicationGroupId=replication_group_id + ) + + def describe_replication_group(self, replication_group_id: str) -> dict: + """ + Call ElastiCache API for describing a replication group + + :param replication_group_id: ID of replication group to describe + :type replication_group_id: str + :return: Response from ElastiCache describe replication group API + :rtype: dict + """ + return self.conn.describe_replication_groups( + ReplicationGroupId=replication_group_id + ) + + def get_replication_group_status(self, replication_group_id: str) -> str: + """ + Get current status of replication group + + :param replication_group_id: ID of replication group to check for status + :type replication_group_id: str + :return: Current status of replication group + :rtype: str + """ + return self.describe_replication_group(replication_group_id)[ + "ReplicationGroups" + ][0]["Status"] + + def is_replication_group_available(self, replication_group_id: str) -> bool: + """ + Helper for checking if replication group is available or not + + :param replication_group_id: ID of replication group to check for availability + :type replication_group_id: str + :return: True if available else False + :rtype: bool + """ + return self.get_replication_group_status(replication_group_id) == "available" + + def wait_for_availability( + self, + replication_group_id: str, + initial_sleep_time: Optional[float] = None, + exponential_back_off_factor: Optional[float] = None, + max_retries: Optional[int] = None, + ): + """ + Check if replication group is available or not by performing a describe over it + + :param replication_group_id: ID of replication group to check for availability + :type replication_group_id: str + :param initial_sleep_time: Initial sleep time in seconds + If this is not supplied then this is defaulted to class level value + :type initial_sleep_time: float + :param exponential_back_off_factor: Multiplication factor for deciding next sleep time + If this is not supplied then this is defaulted to class level value + :type exponential_back_off_factor: float + :param max_retries: Max retries for checking availability of replication group + If this is not supplied then this is defaulted to class level value + :type max_retries: int + :return: True if replication is available else False + :rtype: bool + """ + sleep_time = initial_sleep_time or self.initial_poke_interval + exponential_back_off_factor = ( + exponential_back_off_factor or self.exponential_back_off_factor + ) + max_retries = max_retries or self.max_retries + num_tries = 0 + status = "not-found" + stop_poking = False + + while not stop_poking and num_tries <= max_retries: + status = self.get_replication_group_status( + replication_group_id=replication_group_id + ) + stop_poking = status in self.TERMINAL_STATES + + self.log.info( + "Current status of replication group with ID %s is %s", + replication_group_id, + status, + ) + + if not stop_poking: + num_tries += 1 + + # No point in sleeping if all tries have exhausted + if num_tries > max_retries: + break + + self.log.info( + "Poke retry %s. Sleep time %s seconds. Sleeping...", + num_tries, + sleep_time, + ) + + sleep(sleep_time) + + sleep_time *= exponential_back_off_factor + + if status != "available": + self.log.warning( + 'Replication group is not available. Current status is "%s"', status + ) + + return False + + return True + + def wait_for_deletion( + self, + replication_group_id: str, + initial_sleep_time: Optional[float] = None, + exponential_back_off_factor: Optional[float] = None, + max_retries: Optional[int] = None, + ): + """ + Helper for deleting a replication group ensuring it is either deleted or can't be deleted + + :param replication_group_id: ID of replication to delete + :type replication_group_id: str + :param initial_sleep_time: Initial sleep time in second + If this is not supplied then this is defaulted to class level value + :type initial_sleep_time: float + :param exponential_back_off_factor: Multiplication factor for deciding next sleep time + If this is not supplied then this is defaulted to class level value + :type exponential_back_off_factor: float + :param max_retries: Max retries for checking availability of replication group + If this is not supplied then this is defaulted to class level value + :type max_retries: int + :return: Response from ElastiCache delete replication group API and flag to identify if deleted or not + :rtype: (dict, bool) + """ + deleted = False + sleep_time = initial_sleep_time or self.initial_poke_interval + exponential_back_off_factor = ( + exponential_back_off_factor or self.exponential_back_off_factor + ) + max_retries = max_retries or self.max_retries + num_tries = 0 + response = None + + while not deleted and num_tries <= max_retries: + try: + status = self.get_replication_group_status( + replication_group_id=replication_group_id + ) + + self.log.info( + "Current status of replication group with ID %s is %s", + replication_group_id, + status, + ) + + # Can only delete if status is `available` + # Status becomes `deleting` after this call so this will only run once + if status == "available": + self.log.info("Initiating delete and then wait for it to finish") + + response = self.delete_replication_group( + replication_group_id=replication_group_id + ) + + except self.conn.exceptions.ReplicationGroupNotFoundFault: + self.log.info( + "Replication group with ID '%s' does not exist", + replication_group_id, + ) + + deleted = True + + # This should never occur as we only issue a delete request when status is `available` + # which is a valid status for deletion. Still handling for safety. + except self.conn.exceptions.InvalidReplicationGroupStateFault as exp: + # status Error Response + # creating - Cache cluster is not in a valid state to be deleted. + # deleting - Replication group has status deleting which is not valid + # for deletion. + # modifying - Replication group has status deleting which is not valid + # for deletion. + + message = exp.response["Error"]["Message"] + + self.log.warning( + "Received error message from AWS ElastiCache API : %s", message + ) + + if not deleted: + num_tries += 1 + + # No point in sleeping if all tries have exhausted + if num_tries > max_retries: + break + + self.log.info( + "Poke retry %s. Sleep time %s seconds. Sleeping...", + num_tries, + sleep_time, + ) + + sleep(sleep_time) + + sleep_time *= exponential_back_off_factor + + return response, deleted + + def ensure_delete_replication_group( + self, + replication_group_id: str, + initial_sleep_time: Optional[float] = None, + exponential_back_off_factor: Optional[float] = None, + max_retries: Optional[int] = None, + ): + """ + Delete a replication group ensuring it is either deleted or can't be deleted + + :param replication_group_id: ID of replication to delete + :type replication_group_id: str + :param initial_sleep_time: Initial sleep time in second + If this is not supplied then this is defaulted to class level value + :type initial_sleep_time: float + :param exponential_back_off_factor: Multiplication factor for deciding next sleep time + If this is not supplied then this is defaulted to class level value + :type exponential_back_off_factor: float + :param max_retries: Max retries for checking availability of replication group + If this is not supplied then this is defaulted to class level value + :type max_retries: int + :return: Response from ElastiCache delete replication group API + :rtype: dict + :raises AirflowException: If replication group is not deleted + """ + self.log.info("Deleting replication group with ID %s", replication_group_id) + + response, deleted = self.wait_for_deletion( + replication_group_id=replication_group_id, + initial_sleep_time=initial_sleep_time, + exponential_back_off_factor=exponential_back_off_factor, + max_retries=max_retries, + ) + + if not deleted: + raise AirflowException( + f'Replication group could not be deleted. Response "{response}"' + ) + + return response diff --git a/reference/providers/amazon/aws/hooks/emr.py b/reference/providers/amazon/aws/hooks/emr.py new file mode 100644 index 0000000..e3f534d --- /dev/null +++ b/reference/providers/amazon/aws/hooks/emr.py @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, List, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class EmrHook(AwsBaseHook): + """ + Interact with AWS EMR. emr_conn_id is only necessary for using the + create_job_flow method. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + conn_name_attr = "emr_conn_id" + default_conn_name = "emr_default" + conn_type = "emr" + hook_name = "Elastic MapReduce" + + def __init__( + self, emr_conn_id: Optional[str] = default_conn_name, *args, **kwargs + ) -> None: + self.emr_conn_id = emr_conn_id + kwargs["client_type"] = "emr" + super().__init__(*args, **kwargs) + + def get_cluster_id_by_name( + self, emr_cluster_name: str, cluster_states: List[str] + ) -> Optional[str]: + """ + Fetch id of EMR cluster with given name and (optional) states. + Will return only if single id is found. + + :param emr_cluster_name: Name of a cluster to find + :type emr_cluster_name: str + :param cluster_states: State(s) of cluster to find + :type cluster_states: list + :return: id of the EMR cluster + """ + response = self.get_conn().list_clusters(ClusterStates=cluster_states) + + matching_clusters = list( + filter( + lambda cluster: cluster["Name"] == emr_cluster_name, + response["Clusters"], + ) + ) + + if len(matching_clusters) == 1: + cluster_id = matching_clusters[0]["Id"] + self.log.info( + "Found cluster name = %s id = %s", emr_cluster_name, cluster_id + ) + return cluster_id + elif len(matching_clusters) > 1: + raise AirflowException( + f"More than one cluster found for name {emr_cluster_name}" + ) + else: + self.log.info("No cluster found for name %s", emr_cluster_name) + return None + + def create_job_flow(self, job_flow_overrides: Dict[str, Any]) -> Dict[str, Any]: + """ + Creates a job flow using the config from the EMR connection. + Keys of the json extra hash may have the arguments of the boto3 + run_job_flow method. + Overrides for this config may be passed as the job_flow_overrides. + """ + if not self.emr_conn_id: + raise AirflowException("emr_conn_id must be present to use create_job_flow") + + emr_conn = self.get_connection(self.emr_conn_id) + + config = emr_conn.extra_dejson.copy() + config.update(job_flow_overrides) + + response = self.get_conn().run_job_flow(**config) + + return response diff --git a/reference/providers/amazon/aws/hooks/glacier.py b/reference/providers/amazon/aws/hooks/glacier.py new file mode 100644 index 0000000..32375f7 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/glacier.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from typing import Any, Dict + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class GlacierHook(AwsBaseHook): + """Hook for connection with Amazon Glacier""" + + def __init__(self, aws_conn_id: str = "aws_default") -> None: + super().__init__(client_type="glacier") + self.aws_conn_id = aws_conn_id + + def retrieve_inventory(self, vault_name: str) -> Dict[str, Any]: + """ + Initiate an Amazon Glacier inventory-retrieval job + + :param vault_name: the Glacier vault on which job is executed + :type vault_name: str + """ + job_params = {"Type": "inventory-retrieval"} + self.log.info("Retrieving inventory for vault: %s", vault_name) + response = self.get_conn().initiate_job( + vaultName=vault_name, jobParameters=job_params + ) + self.log.info("Initiated inventory-retrieval job for: %s", vault_name) + self.log.info("Retrieval Job ID: %s", response["jobId"]) + return response + + def retrieve_inventory_results( + self, vault_name: str, job_id: str + ) -> Dict[str, Any]: + """ + Retrieve the results of an Amazon Glacier inventory-retrieval job + + :param vault_name: the Glacier vault on which job is executed + :type vault_name: string + :param job_id: the job ID was returned by retrieve_inventory() + :type job_id: str + """ + self.log.info("Retrieving the job results for vault: %s...", vault_name) + response = self.get_conn().get_job_output(vaultName=vault_name, jobId=job_id) + return response + + def describe_job(self, vault_name: str, job_id: str) -> Dict[str, Any]: + """ + Retrieve the status of an Amazon S3 Glacier job, such as an + inventory-retrieval job + + :param vault_name: the Glacier vault on which job is executed + :type vault_name: string + :param job_id: the job ID was returned by retrieve_inventory() + :type job_id: str + """ + self.log.info("Retrieving status for vault: %s and job %s", vault_name, job_id) + response = self.get_conn().describe_job(vaultName=vault_name, jobId=job_id) + self.log.info( + "Job status: %s, code status: %s", + response["Action"], + response["StatusCode"], + ) + return response diff --git a/reference/providers/amazon/aws/hooks/glue.py b/reference/providers/amazon/aws/hooks/glue.py new file mode 100644 index 0000000..969a288 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/glue.py @@ -0,0 +1,205 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +from typing import Dict, List, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class AwsGlueJobHook(AwsBaseHook): + """ + Interact with AWS Glue - create job, trigger, crawler + + :param s3_bucket: S3 bucket where logs and local etl script will be uploaded + :type s3_bucket: Optional[str] + :param job_name: unique job name per AWS account + :type job_name: Optional[str] + :param desc: job description + :type desc: Optional[str] + :param concurrent_run_limit: The maximum number of concurrent runs allowed for a job + :type concurrent_run_limit: int + :param script_location: path to etl script on s3 + :type script_location: Optional[str] + :param retry_limit: Maximum number of times to retry this job if it fails + :type retry_limit: int + :param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job + :type num_of_dpus: int + :param region_name: aws region name (example: us-east-1) + :type region_name: Optional[str] + :param iam_role_name: AWS IAM Role for Glue Job Execution + :type iam_role_name: Optional[str] + :param create_job_kwargs: Extra arguments for Glue Job Creation + :type create_job_kwargs: Optional[dict] + """ + + JOB_POLL_INTERVAL = 6 # polls job status after every JOB_POLL_INTERVAL seconds + + def __init__( + self, + s3_bucket: Optional[str] = None, + job_name: Optional[str] = None, + desc: Optional[str] = None, + concurrent_run_limit: int = 1, + script_location: Optional[str] = None, + retry_limit: int = 0, + num_of_dpus: int = 10, + region_name: Optional[str] = None, + iam_role_name: Optional[str] = None, + create_job_kwargs: Optional[dict] = None, + *args, + **kwargs, + ): # pylint: disable=too-many-arguments + self.job_name = job_name + self.desc = desc + self.concurrent_run_limit = concurrent_run_limit + self.script_location = script_location + self.retry_limit = retry_limit + self.num_of_dpus = num_of_dpus + self.region_name = region_name + self.s3_bucket = s3_bucket + self.role_name = iam_role_name + self.s3_glue_logs = "logs/glue-logs/" + self.create_job_kwargs = create_job_kwargs or {} + kwargs["client_type"] = "glue" + super().__init__(*args, **kwargs) + + def list_jobs(self) -> List: + """:return: Lists of Jobs""" + conn = self.get_conn() + return conn.get_jobs() + + def get_iam_execution_role(self) -> Dict: + """:return: iam role for job execution""" + iam_client = self.get_client_type("iam", self.region_name) + + try: + glue_execution_role = iam_client.get_role(RoleName=self.role_name) + self.log.info("Iam Role Name: %s", self.role_name) + return glue_execution_role + except Exception as general_error: + self.log.error("Failed to create aws glue job, error: %s", general_error) + raise + + def initialize_job(self, script_arguments: Optional[dict] = None) -> Dict[str, str]: + """ + Initializes connection with AWS Glue + to run job + :return: + """ + glue_client = self.get_conn() + script_arguments = script_arguments or {} + + try: + job_name = self.get_or_create_glue_job() + job_run = glue_client.start_job_run( + JobName=job_name, Arguments=script_arguments + ) + return job_run + except Exception as general_error: + self.log.error("Failed to run aws glue job, error: %s", general_error) + raise + + def get_job_state(self, job_name: str, run_id: str) -> str: + """ + Get state of the Glue job. The job state can be + running, finished, failed, stopped or timeout. + :param job_name: unique job name per AWS account + :type job_name: str + :param run_id: The job-run ID of the predecessor job run + :type run_id: str + :return: State of the Glue job + """ + glue_client = self.get_conn() + job_run = glue_client.get_job_run( + JobName=job_name, RunId=run_id, PredecessorsIncluded=True + ) + job_run_state = job_run["JobRun"]["JobRunState"] + return job_run_state + + def job_completion(self, job_name: str, run_id: str) -> Dict[str, str]: + """ + Waits until Glue job with job_name completes or + fails and return final state if finished. + Raises AirflowException when the job failed + :param job_name: unique job name per AWS account + :type job_name: str + :param run_id: The job-run ID of the predecessor job run + :type run_id: str + :return: Dict of JobRunState and JobRunId + """ + failed_states = ["FAILED", "TIMEOUT"] + finished_states = ["SUCCEEDED", "STOPPED"] + + while True: + job_run_state = self.get_job_state(job_name, run_id) + if job_run_state in finished_states: + self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state) + return {"JobRunState": job_run_state, "JobRunId": run_id} + if job_run_state in failed_states: + job_error_message = ( + "Exiting Job " + run_id + " Run State: " + job_run_state + ) + self.log.info(job_error_message) + raise AirflowException(job_error_message) + else: + self.log.info( + "Polling for AWS Glue Job %s current run state with status %s", + job_name, + job_run_state, + ) + time.sleep(self.JOB_POLL_INTERVAL) + + def get_or_create_glue_job(self) -> str: + """ + Creates(or just returns) and returns the Job name + :return:Name of the Job + """ + glue_client = self.get_conn() + try: + get_job_response = glue_client.get_job(JobName=self.job_name) + self.log.info("Job Already exist. Returning Name of the job") + return get_job_response["Job"]["Name"] + + except glue_client.exceptions.EntityNotFoundException: + self.log.info("Job doesnt exist. Now creating and running AWS Glue Job") + if self.s3_bucket is None: + raise AirflowException( + "Could not initialize glue job, error: Specify Parameter `s3_bucket`" + ) + s3_log_path = f"s3://{self.s3_bucket}/{self.s3_glue_logs}{self.job_name}" + execution_role = self.get_iam_execution_role() + try: + create_job_response = glue_client.create_job( + Name=self.job_name, + Description=self.desc, + LogUri=s3_log_path, + Role=execution_role["Role"]["RoleName"], + ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit}, + Command={"Name": "glueetl", "ScriptLocation": self.script_location}, + MaxRetries=self.retry_limit, + AllocatedCapacity=self.num_of_dpus, + **self.create_job_kwargs, + ) + return create_job_response["Name"] + except Exception as general_error: + self.log.error( + "Failed to create aws glue job, error: %s", general_error + ) + raise diff --git a/reference/providers/amazon/aws/hooks/glue_catalog.py b/reference/providers/amazon/aws/hooks/glue_catalog.py new file mode 100644 index 0000000..8a17547 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/glue_catalog.py @@ -0,0 +1,142 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains AWS Glue Catalog Hook""" +from typing import Optional, Set + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class AwsGlueCatalogHook(AwsBaseHook): + """ + Interact with AWS Glue Catalog + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs): + super().__init__(client_type="glue", *args, **kwargs) + + def get_partitions( + self, + database_name: str, + table_name: str, + expression: str = "", + page_size: Optional[int] = None, + max_items: Optional[int] = None, + ) -> Set[tuple]: + """ + Retrieves the partition values for a table. + + :param database_name: The name of the catalog database where the partitions reside. + :type database_name: str + :param table_name: The name of the partitions' table. + :type table_name: str + :param expression: An expression filtering the partitions to be returned. + Please see official AWS documentation for further information. + https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartitions + :type expression: str + :param page_size: pagination size + :type page_size: int + :param max_items: maximum items to return + :type max_items: int + :return: set of partition values where each value is a tuple since + a partition may be composed of multiple columns. For example: + ``{('2018-01-01','1'), ('2018-01-01','2')}`` + """ + config = { + "PageSize": page_size, + "MaxItems": max_items, + } + + paginator = self.get_conn().get_paginator("get_partitions") + response = paginator.paginate( + DatabaseName=database_name, + TableName=table_name, + Expression=expression, + PaginationConfig=config, + ) + + partitions = set() + for page in response: + for partition in page["Partitions"]: + partitions.add(tuple(partition["Values"])) + + return partitions + + def check_for_partition( + self, database_name: str, table_name: str, expression: str + ) -> bool: + """ + Checks whether a partition exists + + :param database_name: Name of hive database (schema) @table belongs to + :type database_name: str + :param table_name: Name of hive table @partition belongs to + :type table_name: str + :expression: Expression that matches the partitions to check for + (eg `a = 'b' AND c = 'd'`) + :type expression: str + :rtype: bool + + >>> hook = AwsGlueCatalogHook() + >>> t = 'static_babynames_partitioned' + >>> hook.check_for_partition('airflow', t, "ds='2015-01-01'") + True + """ + partitions = self.get_partitions( + database_name, table_name, expression, max_items=1 + ) + + return bool(partitions) + + def get_table(self, database_name: str, table_name: str) -> dict: + """ + Get the information of the table + + :param database_name: Name of hive database (schema) @table belongs to + :type database_name: str + :param table_name: Name of hive table + :type table_name: str + :rtype: dict + + >>> hook = AwsGlueCatalogHook() + >>> r = hook.get_table('db', 'table_foo') + >>> r['Name'] = 'table_foo' + """ + result = self.get_conn().get_table(DatabaseName=database_name, Name=table_name) + + return result["Table"] + + def get_table_location(self, database_name: str, table_name: str) -> str: + """ + Get the physical location of the table + + :param database_name: Name of hive database (schema) @table belongs to + :type database_name: str + :param table_name: Name of hive table + :type table_name: str + :return: str + """ + table = self.get_table(database_name, table_name) + + return table["StorageDescriptor"]["Location"] diff --git a/reference/providers/amazon/aws/hooks/glue_crawler.py b/reference/providers/amazon/aws/hooks/glue_crawler.py new file mode 100644 index 0000000..f71d040 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/glue_crawler.py @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from time import sleep + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class AwsGlueCrawlerHook(AwsBaseHook): + """ + Interacts with AWS Glue Crawler. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs): + kwargs["client_type"] = "glue" + super().__init__(*args, **kwargs) + + @cached_property + def glue_client(self): + """:return: AWS Glue client""" + return self.get_conn() + + def has_crawler(self, crawler_name) -> bool: + """ + Checks if the crawler already exists + + :param crawler_name: unique crawler name per AWS account + :type crawler_name: str + :return: Returns True if the crawler already exists and False if not. + """ + self.log.info("Checking if crawler already exists: %s", crawler_name) + + try: + self.get_crawler(crawler_name) + return True + except self.glue_client.exceptions.EntityNotFoundException: + return False + + def get_crawler(self, crawler_name: str) -> dict: + """ + Gets crawler configurations + + :param crawler_name: unique crawler name per AWS account + :type crawler_name: str + :return: Nested dictionary of crawler configurations + """ + return self.glue_client.get_crawler(Name=crawler_name)["Crawler"] + + def update_crawler(self, **crawler_kwargs) -> str: + """ + Updates crawler configurations + + :param crawler_kwargs: Keyword args that define the configurations used for the crawler + :type crawler_kwargs: any + :return: True if crawler was updated and false otherwise + """ + crawler_name = crawler_kwargs["Name"] + current_crawler = self.get_crawler(crawler_name) + + update_config = { + key: value + for key, value in crawler_kwargs.items() + if current_crawler[key] != crawler_kwargs[key] + } + if update_config != {}: + self.log.info("Updating crawler: %s", crawler_name) + self.glue_client.update_crawler(**crawler_kwargs) + self.log.info("Updated configurations: %s", update_config) + return True + else: + return False + + def create_crawler(self, **crawler_kwargs) -> str: + """ + Creates an AWS Glue Crawler + + :param crawler_kwargs: Keyword args that define the configurations used to create the crawler + :type crawler_kwargs: any + :return: Name of the crawler + """ + crawler_name = crawler_kwargs["Name"] + self.log.info("Creating crawler: %s", crawler_name) + return self.glue_client.create_crawler(**crawler_kwargs)["Crawler"]["Name"] + + def start_crawler(self, crawler_name: str) -> dict: + """ + Triggers the AWS Glue crawler + + :param crawler_name: unique crawler name per AWS account + :type crawler_name: str + :return: Empty dictionary + """ + self.log.info("Starting crawler %s", crawler_name) + crawler = self.glue_client.start_crawler(Name=crawler_name) + return crawler + + def wait_for_crawler_completion( + self, crawler_name: str, poll_interval: int = 5 + ) -> str: + """ + Waits until Glue crawler completes and + returns the status of the latest crawl run. + Raises AirflowException if the crawler fails or is cancelled. + + :param crawler_name: unique crawler name per AWS account + :type crawler_name: str + :param poll_interval: Time (in seconds) to wait between two consecutive calls to check crawler status + :type poll_interval: int + :return: Crawler's status + """ + failed_status = ["FAILED", "CANCELLED"] + + while True: + crawler = self.get_crawler(crawler_name) + crawler_state = crawler["State"] + if crawler_state == "READY": + self.log.info("State: %s", crawler_state) + self.log.info("crawler_config: %s", crawler) + crawler_status = crawler["LastCrawl"]["Status"] + if crawler_status in failed_status: + raise AirflowException( + f"Status: {crawler_status}" + ) # pylint: disable=raising-format-tuple + else: + metrics = self.glue_client.get_crawler_metrics( + CrawlerNameList=[crawler_name] + )["CrawlerMetricsList"][0] + self.log.info("Status: %s", crawler_status) + self.log.info( + "Last Runtime Duration (seconds): %s", + metrics["LastRuntimeSeconds"], + ) + self.log.info( + "Median Runtime Duration (seconds): %s", + metrics["MedianRuntimeSeconds"], + ) + self.log.info("Tables Created: %s", metrics["TablesCreated"]) + self.log.info("Tables Updated: %s", metrics["TablesUpdated"]) + self.log.info("Tables Deleted: %s", metrics["TablesDeleted"]) + + return crawler_status + + else: + self.log.info("Polling for AWS Glue crawler: %s ", crawler_name) + self.log.info("State: %s", crawler_state) + + metrics = self.glue_client.get_crawler_metrics( + CrawlerNameList=[crawler_name] + )["CrawlerMetricsList"][0] + time_left = int(metrics["TimeLeftSeconds"]) + + if time_left > 0: + self.log.info("Estimated Time Left (seconds): %s", time_left) + else: + self.log.info("Crawler should finish soon") + + sleep(poll_interval) diff --git a/reference/providers/amazon/aws/hooks/kinesis.py b/reference/providers/amazon/aws/hooks/kinesis.py new file mode 100644 index 0000000..2d3ffc5 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/kinesis.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains AWS Firehose hook""" +from typing import Iterable + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class AwsFirehoseHook(AwsBaseHook): + """ + Interact with AWS Kinesis Firehose. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + + :param delivery_stream: Name of the delivery stream + :type delivery_stream: str + """ + + def __init__(self, delivery_stream: str, *args, **kwargs) -> None: + self.delivery_stream = delivery_stream + kwargs["client_type"] = "firehose" + super().__init__(*args, **kwargs) + + def put_records(self, records: Iterable): + """Write batch records to Kinesis Firehose""" + response = self.get_conn().put_record_batch( + DeliveryStreamName=self.delivery_stream, Records=records + ) + + return response diff --git a/reference/providers/amazon/aws/hooks/lambda_function.py b/reference/providers/amazon/aws/hooks/lambda_function.py new file mode 100644 index 0000000..0d5fba0 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/lambda_function.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains AWS Lambda hook""" +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class AwsLambdaHook(AwsBaseHook): + """ + Interact with AWS Lambda + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + + :param function_name: AWS Lambda Function Name + :type function_name: str + :param log_type: Tail Invocation Request + :type log_type: str + :param qualifier: AWS Lambda Function Version or Alias Name + :type qualifier: str + :param invocation_type: AWS Lambda Invocation Type (RequestResponse, Event etc) + :type invocation_type: str + """ + + def __init__( + self, + function_name: str, + log_type: str = "None", + qualifier: str = "$LATEST", + invocation_type: str = "RequestResponse", + *args, + **kwargs, + ) -> None: + self.function_name = function_name + self.log_type = log_type + self.invocation_type = invocation_type + self.qualifier = qualifier + kwargs["client_type"] = "lambda" + super().__init__(*args, **kwargs) + + def invoke_lambda(self, payload: str) -> str: + """Invoke Lambda Function""" + response = self.conn.invoke( + FunctionName=self.function_name, + InvocationType=self.invocation_type, + LogType=self.log_type, + Payload=payload, + Qualifier=self.qualifier, + ) + + return response diff --git a/reference/providers/amazon/aws/hooks/logs.py b/reference/providers/amazon/aws/hooks/logs.py new file mode 100644 index 0000000..4e609f7 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/logs.py @@ -0,0 +1,105 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This module contains a hook (AwsLogsHook) with some very basic +functionality for interacting with AWS CloudWatch. +""" +from typing import Dict, Generator, Optional + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class AwsLogsHook(AwsBaseHook): + """ + Interact with AWS CloudWatch Logs + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = "logs" + super().__init__(*args, **kwargs) + + def get_log_events( + self, + log_group: str, + log_stream_name: str, + start_time: int = 0, + skip: int = 0, + start_from_head: bool = True, + ) -> Generator: + """ + A generator for log items in a single stream. This will yield all the + items that are available at the current moment. + + :param log_group: The name of the log group. + :type log_group: str + :param log_stream_name: The name of the specific stream. + :type log_stream_name: str + :param start_time: The time stamp value to start reading the logs from (default: 0). + :type start_time: int + :param skip: The number of log entries to skip at the start (default: 0). + This is for when there are multiple entries at the same timestamp. + :type skip: int + :param start_from_head: whether to start from the beginning (True) of the log or + at the end of the log (False). + :type start_from_head: bool + :rtype: dict + :return: | A CloudWatch log event with the following key-value pairs: + | 'timestamp' (int): The time in milliseconds of the event. + | 'message' (str): The log event data. + | 'ingestionTime' (int): The time in milliseconds the event was ingested. + """ + next_token = None + + event_count = 1 + while event_count > 0: + if next_token is not None: + token_arg: Optional[Dict[str, str]] = {"nextToken": next_token} + else: + token_arg = {} + + response = self.get_conn().get_log_events( + logGroupName=log_group, + logStreamName=log_stream_name, + startTime=start_time, + startFromHead=start_from_head, + **token_arg, + ) + + events = response["events"] + event_count = len(events) + + if event_count > skip: + events = events[skip:] + skip = 0 + else: + skip -= event_count + events = [] + + yield from events + + if "nextForwardToken" in response: + next_token = response["nextForwardToken"] + else: + return diff --git a/reference/providers/amazon/aws/hooks/redshift.py b/reference/providers/amazon/aws/hooks/redshift.py new file mode 100644 index 0000000..84dea4e --- /dev/null +++ b/reference/providers/amazon/aws/hooks/redshift.py @@ -0,0 +1,131 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Interact with AWS Redshift, using the boto3 library.""" + +from typing import List, Optional + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class RedshiftHook(AwsBaseHook): + """ + Interact with AWS Redshift, using the boto3 library + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = "redshift" + super().__init__(*args, **kwargs) + + # TODO: Wrap create_cluster_snapshot + def cluster_status(self, cluster_identifier: str) -> str: + """ + Return status of a cluster + + :param cluster_identifier: unique identifier of a cluster + :type cluster_identifier: str + """ + try: + response = self.get_conn().describe_clusters( + ClusterIdentifier=cluster_identifier + )["Clusters"] + return response[0]["ClusterStatus"] if response else None + except self.get_conn().exceptions.ClusterNotFoundFault: + return "cluster_not_found" + + def delete_cluster( # pylint: disable=invalid-name + self, + cluster_identifier: str, + skip_final_cluster_snapshot: bool = True, + final_cluster_snapshot_identifier: Optional[str] = None, + ): + """ + Delete a cluster and optionally create a snapshot + + :param cluster_identifier: unique identifier of a cluster + :type cluster_identifier: str + :param skip_final_cluster_snapshot: determines cluster snapshot creation + :type skip_final_cluster_snapshot: bool + :param final_cluster_snapshot_identifier: name of final cluster snapshot + :type final_cluster_snapshot_identifier: str + """ + final_cluster_snapshot_identifier = final_cluster_snapshot_identifier or "" + + response = self.get_conn().delete_cluster( + ClusterIdentifier=cluster_identifier, + SkipFinalClusterSnapshot=skip_final_cluster_snapshot, + FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier, + ) + return response["Cluster"] if response["Cluster"] else None + + def describe_cluster_snapshots( + self, cluster_identifier: str + ) -> Optional[List[str]]: + """ + Gets a list of snapshots for a cluster + + :param cluster_identifier: unique identifier of a cluster + :type cluster_identifier: str + """ + response = self.get_conn().describe_cluster_snapshots( + ClusterIdentifier=cluster_identifier + ) + if "Snapshots" not in response: + return None + snapshots = response["Snapshots"] + snapshots = [snapshot for snapshot in snapshots if snapshot["Status"]] + snapshots.sort(key=lambda x: x["SnapshotCreateTime"], reverse=True) + return snapshots + + def restore_from_cluster_snapshot( + self, cluster_identifier: str, snapshot_identifier: str + ) -> str: + """ + Restores a cluster from its snapshot + + :param cluster_identifier: unique identifier of a cluster + :type cluster_identifier: str + :param snapshot_identifier: unique identifier for a snapshot of a cluster + :type snapshot_identifier: str + """ + response = self.get_conn().restore_from_cluster_snapshot( + ClusterIdentifier=cluster_identifier, SnapshotIdentifier=snapshot_identifier + ) + return response["Cluster"] if response["Cluster"] else None + + def create_cluster_snapshot( + self, snapshot_identifier: str, cluster_identifier: str + ) -> str: + """ + Creates a snapshot of a cluster + + :param snapshot_identifier: unique identifier for a snapshot of a cluster + :type snapshot_identifier: str + :param cluster_identifier: unique identifier of a cluster + :type cluster_identifier: str + """ + response = self.get_conn().create_cluster_snapshot( + SnapshotIdentifier=snapshot_identifier, + ClusterIdentifier=cluster_identifier, + ) + return response["Snapshot"] if response["Snapshot"] else None diff --git a/reference/providers/amazon/aws/hooks/s3.py b/reference/providers/amazon/aws/hooks/s3.py new file mode 100644 index 0000000..5d88286 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/s3.py @@ -0,0 +1,973 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name +"""Interact with AWS S3, using the boto3 library.""" +import fnmatch +import gzip as gz +import io +import re +import shutil +from functools import wraps +from inspect import signature +from io import BytesIO +from tempfile import NamedTemporaryFile +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, cast +from urllib.parse import urlparse + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.utils.helpers import chunks +from boto3.s3.transfer import S3Transfer, TransferConfig +from botocore.exceptions import ClientError + +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + + +def provide_bucket_name(func: T) -> T: + """ + Function decorator that provides a bucket name taken from the connection + in case no bucket name has been passed to the function. + """ + function_signature = signature(func) + + @wraps(func) + def wrapper(*args, **kwargs) -> T: + bound_args = function_signature.bind(*args, **kwargs) + + if "bucket_name" not in bound_args.arguments: + self = args[0] + if self.aws_conn_id: + connection = self.get_connection(self.aws_conn_id) + if connection.schema: + bound_args.arguments["bucket_name"] = connection.schema + + return func(*bound_args.args, **bound_args.kwargs) + + return cast(T, wrapper) + + +def unify_bucket_name_and_key(func: T) -> T: + """ + Function decorator that unifies bucket name and key taken from the key + in case no bucket name and at least a key has been passed to the function. + """ + function_signature = signature(func) + + @wraps(func) + def wrapper(*args, **kwargs) -> T: + bound_args = function_signature.bind(*args, **kwargs) + + def get_key_name() -> Optional[str]: + if "wildcard_key" in bound_args.arguments: + return "wildcard_key" + if "key" in bound_args.arguments: + return "key" + raise ValueError("Missing key parameter!") + + key_name = get_key_name() + if key_name and "bucket_name" not in bound_args.arguments: + ( + bound_args.arguments["bucket_name"], + bound_args.arguments[key_name], + ) = S3Hook.parse_s3_url(bound_args.arguments[key_name]) + + return func(*bound_args.args, **bound_args.kwargs) + + return cast(T, wrapper) + + +class S3Hook(AwsBaseHook): + """ + Interact with AWS S3, using the boto3 library. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + conn_type = "s3" + hook_name = "S3" + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = "s3" + + self.extra_args = {} + if "extra_args" in kwargs: + self.extra_args = kwargs["extra_args"] + if not isinstance(self.extra_args, dict): + raise ValueError( + f"extra_args '{self.extra_args!r}' must be of type {dict}" + ) + del kwargs["extra_args"] + + self.transfer_config = TransferConfig() + if "transfer_config_args" in kwargs: + transport_config_args = kwargs["transfer_config_args"] + if not isinstance(transport_config_args, dict): + raise ValueError( + f"transfer_config_args '{transport_config_args!r} must be of type {dict}" + ) + self.transfer_config = TransferConfig(**transport_config_args) + del kwargs["transfer_config_args"] + + super().__init__(*args, **kwargs) + + @staticmethod + def parse_s3_url(s3url: str) -> Tuple[str, str]: + """ + Parses the S3 Url into a bucket name and key. + + :param s3url: The S3 Url to parse. + :rtype s3url: str + :return: the parsed bucket name and key + :rtype: tuple of str + """ + parsed_url = urlparse(s3url) + + if not parsed_url.netloc: + raise AirflowException(f'Please provide a bucket_name instead of "{s3url}"') + + bucket_name = parsed_url.netloc + key = parsed_url.path.strip("/") + + return bucket_name, key + + @provide_bucket_name + def check_for_bucket(self, bucket_name: Optional[str] = None) -> bool: + """ + Check if bucket_name exists. + + :param bucket_name: the name of the bucket + :type bucket_name: str + :return: True if it exists and False if not. + :rtype: bool + """ + try: + self.get_conn().head_bucket(Bucket=bucket_name) + return True + except ClientError as e: + self.log.error(e.response["Error"]["Message"]) + return False + + @provide_bucket_name + def get_bucket(self, bucket_name: Optional[str] = None) -> str: + """ + Returns a boto3.S3.Bucket object + + :param bucket_name: the name of the bucket + :type bucket_name: str + :return: the bucket object to the bucket name. + :rtype: boto3.S3.Bucket + """ + s3_resource = self.get_resource_type("s3") + return s3_resource.Bucket(bucket_name) + + @provide_bucket_name + def create_bucket( + self, bucket_name: Optional[str] = None, region_name: Optional[str] = None + ) -> None: + """ + Creates an Amazon S3 bucket. + + :param bucket_name: The name of the bucket + :type bucket_name: str + :param region_name: The name of the aws region in which to create the bucket. + :type region_name: str + """ + if not region_name: + region_name = self.get_conn().meta.region_name + if region_name == "us-east-1": + self.get_conn().create_bucket(Bucket=bucket_name) + else: + self.get_conn().create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": region_name}, + ) + + @provide_bucket_name + def check_for_prefix( + self, prefix: str, delimiter: str, bucket_name: Optional[str] = None + ) -> bool: + """ + Checks that a prefix exists in a bucket + + :param bucket_name: the name of the bucket + :type bucket_name: str + :param prefix: a key prefix + :type prefix: str + :param delimiter: the delimiter marks key hierarchy. + :type delimiter: str + :return: False if the prefix does not exist in the bucket and True if it does. + :rtype: bool + """ + prefix = prefix + delimiter if prefix[-1] != delimiter else prefix + prefix_split = re.split(fr"(\w+[{delimiter}])$", prefix, 1) + previous_level = prefix_split[0] + plist = self.list_prefixes(bucket_name, previous_level, delimiter) + return prefix in plist + + @provide_bucket_name + def list_prefixes( + self, + bucket_name: Optional[str] = None, + prefix: Optional[str] = None, + delimiter: Optional[str] = None, + page_size: Optional[int] = None, + max_items: Optional[int] = None, + ) -> list: + """ + Lists prefixes in a bucket under prefix + + :param bucket_name: the name of the bucket + :type bucket_name: str + :param prefix: a key prefix + :type prefix: str + :param delimiter: the delimiter marks key hierarchy. + :type delimiter: str + :param page_size: pagination size + :type page_size: int + :param max_items: maximum items to return + :type max_items: int + :return: a list of matched prefixes + :rtype: list + """ + prefix = prefix or "" + delimiter = delimiter or "" + config = { + "PageSize": page_size, + "MaxItems": max_items, + } + + paginator = self.get_conn().get_paginator("list_objects_v2") + response = paginator.paginate( + Bucket=bucket_name, + Prefix=prefix, + Delimiter=delimiter, + PaginationConfig=config, + ) + + prefixes = [] + for page in response: + if "CommonPrefixes" in page: + for common_prefix in page["CommonPrefixes"]: + prefixes.append(common_prefix["Prefix"]) + + return prefixes + + @provide_bucket_name + def list_keys( + self, + bucket_name: Optional[str] = None, + prefix: Optional[str] = None, + delimiter: Optional[str] = None, + page_size: Optional[int] = None, + max_items: Optional[int] = None, + ) -> list: + """ + Lists keys in a bucket under prefix and not containing delimiter + + :param bucket_name: the name of the bucket + :type bucket_name: str + :param prefix: a key prefix + :type prefix: str + :param delimiter: the delimiter marks key hierarchy. + :type delimiter: str + :param page_size: pagination size + :type page_size: int + :param max_items: maximum items to return + :type max_items: int + :return: a list of matched keys + :rtype: list + """ + prefix = prefix or "" + delimiter = delimiter or "" + config = { + "PageSize": page_size, + "MaxItems": max_items, + } + + paginator = self.get_conn().get_paginator("list_objects_v2") + response = paginator.paginate( + Bucket=bucket_name, + Prefix=prefix, + Delimiter=delimiter, + PaginationConfig=config, + ) + + keys = [] + for page in response: + if "Contents" in page: + for k in page["Contents"]: + keys.append(k["Key"]) + + return keys + + @provide_bucket_name + @unify_bucket_name_and_key + def check_for_key(self, key: str, bucket_name: Optional[str] = None) -> bool: + """ + Checks if a key exists in a bucket + + :param key: S3 key that will point to the file + :type key: str + :param bucket_name: Name of the bucket in which the file is stored + :type bucket_name: str + :return: True if the key exists and False if not. + :rtype: bool + """ + try: + self.get_conn().head_object(Bucket=bucket_name, Key=key) + return True + except ClientError as e: + if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: + return False + else: + raise e + + @provide_bucket_name + @unify_bucket_name_and_key + def get_key(self, key: str, bucket_name: Optional[str] = None) -> S3Transfer: + """ + Returns a boto3.s3.Object + + :param key: the path to the key + :type key: str + :param bucket_name: the name of the bucket + :type bucket_name: str + :return: the key object from the bucket + :rtype: boto3.s3.Object + """ + obj = self.get_resource_type("s3").Object(bucket_name, key) + obj.load() + return obj + + @provide_bucket_name + @unify_bucket_name_and_key + def read_key(self, key: str, bucket_name: Optional[str] = None) -> str: + """ + Reads a key from S3 + + :param key: S3 key that will point to the file + :type key: str + :param bucket_name: Name of the bucket in which the file is stored + :type bucket_name: str + :return: the content of the key + :rtype: str + """ + obj = self.get_key(key, bucket_name) + return obj.get()["Body"].read().decode("utf-8") + + @provide_bucket_name + @unify_bucket_name_and_key + def select_key( + self, + key: str, + bucket_name: Optional[str] = None, + expression: Optional[str] = None, + expression_type: Optional[str] = None, + input_serialization: Optional[Dict[str, Any]] = None, + output_serialization: Optional[Dict[str, Any]] = None, + ) -> str: + """ + Reads a key with S3 Select. + + :param key: S3 key that will point to the file + :type key: str + :param bucket_name: Name of the bucket in which the file is stored + :type bucket_name: str + :param expression: S3 Select expression + :type expression: str + :param expression_type: S3 Select expression type + :type expression_type: str + :param input_serialization: S3 Select input data serialization format + :type input_serialization: dict + :param output_serialization: S3 Select output data serialization format + :type output_serialization: dict + :return: retrieved subset of original data by S3 Select + :rtype: str + + .. seealso:: + For more details about S3 Select parameters: + http://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.select_object_content + """ + expression = expression or "SELECT * FROM S3Object" + expression_type = expression_type or "SQL" + + if input_serialization is None: + input_serialization = {"CSV": {}} + if output_serialization is None: + output_serialization = {"CSV": {}} + + response = self.get_conn().select_object_content( + Bucket=bucket_name, + Key=key, + Expression=expression, + ExpressionType=expression_type, + InputSerialization=input_serialization, + OutputSerialization=output_serialization, + ) + + return "".join( + event["Records"]["Payload"].decode("utf-8") + for event in response["Payload"] + if "Records" in event + ) + + @provide_bucket_name + @unify_bucket_name_and_key + def check_for_wildcard_key( + self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = "" + ) -> bool: + """ + Checks that a key matching a wildcard expression exists in a bucket + + :param wildcard_key: the path to the key + :type wildcard_key: str + :param bucket_name: the name of the bucket + :type bucket_name: str + :param delimiter: the delimiter marks key hierarchy + :type delimiter: str + :return: True if a key exists and False if not. + :rtype: bool + """ + return ( + self.get_wildcard_key( + wildcard_key=wildcard_key, bucket_name=bucket_name, delimiter=delimiter + ) + is not None + ) + + @provide_bucket_name + @unify_bucket_name_and_key + def get_wildcard_key( + self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = "" + ) -> S3Transfer: + """ + Returns a boto3.s3.Object object matching the wildcard expression + + :param wildcard_key: the path to the key + :type wildcard_key: str + :param bucket_name: the name of the bucket + :type bucket_name: str + :param delimiter: the delimiter marks key hierarchy + :type delimiter: str + :return: the key object from the bucket or None if none has been found. + :rtype: boto3.s3.Object + """ + prefix = re.split(r"[*]", wildcard_key, 1)[0] + key_list = self.list_keys(bucket_name, prefix=prefix, delimiter=delimiter) + key_matches = [k for k in key_list if fnmatch.fnmatch(k, wildcard_key)] + if key_matches: + return self.get_key(key_matches[0], bucket_name) + return None + + @provide_bucket_name + @unify_bucket_name_and_key + def load_file( + self, + filename: str, + key: str, + bucket_name: Optional[str] = None, + replace: bool = False, + encrypt: bool = False, + gzip: bool = False, + acl_policy: Optional[str] = None, + ) -> None: + """ + Loads a local file to S3 + + :param filename: name of the file to load. + :type filename: str + :param key: S3 key that will point to the file + :type key: str + :param bucket_name: Name of the bucket in which to store the file + :type bucket_name: str + :param replace: A flag to decide whether or not to overwrite the key + if it already exists. If replace is False and the key exists, an + error will be raised. + :type replace: bool + :param encrypt: If True, the file will be encrypted on the server-side + by S3 and will be stored in an encrypted form while at rest in S3. + :type encrypt: bool + :param gzip: If True, the file will be compressed locally + :type gzip: bool + :param acl_policy: String specifying the canned ACL policy for the file being + uploaded to the S3 bucket. + :type acl_policy: str + """ + if not replace and self.check_for_key(key, bucket_name): + raise ValueError(f"The key {key} already exists.") + + extra_args = self.extra_args + if encrypt: + extra_args["ServerSideEncryption"] = "AES256" + if gzip: + with open(filename, "rb") as f_in: + filename_gz = f_in.name + ".gz" + with gz.open(filename_gz, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + filename = filename_gz + if acl_policy: + extra_args["ACL"] = acl_policy + + client = self.get_conn() + client.upload_file( + filename, + bucket_name, + key, + ExtraArgs=extra_args, + Config=self.transfer_config, + ) + + @provide_bucket_name + @unify_bucket_name_and_key + def load_string( + self, + string_data: str, + key: str, + bucket_name: Optional[str] = None, + replace: bool = False, + encrypt: bool = False, + encoding: Optional[str] = None, + acl_policy: Optional[str] = None, + compression: Optional[str] = None, + ) -> None: + """ + Loads a string to S3 + + This is provided as a convenience to drop a string in S3. It uses the + boto infrastructure to ship a file to s3. + + :param string_data: str to set as content for the key. + :type string_data: str + :param key: S3 key that will point to the file + :type key: str + :param bucket_name: Name of the bucket in which to store the file + :type bucket_name: str + :param replace: A flag to decide whether or not to overwrite the key + if it already exists + :type replace: bool + :param encrypt: If True, the file will be encrypted on the server-side + by S3 and will be stored in an encrypted form while at rest in S3. + :type encrypt: bool + :param encoding: The string to byte encoding + :type encoding: str + :param acl_policy: The string to specify the canned ACL policy for the + object to be uploaded + :type acl_policy: str + :param compression: Type of compression to use, currently only gzip is supported. + :type compression: str + """ + encoding = encoding or "utf-8" + + bytes_data = string_data.encode(encoding) + + # Compress string + available_compressions = ["gzip"] + if compression is not None and compression not in available_compressions: + raise NotImplementedError( + "Received {} compression type. String " + "can currently be compressed in {} " + "only.".format(compression, available_compressions) + ) + if compression == "gzip": + bytes_data = gz.compress(bytes_data) + + file_obj = io.BytesIO(bytes_data) + + self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy) + file_obj.close() + + @provide_bucket_name + @unify_bucket_name_and_key + def load_bytes( + self, + bytes_data: bytes, + key: str, + bucket_name: Optional[str] = None, + replace: bool = False, + encrypt: bool = False, + acl_policy: Optional[str] = None, + ) -> None: + """ + Loads bytes to S3 + + This is provided as a convenience to drop a string in S3. It uses the + boto infrastructure to ship a file to s3. + + :param bytes_data: bytes to set as content for the key. + :type bytes_data: bytes + :param key: S3 key that will point to the file + :type key: str + :param bucket_name: Name of the bucket in which to store the file + :type bucket_name: str + :param replace: A flag to decide whether or not to overwrite the key + if it already exists + :type replace: bool + :param encrypt: If True, the file will be encrypted on the server-side + by S3 and will be stored in an encrypted form while at rest in S3. + :type encrypt: bool + :param acl_policy: The string to specify the canned ACL policy for the + object to be uploaded + :type acl_policy: str + """ + file_obj = io.BytesIO(bytes_data) + self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy) + file_obj.close() + + @provide_bucket_name + @unify_bucket_name_and_key + def load_file_obj( + self, + file_obj: BytesIO, + key: str, + bucket_name: Optional[str] = None, + replace: bool = False, + encrypt: bool = False, + acl_policy: Optional[str] = None, + ) -> None: + """ + Loads a file object to S3 + + :param file_obj: The file-like object to set as the content for the S3 key. + :type file_obj: file-like object + :param key: S3 key that will point to the file + :type key: str + :param bucket_name: Name of the bucket in which to store the file + :type bucket_name: str + :param replace: A flag that indicates whether to overwrite the key + if it already exists. + :type replace: bool + :param encrypt: If True, S3 encrypts the file on the server, + and the file is stored in encrypted form at rest in S3. + :type encrypt: bool + :param acl_policy: The string to specify the canned ACL policy for the + object to be uploaded + :type acl_policy: str + """ + self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy) + + def _upload_file_obj( + self, + file_obj: BytesIO, + key: str, + bucket_name: Optional[str] = None, + replace: bool = False, + encrypt: bool = False, + acl_policy: Optional[str] = None, + ) -> None: + if not replace and self.check_for_key(key, bucket_name): + raise ValueError(f"The key {key} already exists.") + + extra_args = self.extra_args + if encrypt: + extra_args["ServerSideEncryption"] = "AES256" + if acl_policy: + extra_args["ACL"] = acl_policy + + client = self.get_conn() + client.upload_fileobj( + file_obj, + bucket_name, + key, + ExtraArgs=extra_args, + Config=self.transfer_config, + ) + + def copy_object( + self, + source_bucket_key: str, + dest_bucket_key: str, + source_bucket_name: Optional[str] = None, + dest_bucket_name: Optional[str] = None, + source_version_id: Optional[str] = None, + acl_policy: Optional[str] = None, + ) -> None: + """ + Creates a copy of an object that is already stored in S3. + + Note: the S3 connection used here needs to have access to both + source and destination bucket/key. + + :param source_bucket_key: The key of the source object. + + It can be either full s3:// style url or relative path from root level. + + When it's specified as a full s3:// url, please omit source_bucket_name. + :type source_bucket_key: str + :param dest_bucket_key: The key of the object to copy to. + + The convention to specify `dest_bucket_key` is the same + as `source_bucket_key`. + :type dest_bucket_key: str + :param source_bucket_name: Name of the S3 bucket where the source object is in. + + It should be omitted when `source_bucket_key` is provided as a full s3:// url. + :type source_bucket_name: str + :param dest_bucket_name: Name of the S3 bucket to where the object is copied. + + It should be omitted when `dest_bucket_key` is provided as a full s3:// url. + :type dest_bucket_name: str + :param source_version_id: Version ID of the source object (OPTIONAL) + :type source_version_id: str + :param acl_policy: The string to specify the canned ACL policy for the + object to be copied which is private by default. + :type acl_policy: str + """ + acl_policy = acl_policy or "private" + + if dest_bucket_name is None: + dest_bucket_name, dest_bucket_key = self.parse_s3_url(dest_bucket_key) + else: + parsed_url = urlparse(dest_bucket_key) + if parsed_url.scheme != "" or parsed_url.netloc != "": + raise AirflowException( + "If dest_bucket_name is provided, " + + "dest_bucket_key should be relative path " + + "from root level, rather than a full s3:// url" + ) + + if source_bucket_name is None: + source_bucket_name, source_bucket_key = self.parse_s3_url(source_bucket_key) + else: + parsed_url = urlparse(source_bucket_key) + if parsed_url.scheme != "" or parsed_url.netloc != "": + raise AirflowException( + "If source_bucket_name is provided, " + + "source_bucket_key should be relative path " + + "from root level, rather than a full s3:// url" + ) + + copy_source = { + "Bucket": source_bucket_name, + "Key": source_bucket_key, + "VersionId": source_version_id, + } + response = self.get_conn().copy_object( + Bucket=dest_bucket_name, + Key=dest_bucket_key, + CopySource=copy_source, + ACL=acl_policy, + ) + return response + + @provide_bucket_name + def delete_bucket(self, bucket_name: str, force_delete: bool = False) -> None: + """ + To delete s3 bucket, delete all s3 bucket objects and then delete the bucket. + + :param bucket_name: Bucket name + :type bucket_name: str + :param force_delete: Enable this to delete bucket even if not empty + :type force_delete: bool + :return: None + :rtype: None + """ + if force_delete: + bucket_keys = self.list_keys(bucket_name=bucket_name) + if bucket_keys: + self.delete_objects(bucket=bucket_name, keys=bucket_keys) + self.conn.delete_bucket(Bucket=bucket_name) + + def delete_objects(self, bucket: str, keys: Union[str, list]) -> None: + """ + Delete keys from the bucket. + + :param bucket: Name of the bucket in which you are going to delete object(s) + :type bucket: str + :param keys: The key(s) to delete from S3 bucket. + + When ``keys`` is a string, it's supposed to be the key name of + the single object to delete. + + When ``keys`` is a list, it's supposed to be the list of the + keys to delete. + :type keys: str or list + """ + if isinstance(keys, str): + keys = [keys] + + s3 = self.get_conn() + + # We can only send a maximum of 1000 keys per request. + # For details see: + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.delete_objects + for chunk in chunks(keys, chunk_size=1000): + response = s3.delete_objects( + Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]} + ) + deleted_keys = [x["Key"] for x in response.get("Deleted", [])] + self.log.info("Deleted: %s", deleted_keys) + if "Errors" in response: + errors_keys = [x["Key"] for x in response.get("Errors", [])] + raise AirflowException(f"Errors when deleting: {errors_keys}") + + @provide_bucket_name + @unify_bucket_name_and_key + def download_file( + self, + key: str, + bucket_name: Optional[str] = None, + local_path: Optional[str] = None, + ) -> str: + """ + Downloads a file from the S3 location to the local file system. + + :param key: The key path in S3. + :type key: str + :param bucket_name: The specific bucket to use. + :type bucket_name: Optional[str] + :param local_path: The local path to the downloaded file. If no path is provided it will use the + system's temporary directory. + :type local_path: Optional[str] + :return: the file name. + :rtype: str + """ + self.log.info( + "Downloading source S3 file from Bucket %s with path %s", bucket_name, key + ) + + if not self.check_for_key(key, bucket_name): + raise AirflowException( + f"The source file in Bucket {bucket_name} with path {key} does not exist" + ) + + s3_obj = self.get_key(key, bucket_name) + + with NamedTemporaryFile( + dir=local_path, prefix="airflow_tmp_", delete=False + ) as local_tmp_file: + s3_obj.download_fileobj(local_tmp_file) + + return local_tmp_file.name + + def generate_presigned_url( + self, + client_method: str, + params: Optional[dict] = None, + expires_in: int = 3600, + http_method: Optional[str] = None, + ) -> Optional[str]: + """ + Generate a presigned url given a client, its method, and arguments + + :param client_method: The client method to presign for. + :type client_method: str + :param params: The parameters normally passed to ClientMethod. + :type params: dict + :param expires_in: The number of seconds the presigned url is valid for. + By default it expires in an hour (3600 seconds). + :type expires_in: int + :param http_method: The http method to use on the generated url. + By default, the http method is whatever is used in the method's model. + :type http_method: str + :return: The presigned url. + :rtype: str + """ + s3_client = self.get_conn() + try: + return s3_client.generate_presigned_url( + ClientMethod=client_method, + Params=params, + ExpiresIn=expires_in, + HttpMethod=http_method, + ) + + except ClientError as e: + self.log.error(e.response["Error"]["Message"]) + return None + + @provide_bucket_name + def get_bucket_tagging( + self, bucket_name: Optional[str] = None + ) -> Optional[List[Dict[str, str]]]: + """ + Gets a List of tags from a bucket. + + :param bucket_name: The name of the bucket. + :type bucket_name: str + :return: A List containing the key/value pairs for the tags + :rtype: Optional[List[Dict[str, str]]] + """ + try: + s3_client = self.get_conn() + result = s3_client.get_bucket_tagging(Bucket=bucket_name)["TagSet"] + self.log.info("S3 Bucket Tag Info: %s", result) + return result + except ClientError as e: + self.log.error(e) + raise e + + @provide_bucket_name + def put_bucket_tagging( + self, + tag_set: Optional[List[Dict[str, str]]] = None, + key: Optional[str] = None, + value: Optional[str] = None, + bucket_name: Optional[str] = None, + ) -> None: + """ + Overwrites the existing TagSet with provided tags. Must provide either a TagSet or a key/value pair. + + :param tag_set: A List containing the key/value pairs for the tags. + :type tag_set: List[Dict[str, str]] + :param key: The Key for the new TagSet entry. + :type key: str + :param value: The Value for the new TagSet entry. + :type value: str + :param bucket_name: The name of the bucket. + :type bucket_name: str + :return: None + :rtype: None + """ + self.log.info( + "S3 Bucket Tag Info:\tKey: %s\tValue: %s\tSet: %s", key, value, tag_set + ) + if not tag_set: + tag_set = [] + if key and value: + tag_set.append({"Key": key, "Value": value}) + elif not tag_set or (key or value): + message = "put_bucket_tagging() requires either a predefined TagSet or a key/value pair." + self.log.error(message) + raise ValueError(message) + + try: + s3_client = self.get_conn() + s3_client.put_bucket_tagging( + Bucket=bucket_name, Tagging={"TagSet": tag_set} + ) + except ClientError as e: + self.log.error(e) + raise e + + @provide_bucket_name + def delete_bucket_tagging(self, bucket_name: Optional[str] = None) -> None: + """ + Deletes all tags from a bucket. + + :param bucket_name: The name of the bucket. + :type bucket_name: str + :return: None + :rtype: None + """ + s3_client = self.get_conn() + s3_client.delete_bucket_tagging(Bucket=bucket_name) diff --git a/reference/providers/amazon/aws/hooks/sagemaker.py b/reference/providers/amazon/aws/hooks/sagemaker.py new file mode 100644 index 0000000..2755b79 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/sagemaker.py @@ -0,0 +1,1039 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import collections +import os +import tarfile +import tempfile +import time +import warnings +from functools import partial +from typing import Any, Callable, Dict, Generator, List, Optional, Set + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.utils import timezone +from botocore.exceptions import ClientError + + +class LogState: + """ + Enum-style class holding all possible states of CloudWatch log streams. + https://sagemaker.readthedocs.io/en/stable/session.html#sagemaker.session.LogState + """ + + STARTING = 1 + WAIT_IN_PROGRESS = 2 + TAILING = 3 + JOB_COMPLETE = 4 + COMPLETE = 5 + + +# Position is a tuple that includes the last read timestamp and the number of items that were read +# at that time. This is used to figure out which event to start with on the next read. +Position = collections.namedtuple("Position", ["timestamp", "skip"]) + + +def argmin(arr, f: Callable) -> Optional[int]: + """Return the index, i, in arr that minimizes f(arr[i])""" + min_value = None + min_idx = None + for idx, item in enumerate(arr): + if item is not None: + if min_value is None or f(item) < min_value: + min_value = f(item) + min_idx = idx + return min_idx + + +def secondary_training_status_changed( + current_job_description: dict, prev_job_description: dict +) -> bool: + """ + Returns true if training job's secondary status message has changed. + + :param current_job_description: Current job description, returned from DescribeTrainingJob call. + :type current_job_description: dict + :param prev_job_description: Previous job description, returned from DescribeTrainingJob call. + :type prev_job_description: dict + + :return: Whether the secondary status message of a training job changed or not. + """ + current_secondary_status_transitions = current_job_description.get( + "SecondaryStatusTransitions" + ) + if ( + current_secondary_status_transitions is None + or len(current_secondary_status_transitions) == 0 + ): + return False + + prev_job_secondary_status_transitions = ( + prev_job_description.get("SecondaryStatusTransitions") + if prev_job_description is not None + else None + ) + + last_message = ( + prev_job_secondary_status_transitions[-1]["StatusMessage"] + if prev_job_secondary_status_transitions is not None + and len(prev_job_secondary_status_transitions) > 0 + else "" + ) + + message = current_job_description["SecondaryStatusTransitions"][-1]["StatusMessage"] + + return message != last_message + + +def secondary_training_status_message( + job_description: Dict[str, List[dict]], prev_description: Optional[dict] +) -> str: + """ + Returns a string contains start time and the secondary training job status message. + + :param job_description: Returned response from DescribeTrainingJob call + :type job_description: dict + :param prev_description: Previous job description from DescribeTrainingJob call + :type prev_description: dict + + :return: Job status string to be printed. + """ + current_transitions = job_description.get("SecondaryStatusTransitions") + if current_transitions is None or len(current_transitions) == 0: + return "" + + prev_transitions_num = 0 + if prev_description is not None: + if prev_description.get("SecondaryStatusTransitions") is not None: + prev_transitions_num = len(prev_description["SecondaryStatusTransitions"]) + + transitions_to_print = ( + current_transitions[-1:] + if len(current_transitions) == prev_transitions_num + else current_transitions[prev_transitions_num - len(current_transitions) :] + ) + + status_strs = [] + for transition in transitions_to_print: + message = transition["StatusMessage"] + time_str = timezone.convert_to_utc( + job_description["LastModifiedTime"] + ).strftime("%Y-%m-%d %H:%M:%S") + status_strs.append(f"{time_str} {transition['Status']} - {message}") + + return "\n".join(status_strs) + + +class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods + """ + Interact with Amazon SageMaker. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + non_terminal_states = {"InProgress", "Stopping"} + endpoint_non_terminal_states = { + "Creating", + "Updating", + "SystemUpdating", + "RollingBack", + "Deleting", + } + failed_states = {"Failed"} + + def __init__(self, *args, **kwargs): + super().__init__(client_type="sagemaker", *args, **kwargs) + self.s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) + self.logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id) + + def tar_and_s3_upload(self, path: str, key: str, bucket: str) -> None: + """ + Tar the local file or directory and upload to s3 + + :param path: local file or directory + :type path: str + :param key: s3 key + :type key: str + :param bucket: s3 bucket + :type bucket: str + :return: None + """ + with tempfile.TemporaryFile() as temp_file: + if os.path.isdir(path): + files = [os.path.join(path, name) for name in os.listdir(path)] + else: + files = [path] + with tarfile.open(mode="w:gz", fileobj=temp_file) as tar_file: + for f in files: + tar_file.add(f, arcname=os.path.basename(f)) + temp_file.seek(0) + self.s3_hook.load_file_obj(temp_file, key, bucket, replace=True) + + def configure_s3_resources(self, config: dict) -> None: + """ + Extract the S3 operations from the configuration and execute them. + + :param config: config of SageMaker operation + :type config: dict + :rtype: dict + """ + s3_operations = config.pop("S3Operations", None) + + if s3_operations is not None: + create_bucket_ops = s3_operations.get("S3CreateBucket", []) + upload_ops = s3_operations.get("S3Upload", []) + for op in create_bucket_ops: + self.s3_hook.create_bucket(bucket_name=op["Bucket"]) + for op in upload_ops: + if op["Tar"]: + self.tar_and_s3_upload(op["Path"], op["Key"], op["Bucket"]) + else: + self.s3_hook.load_file(op["Path"], op["Key"], op["Bucket"]) + + def check_s3_url(self, s3url: str) -> bool: + """ + Check if an S3 URL exists + + :param s3url: S3 url + :type s3url: str + :rtype: bool + """ + bucket, key = S3Hook.parse_s3_url(s3url) + if not self.s3_hook.check_for_bucket(bucket_name=bucket): + raise AirflowException(f"The input S3 Bucket {bucket} does not exist ") + if ( + key + and not self.s3_hook.check_for_key(key=key, bucket_name=bucket) + and not self.s3_hook.check_for_prefix( + prefix=key, bucket_name=bucket, delimiter="/" + ) + ): + # check if s3 key exists in the case user provides a single file + # or if s3 prefix exists in the case user provides multiple files in + # a prefix + raise AirflowException( + f"The input S3 Key or Prefix {s3url} does not exist in the Bucket {bucket}" + ) + return True + + def check_training_config(self, training_config: dict) -> None: + """ + Check if a training configuration is valid + + :param training_config: training_config + :type training_config: dict + :return: None + """ + if "InputDataConfig" in training_config: + for channel in training_config["InputDataConfig"]: + if "S3DataSource" in channel["DataSource"]: + self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"]) + + def check_tuning_config(self, tuning_config: dict) -> None: + """ + Check if a tuning configuration is valid + + :param tuning_config: tuning_config + :type tuning_config: dict + :return: None + """ + for channel in tuning_config["TrainingJobDefinition"]["InputDataConfig"]: + if "S3DataSource" in channel["DataSource"]: + self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"]) + + def get_log_conn(self): + """ + This method is deprecated. + Please use :py:meth:`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_conn` instead. + """ + warnings.warn( + "Method `get_log_conn` has been deprecated. " + "Please use `airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_conn` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + + return self.logs_hook.get_conn() + + def log_stream(self, log_group, stream_name, start_time=0, skip=0): + """ + This method is deprecated. + Please use + :py:meth:`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events` instead. + """ + warnings.warn( + "Method `log_stream` has been deprecated. " + "Please use " + "`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + + return self.logs_hook.get_log_events(log_group, stream_name, start_time, skip) + + def multi_stream_iter( + self, log_group: str, streams: list, positions=None + ) -> Generator: + """ + Iterate over the available events coming from a set of log streams in a single log group + interleaving the events from each stream so they're yielded in timestamp order. + + :param log_group: The name of the log group. + :type log_group: str + :param streams: A list of the log stream names. The position of the stream in this list is + the stream number. + :type streams: list + :param positions: A list of pairs of (timestamp, skip) which represents the last record + read from each stream. + :type positions: list + :return: A tuple of (stream number, cloudwatch log event). + """ + positions = positions or {s: Position(timestamp=0, skip=0) for s in streams} + event_iters = [ + self.logs_hook.get_log_events( + log_group, s, positions[s].timestamp, positions[s].skip + ) + for s in streams + ] + events: List[Optional[Any]] = [] + for event_stream in event_iters: + if not event_stream: + events.append(None) + continue + try: + events.append(next(event_stream)) + except StopIteration: + events.append(None) + + while any(events): + i = argmin(events, lambda x: x["timestamp"] if x else 9999999999) or 0 + yield i, events[i] + try: + events[i] = next(event_iters[i]) + except StopIteration: + events[i] = None + + def create_training_job( + self, + config: dict, + wait_for_completion: bool = True, + print_log: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + ): + """ + Create a training job + + :param config: the config for training + :type config: dict + :param wait_for_completion: if the program should keep running until job finishes + :type wait_for_completion: bool + :param check_interval: the time interval in seconds which the operator + will check the status of any SageMaker job + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :return: A response to training job creation + """ + self.check_training_config(config) + + response = self.get_conn().create_training_job(**config) + if print_log: + self.check_training_status_with_log( + config["TrainingJobName"], + self.non_terminal_states, + self.failed_states, + wait_for_completion, + check_interval, + max_ingestion_time, + ) + elif wait_for_completion: + describe_response = self.check_status( + config["TrainingJobName"], + "TrainingJobStatus", + self.describe_training_job, + check_interval, + max_ingestion_time, + ) + + billable_time = ( + describe_response["TrainingEndTime"] + - describe_response["TrainingStartTime"] + ) * describe_response["ResourceConfig"]["InstanceCount"] + self.log.info( + "Billable seconds: %d", int(billable_time.total_seconds()) + 1 + ) + + return response + + def create_tuning_job( + self, + config: dict, + wait_for_completion: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + ): + """ + Create a tuning job + + :param config: the config for tuning + :type config: dict + :param wait_for_completion: if the program should keep running until job finishes + :type wait_for_completion: bool + :param check_interval: the time interval in seconds which the operator + will check the status of any SageMaker job + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :return: A response to tuning job creation + """ + self.check_tuning_config(config) + + response = self.get_conn().create_hyper_parameter_tuning_job(**config) + if wait_for_completion: + self.check_status( + config["HyperParameterTuningJobName"], + "HyperParameterTuningJobStatus", + self.describe_tuning_job, + check_interval, + max_ingestion_time, + ) + return response + + def create_transform_job( + self, + config: dict, + wait_for_completion: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + ): + """ + Create a transform job + + :param config: the config for transform job + :type config: dict + :param wait_for_completion: if the program should keep running until job finishes + :type wait_for_completion: bool + :param check_interval: the time interval in seconds which the operator + will check the status of any SageMaker job + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :return: A response to transform job creation + """ + if "S3DataSource" in config["TransformInput"]["DataSource"]: + self.check_s3_url( + config["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"] + ) + + response = self.get_conn().create_transform_job(**config) + if wait_for_completion: + self.check_status( + config["TransformJobName"], + "TransformJobStatus", + self.describe_transform_job, + check_interval, + max_ingestion_time, + ) + return response + + def create_processing_job( + self, + config: dict, + wait_for_completion: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + ): + """ + Create a processing job + + :param config: the config for processing job + :type config: dict + :param wait_for_completion: if the program should keep running until job finishes + :type wait_for_completion: bool + :param check_interval: the time interval in seconds which the operator + will check the status of any SageMaker job + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :return: A response to transform job creation + """ + response = self.get_conn().create_processing_job(**config) + if wait_for_completion: + self.check_status( + config["ProcessingJobName"], + "ProcessingJobStatus", + self.describe_processing_job, + check_interval, + max_ingestion_time, + ) + return response + + def create_model(self, config: dict): + """ + Create a model job + + :param config: the config for model + :type config: dict + :return: A response to model creation + """ + return self.get_conn().create_model(**config) + + def create_endpoint_config(self, config: dict): + """ + Create an endpoint config + + :param config: the config for endpoint-config + :type config: dict + :return: A response to endpoint config creation + """ + return self.get_conn().create_endpoint_config(**config) + + def create_endpoint( + self, + config: dict, + wait_for_completion: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + ): + """ + Create an endpoint + + :param config: the config for endpoint + :type config: dict + :param wait_for_completion: if the program should keep running until job finishes + :type wait_for_completion: bool + :param check_interval: the time interval in seconds which the operator + will check the status of any SageMaker job + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :return: A response to endpoint creation + """ + response = self.get_conn().create_endpoint(**config) + if wait_for_completion: + self.check_status( + config["EndpointName"], + "EndpointStatus", + self.describe_endpoint, + check_interval, + max_ingestion_time, + non_terminal_states=self.endpoint_non_terminal_states, + ) + return response + + def update_endpoint( + self, + config: dict, + wait_for_completion: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + ): + """ + Update an endpoint + + :param config: the config for endpoint + :type config: dict + :param wait_for_completion: if the program should keep running until job finishes + :type wait_for_completion: bool + :param check_interval: the time interval in seconds which the operator + will check the status of any SageMaker job + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :return: A response to endpoint update + """ + response = self.get_conn().update_endpoint(**config) + if wait_for_completion: + self.check_status( + config["EndpointName"], + "EndpointStatus", + self.describe_endpoint, + check_interval, + max_ingestion_time, + non_terminal_states=self.endpoint_non_terminal_states, + ) + return response + + def describe_training_job(self, name: str): + """ + Return the training job info associated with the name + + :param name: the name of the training job + :type name: str + :return: A dict contains all the training job info + """ + return self.get_conn().describe_training_job(TrainingJobName=name) + + def describe_training_job_with_log( + self, + job_name: str, + positions, + stream_names: list, + instance_count: int, + state: int, + last_description: dict, + last_describe_job_call: float, + ): + """Return the training job info associated with job_name and print CloudWatch logs""" + log_group = "/aws/sagemaker/TrainingJobs" + + if len(stream_names) < instance_count: + # Log streams are created whenever a container starts writing to stdout/err, so this list + # may be dynamic until we have a stream for every instance. + logs_conn = self.logs_hook.get_conn() + try: + streams = logs_conn.describe_log_streams( + logGroupName=log_group, + logStreamNamePrefix=job_name + "/", + orderBy="LogStreamName", + limit=instance_count, + ) + stream_names = [s["logStreamName"] for s in streams["logStreams"]] + positions.update( + [ + (s, Position(timestamp=0, skip=0)) + for s in stream_names + if s not in positions + ] + ) + except logs_conn.exceptions.ResourceNotFoundException: + # On the very first training job run on an account, there's no log group until + # the container starts logging, so ignore any errors thrown about that + pass + + if len(stream_names) > 0: + for idx, event in self.multi_stream_iter( + log_group, stream_names, positions + ): + self.log.info(event["message"]) + ts, count = positions[stream_names[idx]] + if event["timestamp"] == ts: + positions[stream_names[idx]] = Position( + timestamp=ts, skip=count + 1 + ) + else: + positions[stream_names[idx]] = Position( + timestamp=event["timestamp"], skip=1 + ) + + if state == LogState.COMPLETE: + return state, last_description, last_describe_job_call + + if state == LogState.JOB_COMPLETE: + state = LogState.COMPLETE + elif time.monotonic() - last_describe_job_call >= 30: + description = self.describe_training_job(job_name) + last_describe_job_call = time.monotonic() + + if secondary_training_status_changed(description, last_description): + self.log.info( + secondary_training_status_message(description, last_description) + ) + last_description = description + + status = description["TrainingJobStatus"] + + if status not in self.non_terminal_states: + state = LogState.JOB_COMPLETE + return state, last_description, last_describe_job_call + + def describe_tuning_job(self, name: str) -> dict: + """ + Return the tuning job info associated with the name + + :param name: the name of the tuning job + :type name: str + :return: A dict contains all the tuning job info + """ + return self.get_conn().describe_hyper_parameter_tuning_job( + HyperParameterTuningJobName=name + ) + + def describe_model(self, name: str) -> dict: + """ + Return the SageMaker model info associated with the name + + :param name: the name of the SageMaker model + :type name: str + :return: A dict contains all the model info + """ + return self.get_conn().describe_model(ModelName=name) + + def describe_transform_job(self, name: str) -> dict: + """ + Return the transform job info associated with the name + + :param name: the name of the transform job + :type name: str + :return: A dict contains all the transform job info + """ + return self.get_conn().describe_transform_job(TransformJobName=name) + + def describe_processing_job(self, name: str) -> dict: + """ + Return the processing job info associated with the name + + :param name: the name of the processing job + :type name: str + :return: A dict contains all the processing job info + """ + return self.get_conn().describe_processing_job(ProcessingJobName=name) + + def describe_endpoint_config(self, name: str) -> dict: + """ + Return the endpoint config info associated with the name + + :param name: the name of the endpoint config + :type name: str + :return: A dict contains all the endpoint config info + """ + return self.get_conn().describe_endpoint_config(EndpointConfigName=name) + + def describe_endpoint(self, name: str) -> dict: + """ + :param name: the name of the endpoint + :type name: str + :return: A dict contains all the endpoint info + """ + return self.get_conn().describe_endpoint(EndpointName=name) + + def check_status( + self, + job_name: str, + key: str, + describe_function: Callable, + check_interval: int, + max_ingestion_time: Optional[int] = None, + non_terminal_states: Optional[Set] = None, + ): + """ + Check status of a SageMaker job + + :param job_name: name of the job to check status + :type job_name: str + :param key: the key of the response dict + that points to the state + :type key: str + :param describe_function: the function used to retrieve the status + :type describe_function: python callable + :param args: the arguments for the function + :param check_interval: the time interval in seconds which the operator + will check the status of any SageMaker job + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :param non_terminal_states: the set of nonterminal states + :type non_terminal_states: set + :return: response of describe call after job is done + """ + if not non_terminal_states: + non_terminal_states = self.non_terminal_states + + sec = 0 + running = True + + while running: + time.sleep(check_interval) + sec += check_interval + + try: + response = describe_function(job_name) + status = response[key] + self.log.info( + "Job still running for %s seconds... current status is %s", + sec, + status, + ) + except KeyError: + raise AirflowException("Could not get status of the SageMaker job") + except ClientError: + raise AirflowException("AWS request failed, check logs for more info") + + if status in non_terminal_states: + running = True + elif status in self.failed_states: + raise AirflowException( + f"SageMaker job failed because {response['FailureReason']}" + ) + else: + running = False + + if max_ingestion_time and sec > max_ingestion_time: + # ensure that the job gets killed if the max ingestion time is exceeded + raise AirflowException( + f"SageMaker job took more than {max_ingestion_time} seconds" + ) + + self.log.info("SageMaker Job completed") + response = describe_function(job_name) + return response + + def check_training_status_with_log( + self, + job_name: str, + non_terminal_states: set, + failed_states: set, + wait_for_completion: bool, + check_interval: int, + max_ingestion_time: Optional[int] = None, + ): + """ + Display the logs for a given training job, optionally tailing them until the + job is complete. + + :param job_name: name of the training job to check status and display logs for + :type job_name: str + :param non_terminal_states: the set of non_terminal states + :type non_terminal_states: set + :param failed_states: the set of failed states + :type failed_states: set + :param wait_for_completion: Whether to keep looking for new log entries + until the job completes + :type wait_for_completion: bool + :param check_interval: The interval in seconds between polling for new log entries and job completion + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :return: None + """ + sec = 0 + description = self.describe_training_job(job_name) + self.log.info(secondary_training_status_message(description, None)) + instance_count = description["ResourceConfig"]["InstanceCount"] + status = description["TrainingJobStatus"] + + stream_names: list = [] # The list of log streams + positions: dict = ( + {} + ) # The current position in each stream, map of stream name -> position + + job_already_completed = status not in non_terminal_states + + state = ( + LogState.TAILING + if wait_for_completion and not job_already_completed + else LogState.COMPLETE + ) + + # The loop below implements a state machine that alternates between checking the job status and + # reading whatever is available in the logs at this point. Note, that if we were called with + # wait_for_completion == False, we never check the job status. + # + # If wait_for_completion == TRUE and job is not completed, the initial state is TAILING + # If wait_for_completion == FALSE, the initial state is COMPLETE + # (doesn't matter if the job really is complete). + # + # The state table: + # + # STATE ACTIONS CONDITION NEW STATE + # ---------------- ---------------- ----------------- ---------------- + # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE + # Else TAILING + # JOB_COMPLETE Read logs, Pause Any COMPLETE + # COMPLETE Read logs, Exit N/A + # + # Notes: + # - The JOB_COMPLETE state forces us to do an extra pause and read any items that + # got to Cloudwatch after the job was marked complete. + last_describe_job_call = time.monotonic() + last_description = description + + while True: + time.sleep(check_interval) + sec += check_interval + + ( + state, + last_description, + last_describe_job_call, + ) = self.describe_training_job_with_log( + job_name, + positions, + stream_names, + instance_count, + state, + last_description, + last_describe_job_call, + ) + if state == LogState.COMPLETE: + break + + if max_ingestion_time and sec > max_ingestion_time: + # ensure that the job gets killed if the max ingestion time is exceeded + raise AirflowException( + f"SageMaker job took more than {max_ingestion_time} seconds" + ) + + if wait_for_completion: + status = last_description["TrainingJobStatus"] + if status in failed_states: + reason = last_description.get("FailureReason", "(No reason provided)") + raise AirflowException( + f"Error training {job_name}: {status} Reason: {reason}" + ) + billable_time = ( + last_description["TrainingEndTime"] + - last_description["TrainingStartTime"] + ) * instance_count + self.log.info( + "Billable seconds: %d", int(billable_time.total_seconds()) + 1 + ) + + def list_training_jobs( + self, + name_contains: Optional[str] = None, + max_results: Optional[int] = None, + **kwargs, + ) -> List[Dict]: # noqa: D402 + """ + This method wraps boto3's list_training_jobs(). The training job name and max results are configurable + via arguments. Other arguments are not, and should be provided via kwargs. Note boto3 expects these in + CamelCase format, for example: + + .. code-block:: python + + list_training_jobs(name_contains="myjob", StatusEquals="Failed") + + .. seealso:: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_training_jobs + + :param name_contains: (optional) partial name to match + :param max_results: (optional) maximum number of results to return. None returns infinite results + :param kwargs: (optional) kwargs to boto3's list_training_jobs method + :return: results of the list_training_jobs request + """ + config = {} + + if name_contains: + if "NameContains" in kwargs: + raise AirflowException( + "Either name_contains or NameContains can be provided, not both." + ) + config["NameContains"] = name_contains + + if "MaxResults" in kwargs and kwargs["MaxResults"] is not None: + if max_results: + raise AirflowException( + "Either max_results or MaxResults can be provided, not both." + ) + # Unset MaxResults, we'll use the SageMakerHook's internal method for iteratively fetching results + max_results = kwargs["MaxResults"] + del kwargs["MaxResults"] + + config.update(kwargs) + list_training_jobs_request = partial( + self.get_conn().list_training_jobs, **config + ) + results = self._list_request( + list_training_jobs_request, "TrainingJobSummaries", max_results=max_results + ) + return results + + def list_processing_jobs(self, **kwargs) -> List[Dict]: # noqa: D402 + """ + This method wraps boto3's list_processing_jobs(). All arguments should be provided via kwargs. + Note boto3 expects these in CamelCase format, for example: + + .. code-block:: python + + list_processing_jobs(NameContains="myjob", StatusEquals="Failed") + + .. seealso:: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_processing_jobs + + :param kwargs: (optional) kwargs to boto3's list_training_jobs method + :return: results of the list_processing_jobs request + """ + list_processing_jobs_request = partial( + self.get_conn().list_processing_jobs, **kwargs + ) + results = self._list_request( + list_processing_jobs_request, + "ProcessingJobSummaries", + max_results=kwargs.get("MaxResults"), + ) + return results + + def _list_request( + self, partial_func: Callable, result_key: str, max_results: Optional[int] = None + ) -> List[Dict]: + """ + All AWS boto3 list_* requests return results in batches (if the key "NextToken" is contained in the + result, there are more results to fetch). The default AWS batch size is 10, and configurable up to + 100. This function iteratively loads all results (or up to a given maximum). + + Each boto3 list_* function returns the results in a list with a different name. The key of this + structure must be given to iterate over the results, e.g. "TransformJobSummaries" for + list_transform_jobs(). + + :param partial_func: boto3 function with arguments + :param result_key: the result key to iterate over + :param max_results: maximum number of results to return (None = infinite) + :return: Results of the list_* request + """ + sagemaker_max_results = 100 # Fixed number set by AWS + + results: List[Dict] = [] + next_token = None + + while True: + kwargs = {} + if next_token is not None: + kwargs["NextToken"] = next_token + + if max_results is None: + kwargs["MaxResults"] = sagemaker_max_results + else: + kwargs["MaxResults"] = min( + max_results - len(results), sagemaker_max_results + ) + + response = partial_func(**kwargs) + self.log.debug("Fetched %s results.", len(response[result_key])) + results.extend(response[result_key]) + + if "NextToken" not in response or ( + max_results is not None and len(results) == max_results + ): + # Return when there are no results left (no NextToken) or when we've reached max_results. + return results + else: + next_token = response["NextToken"] diff --git a/reference/providers/amazon/aws/hooks/secrets_manager.py b/reference/providers/amazon/aws/hooks/secrets_manager.py new file mode 100644 index 0000000..73559a1 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/secrets_manager.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import base64 +import json +from typing import Union + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class SecretsManagerHook(AwsBaseHook): + """ + Interact with Amazon SecretsManager Service. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. see also:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs): + super().__init__(client_type="secretsmanager", *args, **kwargs) + + def get_secret(self, secret_name: str) -> Union[str, bytes]: + """ + Retrieve secret value from AWS Secrets Manager as a str or bytes + reflecting format it stored in the AWS Secrets Manager + + :param secret_name: name of the secrets. + :type secret_name: str + :return: Union[str, bytes] with the information about the secrets + :rtype: Union[str, bytes] + """ + # Depending on whether the secret is a string or binary, one of + # these fields will be populated. + get_secret_value_response = self.get_conn().get_secret_value( + SecretId=secret_name + ) + if "SecretString" in get_secret_value_response: + secret = get_secret_value_response["SecretString"] + else: + secret = base64.b64decode(get_secret_value_response["SecretBinary"]) + return secret + + def get_secret_as_dict(self, secret_name: str) -> dict: + """ + Retrieve secret value from AWS Secrets Manager in a dict representation + + :param secret_name: name of the secrets. + :type secret_name: str + :return: dict with the information about the secrets + :rtype: dict + """ + return json.loads(self.get_secret(secret_name)) diff --git a/reference/providers/amazon/aws/hooks/ses.py b/reference/providers/amazon/aws/hooks/ses.py new file mode 100644 index 0000000..0fd74c6 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/ses.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains AWS SES Hook""" +from typing import Any, Dict, Iterable, List, Optional, Union + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.utils.email import build_mime_message + + +class SESHook(AwsBaseHook): + """ + Interact with Amazon Simple Email Service. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = "ses" + super().__init__(*args, **kwargs) + + def send_email( # pylint: disable=too-many-arguments + self, + mail_from: str, + to: Union[str, Iterable[str]], + subject: str, + html_content: str, + files: Optional[List[str]] = None, + cc: Optional[Union[str, Iterable[str]]] = None, + bcc: Optional[Union[str, Iterable[str]]] = None, + mime_subtype: str = "mixed", + mime_charset: str = "utf-8", + reply_to: Optional[str] = None, + return_path: Optional[str] = None, + custom_headers: Optional[Dict[str, Any]] = None, + ) -> dict: + """ + Send email using Amazon Simple Email Service + + :param mail_from: Email address to set as email's from + :param to: List of email addresses to set as email's to + :param subject: Email's subject + :param html_content: Content of email in HTML format + :param files: List of paths of files to be attached + :param cc: List of email addresses to set as email's CC + :param bcc: List of email addresses to set as email's BCC + :param mime_subtype: Can be used to specify the sub-type of the message. Default = mixed + :param mime_charset: Email's charset. Default = UTF-8. + :param return_path: The email address to which replies will be sent. By default, replies + are sent to the original sender's email address. + :param reply_to: The email address to which message bounces and complaints should be sent. + "Return-Path" is sometimes called "envelope from," "envelope sender," or "MAIL FROM." + :param custom_headers: Additional headers to add to the MIME message. + No validations are run on these values and they should be able to be encoded. + :return: Response from Amazon SES service with unique message identifier. + """ + ses_client = self.get_conn() + + custom_headers = custom_headers or {} + if reply_to: + custom_headers["Reply-To"] = reply_to + if return_path: + custom_headers["Return-Path"] = return_path + + message, recipients = build_mime_message( + mail_from=mail_from, + to=to, + subject=subject, + html_content=html_content, + files=files, + cc=cc, + bcc=bcc, + mime_subtype=mime_subtype, + mime_charset=mime_charset, + custom_headers=custom_headers, + ) + + return ses_client.send_raw_email( + Source=mail_from, + Destinations=recipients, + RawMessage={"Data": message.as_string()}, + ) diff --git a/reference/providers/amazon/aws/hooks/sns.py b/reference/providers/amazon/aws/hooks/sns.py new file mode 100644 index 0000000..448806a --- /dev/null +++ b/reference/providers/amazon/aws/hooks/sns.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains AWS SNS hook""" +import json +from typing import Dict, Optional, Union + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +def _get_message_attribute(o): + if isinstance(o, bytes): + return {"DataType": "Binary", "BinaryValue": o} + if isinstance(o, str): + return {"DataType": "String", "StringValue": o} + if isinstance(o, (int, float)): + return {"DataType": "Number", "StringValue": str(o)} + if hasattr(o, "__iter__"): + return {"DataType": "String.Array", "StringValue": json.dumps(o)} + raise TypeError( + "Values in MessageAttributes must be one of bytes, str, int, float, or iterable; " + f"got {type(o)}" + ) + + +class AwsSnsHook(AwsBaseHook): + """ + Interact with Amazon Simple Notification Service. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs): + super().__init__(client_type="sns", *args, **kwargs) + + def publish_to_target( + self, + target_arn: str, + message: str, + subject: Optional[str] = None, + message_attributes: Optional[dict] = None, + ): + """ + Publish a message to a topic or an endpoint. + + :param target_arn: either a TopicArn or an EndpointArn + :type target_arn: str + :param message: the default message you want to send + :param message: str + :param subject: subject of message + :type subject: str + :param message_attributes: additional attributes to publish for message filtering. This should be + a flat dict; the DataType to be sent depends on the type of the value: + + - bytes = Binary + - str = String + - int, float = Number + - iterable = String.Array + + :type message_attributes: dict + """ + publish_kwargs: Dict[str, Union[str, dict]] = { + "TargetArn": target_arn, + "MessageStructure": "json", + "Message": json.dumps({"default": message}), + } + + # Construct args this way because boto3 distinguishes from missing args and those set to None + if subject: + publish_kwargs["Subject"] = subject + if message_attributes: + publish_kwargs["MessageAttributes"] = { + key: _get_message_attribute(val) + for key, val in message_attributes.items() + } + + return self.get_conn().publish(**publish_kwargs) diff --git a/reference/providers/amazon/aws/hooks/sqs.py b/reference/providers/amazon/aws/hooks/sqs.py new file mode 100644 index 0000000..8e46519 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/sqs.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains AWS SQS hook""" +from typing import Dict, Optional + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class SQSHook(AwsBaseHook): + """ + Interact with Amazon Simple Queue Service. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = "sqs" + super().__init__(*args, **kwargs) + + def create_queue(self, queue_name: str, attributes: Optional[Dict] = None) -> Dict: + """ + Create queue using connection object + + :param queue_name: name of the queue. + :type queue_name: str + :param attributes: additional attributes for the queue (default: None) + For details of the attributes parameter see :py:meth:`SQS.create_queue` + :type attributes: dict + + :return: dict with the information about the queue + For details of the returned value see :py:meth:`SQS.create_queue` + :rtype: dict + """ + return self.get_conn().create_queue( + QueueName=queue_name, Attributes=attributes or {} + ) + + def send_message( + self, + queue_url: str, + message_body: str, + delay_seconds: int = 0, + message_attributes: Optional[Dict] = None, + ) -> Dict: + """ + Send message to the queue + + :param queue_url: queue url + :type queue_url: str + :param message_body: the contents of the message + :type message_body: str + :param delay_seconds: seconds to delay the message + :type delay_seconds: int + :param message_attributes: additional attributes for the message (default: None) + For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message` + :type message_attributes: dict + + :return: dict with the information about the message sent + For details of the returned value see :py:meth:`botocore.client.SQS.send_message` + :rtype: dict + """ + return self.get_conn().send_message( + QueueUrl=queue_url, + MessageBody=message_body, + DelaySeconds=delay_seconds, + MessageAttributes=message_attributes or {}, + ) diff --git a/reference/providers/amazon/aws/hooks/step_function.py b/reference/providers/amazon/aws/hooks/step_function.py new file mode 100644 index 0000000..0652f69 --- /dev/null +++ b/reference/providers/amazon/aws/hooks/step_function.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from typing import Optional, Union + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class StepFunctionHook(AwsBaseHook): + """ + Interact with an AWS Step Functions State Machine. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, region_name: Optional[str] = None, *args, **kwargs) -> None: + kwargs["client_type"] = "stepfunctions" + super().__init__(*args, **kwargs) + + def start_execution( + self, + state_machine_arn: str, + name: Optional[str] = None, + state_machine_input: Union[dict, str, None] = None, + ) -> str: + """ + Start Execution of the State Machine. + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.start_execution + + :param state_machine_arn: AWS Step Function State Machine ARN + :type state_machine_arn: str + :param name: The name of the execution. + :type name: Optional[str] + :param state_machine_input: JSON data input to pass to the State Machine + :type state_machine_input: Union[Dict[str, any], str, None] + :return: Execution ARN + :rtype: str + """ + execution_args = {"stateMachineArn": state_machine_arn} + if name is not None: + execution_args["name"] = name + if state_machine_input is not None: + if isinstance(state_machine_input, str): + execution_args["input"] = state_machine_input + elif isinstance(state_machine_input, dict): + execution_args["input"] = json.dumps(state_machine_input) + + self.log.info("Executing Step Function State Machine: %s", state_machine_arn) + + response = self.conn.start_execution(**execution_args) + return response.get("executionArn") + + def describe_execution(self, execution_arn: str) -> dict: + """ + Describes a State Machine Execution + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.describe_execution + + :param execution_arn: ARN of the State Machine Execution + :type execution_arn: str + :return: Dict with Execution details + :rtype: dict + """ + return self.get_conn().describe_execution(executionArn=execution_arn) diff --git a/reference/providers/amazon/aws/log/__init__.py b/reference/providers/amazon/aws/log/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/amazon/aws/log/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/amazon/aws/log/cloudwatch_task_handler.py b/reference/providers/amazon/aws/log/cloudwatch_task_handler.py new file mode 100644 index 0000000..c1599ea --- /dev/null +++ b/reference/providers/amazon/aws/log/cloudwatch_task_handler.py @@ -0,0 +1,135 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import watchtower + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.configuration import conf +from airflow.utils.log.file_task_handler import FileTaskHandler +from airflow.utils.log.logging_mixin import LoggingMixin + + +class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin): + """ + CloudwatchTaskHandler is a python log handler that handles and reads task instance logs. + + It extends airflow FileTaskHandler and uploads to and reads from Cloudwatch. + + :param base_log_folder: base folder to store logs locally + :type base_log_folder: str + :param log_group_arn: ARN of the Cloudwatch log group for remote log storage + with format ``arn:aws:logs:{region name}:{account id}:log-group:{group name}`` + :type log_group_arn: str + :param filename_template: template for file name (local storage) or log stream name (remote) + :type filename_template: str + """ + + def __init__( + self, base_log_folder: str, log_group_arn: str, filename_template: str + ): + super().__init__(base_log_folder, filename_template) + split_arn = log_group_arn.split(":") + + self.handler = None + self.log_group = split_arn[6] + self.region_name = split_arn[3] + self.closed = False + + @cached_property + def hook(self): + """Returns AwsLogsHook.""" + remote_conn_id = conf.get("logging", "REMOTE_LOG_CONN_ID") + try: + from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook + + return AwsLogsHook(aws_conn_id=remote_conn_id, region_name=self.region_name) + except Exception as e: # pylint: disable=broad-except + self.log.error( + 'Could not create an AwsLogsHook with connection id "%s". ' + "Please make sure that airflow[aws] is installed and " + 'the Cloudwatch logs connection exists. Exception: "%s"', + remote_conn_id, + e, + ) + return None + + def _render_filename(self, ti, try_number): + # Replace unsupported log group name characters + return super()._render_filename(ti, try_number).replace(":", "_") + + def set_context(self, ti): + super().set_context(ti) + self.handler = watchtower.CloudWatchLogHandler( + log_group=self.log_group, + stream_name=self._render_filename(ti, ti.try_number), + boto3_session=self.hook.get_session(self.region_name), + ) + + def close(self): + """Close the handler responsible for the upload of the local log file to Cloudwatch.""" + # When application exit, system shuts down all handlers by + # calling close method. Here we check if logger is already + # closed to prevent uploading the log to remote storage multiple + # times when `logging.shutdown` is called. + if self.closed: + return + + if self.handler is not None: + self.handler.close() + # Mark closed so we don't double write if close is called twice + self.closed = True + + def _read(self, task_instance, try_number, metadata=None): + stream_name = self._render_filename(task_instance, try_number) + return ( + "*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\n{}\n".format( + self.log_group, + stream_name, + self.get_cloudwatch_logs(stream_name=stream_name), + ), + {"end_of_log": True}, + ) + + def get_cloudwatch_logs(self, stream_name: str) -> str: + """ + Return all logs from the given log stream. + + :param stream_name: name of the Cloudwatch log stream to get all logs from + :return: string of all logs from the given log stream + """ + try: + events = list( + self.hook.get_log_events( + log_group=self.log_group, + log_stream_name=stream_name, + start_from_head=True, + ) + ) + return "\n".join([event["message"] for event in events]) + except Exception: # pylint: disable=broad-except + msg = ( + "Could not read remote logs from log_group: {} log_stream: {}.".format( + self.log_group, stream_name + ) + ) + self.log.exception(msg) + return msg diff --git a/reference/providers/amazon/aws/log/s3_task_handler.py b/reference/providers/amazon/aws/log/s3_task_handler.py new file mode 100644 index 0000000..f604452 --- /dev/null +++ b/reference/providers/amazon/aws/log/s3_task_handler.py @@ -0,0 +1,204 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.configuration import conf +from airflow.utils.log.file_task_handler import FileTaskHandler +from airflow.utils.log.logging_mixin import LoggingMixin + + +class S3TaskHandler(FileTaskHandler, LoggingMixin): + """ + S3TaskHandler is a python log handler that handles and reads + task instance logs. It extends airflow FileTaskHandler and + uploads to and reads from S3 remote storage. + """ + + def __init__( + self, base_log_folder: str, s3_log_folder: str, filename_template: str + ): + super().__init__(base_log_folder, filename_template) + self.remote_base = s3_log_folder + self.log_relative_path = "" + self._hook = None + self.closed = False + self.upload_on_close = True + + @cached_property + def hook(self): + """Returns S3Hook.""" + remote_conn_id = conf.get("logging", "REMOTE_LOG_CONN_ID") + try: + from airflow.providers.amazon.aws.hooks.s3 import S3Hook + + return S3Hook(remote_conn_id, transfer_config_args={"use_threads": False}) + except Exception as e: # pylint: disable=broad-except + self.log.exception( + 'Could not create an S3Hook with connection id "%s". ' + "Please make sure that airflow[aws] is installed and " + 'the S3 connection exists. Exception : "%s"', + remote_conn_id, + e, + ) + return None + + def set_context(self, ti): + super().set_context(ti) + # Local location and remote location is needed to open and + # upload local log file to S3 remote storage. + self.log_relative_path = self._render_filename(ti, ti.try_number) + self.upload_on_close = not ti.raw + + # Clear the file first so that duplicate data is not uploaded + # when re-using the same path (e.g. with rescheduled sensors) + if self.upload_on_close: + with open(self.handler.baseFilename, "w"): + pass + + def close(self): + """Close and upload local log file to remote storage S3.""" + # When application exit, system shuts down all handlers by + # calling close method. Here we check if logger is already + # closed to prevent uploading the log to remote storage multiple + # times when `logging.shutdown` is called. + if self.closed: + return + + super().close() + + if not self.upload_on_close: + return + + local_loc = os.path.join(self.local_base, self.log_relative_path) + remote_loc = os.path.join(self.remote_base, self.log_relative_path) + if os.path.exists(local_loc): + # read log and remove old logs to get just the latest additions + with open(local_loc) as logfile: + log = logfile.read() + self.s3_write(log, remote_loc) + + # Mark closed so we don't double write if close is called twice + self.closed = True + + def _read(self, ti, try_number, metadata=None): + """ + Read logs of given task instance and try_number from S3 remote storage. + If failed, read the log from task instance host machine. + + :param ti: task instance object + :param try_number: task instance try_number to read logs from + :param metadata: log metadata, + can be used for steaming log reading and auto-tailing. + """ + # Explicitly getting log relative path is necessary as the given + # task instance might be different than task instance passed in + # in set_context method. + log_relative_path = self._render_filename(ti, try_number) + remote_loc = os.path.join(self.remote_base, log_relative_path) + + log_exists = False + log = "" + + try: + log_exists = self.s3_log_exists(remote_loc) + except Exception as error: # pylint: disable=broad-except + self.log.exception(error) + log = ( + f"*** Failed to verify remote log exists {remote_loc}.\n{str(error)}\n" + ) + + if log_exists: + # If S3 remote file exists, we do not fetch logs from task instance + # local machine even if there are errors reading remote logs, as + # returned remote_log will contain error messages. + remote_log = self.s3_read(remote_loc, return_error=True) + log = f"*** Reading remote log from {remote_loc}.\n{remote_log}\n" + return log, {"end_of_log": True} + else: + log += "*** Falling back to local log\n" + local_log, metadata = super()._read(ti, try_number) + return log + local_log, metadata + + def s3_log_exists(self, remote_log_location: str) -> bool: + """ + Check if remote_log_location exists in remote storage + + :param remote_log_location: log's location in remote storage + :type remote_log_location: str + :return: True if location exists else False + """ + return self.hook.check_for_key(remote_log_location) + + def s3_read(self, remote_log_location: str, return_error: bool = False) -> str: + """ + Returns the log found at the remote_log_location. Returns '' if no + logs are found or there is an error. + + :param remote_log_location: the log's location in remote storage + :type remote_log_location: str (path) + :param return_error: if True, returns a string error message if an + error occurs. Otherwise returns '' when an error occurs. + :type return_error: bool + :return: the log found at the remote_log_location + """ + try: + return self.hook.read_key(remote_log_location) + except Exception as error: # pylint: disable=broad-except + msg = f"Could not read logs from {remote_log_location} with error: {error}" + self.log.exception(msg) + # return error if needed + if return_error: + return msg + return "" + + def s3_write(self, log: str, remote_log_location: str, append: bool = True): + """ + Writes the log to the remote_log_location. Fails silently if no hook + was created. + + :param log: the log to write to the remote_log_location + :type log: str + :param remote_log_location: the log's location in remote storage + :type remote_log_location: str (path) + :param append: if False, any existing log file is overwritten. If True, + the new log is appended to any existing logs. + :type append: bool + """ + try: + if append and self.s3_log_exists(remote_log_location): + old_log = self.s3_read(remote_log_location) + log = "\n".join([old_log, log]) if old_log else log + except Exception as error: # pylint: disable=broad-except + self.log.exception( + "Could not verify previous log to append: %s", str(error) + ) + + try: + self.hook.load_string( + log, + key=remote_log_location, + replace=True, + encrypt=conf.getboolean("logging", "ENCRYPT_S3_LOGS"), + ) + except Exception: # pylint: disable=broad-except + self.log.exception("Could not write logs to %s", remote_log_location) diff --git a/reference/providers/amazon/aws/operators/__init__.py b/reference/providers/amazon/aws/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/amazon/aws/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/amazon/aws/operators/athena.py b/reference/providers/amazon/aws/operators/athena.py new file mode 100644 index 0000000..c95a648 --- /dev/null +++ b/reference/providers/amazon/aws/operators/athena.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Any, Dict, Optional +from uuid import uuid4 + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook +from airflow.utils.decorators import apply_defaults + + +class AWSAthenaOperator(BaseOperator): + """ + An operator that submits a presto query to athena. + + :param query: Presto to be run on athena. (templated) + :type query: str + :param database: Database to select. (templated) + :type database: str + :param output_location: s3 path to write the query results into. (templated) + :type output_location: str + :param aws_conn_id: aws connection to use + :type aws_conn_id: str + :param client_request_token: Unique token created by user to avoid multiple executions of same query + :type client_request_token: str + :param workgroup: Athena workgroup in which query will be run + :type workgroup: str + :param query_execution_context: Context in which query need to be run + :type query_execution_context: dict + :param result_configuration: Dict with path to store results in and config related to encryption + :type result_configuration: dict + :param sleep_time: Time (in seconds) to wait between two consecutive calls to check query status on Athena + :type sleep_time: int + :param max_tries: Number of times to poll for query state before function exits + :type max_tries: int + """ + + ui_color = "#44b5e2" + template_fields = ("query", "database", "output_location") + template_ext = (".sql",) + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + query: str, + database: str, + output_location: str, + aws_conn_id: str = "aws_default", + client_request_token: Optional[str] = None, + workgroup: str = "primary", + query_execution_context: Optional[Dict[str, str]] = None, + result_configuration: Optional[Dict[str, Any]] = None, + sleep_time: int = 30, + max_tries: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.query = query + self.database = database + self.output_location = output_location + self.aws_conn_id = aws_conn_id + self.client_request_token = client_request_token or str(uuid4()) + self.workgroup = workgroup + self.query_execution_context = query_execution_context or {} + self.result_configuration = result_configuration or {} + self.sleep_time = sleep_time + self.max_tries = max_tries + self.query_execution_id = None # type: Optional[str] + + @cached_property + def hook(self) -> AWSAthenaHook: + """Create and return an AWSAthenaHook.""" + return AWSAthenaHook(self.aws_conn_id, sleep_time=self.sleep_time) + + def execute(self, context: dict) -> Optional[str]: + """Run Presto Query on Athena""" + self.query_execution_context["Database"] = self.database + self.result_configuration["OutputLocation"] = self.output_location + self.query_execution_id = self.hook.run_query( + self.query, + self.query_execution_context, + self.result_configuration, + self.client_request_token, + self.workgroup, + ) + query_status = self.hook.poll_query_status( + self.query_execution_id, self.max_tries + ) + + if query_status in AWSAthenaHook.FAILURE_STATES: + error_message = self.hook.get_state_change_reason(self.query_execution_id) + raise Exception( + "Final state of Athena job is {}, query_execution_id is {}. Error: {}".format( + query_status, self.query_execution_id, error_message + ) + ) + elif not query_status or query_status in AWSAthenaHook.INTERMEDIATE_STATES: + raise Exception( + "Final state of Athena job is {}. " + "Max tries of poll status exceeded, query_execution_id is {}.".format( + query_status, self.query_execution_id + ) + ) + + return self.query_execution_id + + def on_kill(self) -> None: + """Cancel the submitted athena query""" + if self.query_execution_id: + self.log.info("Received a kill signal.") + self.log.info( + "Stopping Query with executionId - %s", self.query_execution_id + ) + response = self.hook.stop_query(self.query_execution_id) + http_status_code = None + try: + http_status_code = response["ResponseMetadata"]["HTTPStatusCode"] + except Exception as ex: # pylint: disable=broad-except + self.log.error("Exception while cancelling query: %s", ex) + finally: + if http_status_code is None or http_status_code != 200: + self.log.error("Unable to request query cancel on athena. Exiting") + else: + self.log.info( + "Polling Athena for query with id %s to reach final state", + self.query_execution_id, + ) + self.hook.poll_query_status(self.query_execution_id) diff --git a/reference/providers/amazon/aws/operators/batch.py b/reference/providers/amazon/aws/operators/batch.py new file mode 100644 index 0000000..87bedfe --- /dev/null +++ b/reference/providers/amazon/aws/operators/batch.py @@ -0,0 +1,207 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +""" +An Airflow operator for AWS Batch services + +.. seealso:: + + - http://boto3.readthedocs.io/en/latest/guide/configuration.html + - http://boto3.readthedocs.io/en/latest/reference/services/batch.html + - https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html +""" +from typing import Any, Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.batch_client import AwsBatchClientHook +from airflow.utils.decorators import apply_defaults + + +class AwsBatchOperator(BaseOperator): + """ + Execute a job on AWS Batch + + :param job_name: the name for the job that will run on AWS Batch (templated) + :type job_name: str + + :param job_definition: the job definition name on AWS Batch + :type job_definition: str + + :param job_queue: the queue name on AWS Batch + :type job_queue: str + + :param overrides: the `containerOverrides` parameter for boto3 (templated) + :type overrides: Optional[dict] + + :param array_properties: the `arrayProperties` parameter for boto3 + :type array_properties: Optional[dict] + + :param parameters: the `parameters` for boto3 (templated) + :type parameters: Optional[dict] + + :param job_id: the job ID, usually unknown (None) until the + submit_job operation gets the jobId defined by AWS Batch + :type job_id: Optional[str] + + :param waiters: an :py:class:`.AwsBatchWaiters` object (see note below); + if None, polling is used with max_retries and status_retries. + :type waiters: Optional[AwsBatchWaiters] + + :param max_retries: exponential back-off retries, 4200 = 48 hours; + polling is only used when waiters is None + :type max_retries: int + + :param status_retries: number of HTTP retries to get job status, 10; + polling is only used when waiters is None + :type status_retries: int + + :param aws_conn_id: connection id of AWS credentials / region name. If None, + credential boto3 strategy will be used. + :type aws_conn_id: str + + :param region_name: region name to use in AWS Hook. + Override the region_name in connection (if provided) + :type region_name: str + + :param tags: collection of tags to apply to the AWS Batch job submission + if None, no tags are submitted + :type tags: dict + + .. note:: + Any custom waiters must return a waiter for these calls: + .. code-block:: python + + waiter = waiters.get_waiter("JobExists") + waiter = waiters.get_waiter("JobRunning") + waiter = waiters.get_waiter("JobComplete") + """ + + ui_color = "#c3dae0" + arn = None # type: Optional[str] + template_fields = ( + "job_name", + "overrides", + "parameters", + ) + + @apply_defaults + def __init__( + self, + *, + job_name: str, + job_definition: str, + job_queue: str, + overrides: dict, + array_properties: Optional[dict] = None, + parameters: Optional[dict] = None, + job_id: Optional[str] = None, + waiters: Optional[Any] = None, + max_retries: Optional[int] = None, + status_retries: Optional[int] = None, + aws_conn_id: Optional[str] = None, + region_name: Optional[str] = None, + tags: Optional[dict] = None, + **kwargs, + ): # pylint: disable=too-many-arguments + + BaseOperator.__init__(self, **kwargs) + self.job_id = job_id + self.job_name = job_name + self.job_definition = job_definition + self.job_queue = job_queue + self.overrides = overrides or {} + self.array_properties = array_properties or {} + self.parameters = parameters or {} + self.waiters = waiters + self.tags = tags or {} + self.hook = AwsBatchClientHook( + max_retries=max_retries, + status_retries=status_retries, + aws_conn_id=aws_conn_id, + region_name=region_name, + ) + + def execute(self, context: Dict): + """ + Submit and monitor an AWS Batch job + + :raises: AirflowException + """ + self.submit_job(context) + self.monitor_job(context) + + def on_kill(self): + response = self.hook.client.terminate_job( + jobId=self.job_id, reason="Task killed by the user" + ) + self.log.info("AWS Batch job (%s) terminated: %s", self.job_id, response) + + def submit_job(self, context: Dict): # pylint: disable=unused-argument + """ + Submit an AWS Batch job + + :raises: AirflowException + """ + self.log.info( + "Running AWS Batch job - job definition: %s - on queue %s", + self.job_definition, + self.job_queue, + ) + self.log.info("AWS Batch job - container overrides: %s", self.overrides) + + try: + response = self.hook.client.submit_job( + jobName=self.job_name, + jobQueue=self.job_queue, + jobDefinition=self.job_definition, + arrayProperties=self.array_properties, + parameters=self.parameters, + containerOverrides=self.overrides, + tags=self.tags, + ) + self.job_id = response["jobId"] + + self.log.info("AWS Batch job (%s) started: %s", self.job_id, response) + + except Exception as e: + self.log.error("AWS Batch job (%s) failed submission", self.job_id) + raise AirflowException(e) + + def monitor_job(self, context: Dict): # pylint: disable=unused-argument + """ + Monitor an AWS Batch job + + :raises: AirflowException + """ + if not self.job_id: + raise AirflowException("AWS Batch job - job_id was not found") + + try: + if self.waiters: + self.waiters.wait_for_job(self.job_id) + else: + self.hook.wait_for_job(self.job_id) + + self.hook.check_job_success(self.job_id) + self.log.info("AWS Batch job (%s) succeeded", self.job_id) + + except Exception as e: + self.log.error("AWS Batch job (%s) failed monitoring", self.job_id) + raise AirflowException(e) diff --git a/reference/providers/amazon/aws/operators/cloud_formation.py b/reference/providers/amazon/aws/operators/cloud_formation.py new file mode 100644 index 0000000..d8c0047 --- /dev/null +++ b/reference/providers/amazon/aws/operators/cloud_formation.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains CloudFormation create/delete stack operators.""" +from typing import List, Optional + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.cloud_formation import AWSCloudFormationHook +from airflow.utils.decorators import apply_defaults + + +class CloudFormationCreateStackOperator(BaseOperator): + """ + An operator that creates a CloudFormation stack. + + :param stack_name: stack name (templated) + :type stack_name: str + :param params: parameters to be passed to CloudFormation. + + .. seealso:: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/cloudformation.html#CloudFormation.Client.create_stack + :type params: dict + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + """ + + template_fields: List[str] = ["stack_name"] + template_ext = () + ui_color = "#6b9659" + + @apply_defaults + def __init__( + self, + *, + stack_name: str, + params: dict, + aws_conn_id: str = "aws_default", + **kwargs + ): + super().__init__(**kwargs) + self.stack_name = stack_name + self.params = params + self.aws_conn_id = aws_conn_id + + def execute(self, context): + self.log.info("Parameters: %s", self.params) + + cloudformation_hook = AWSCloudFormationHook(aws_conn_id=self.aws_conn_id) + cloudformation_hook.create_stack(self.stack_name, self.params) + + +class CloudFormationDeleteStackOperator(BaseOperator): + """ + An operator that deletes a CloudFormation stack. + + :param stack_name: stack name (templated) + :type stack_name: str + :param params: parameters to be passed to CloudFormation. + + .. seealso:: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/cloudformation.html#CloudFormation.Client.delete_stack + :type params: dict + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + """ + + template_fields: List[str] = ["stack_name"] + template_ext = () + ui_color = "#1d472b" + ui_fgcolor = "#FFF" + + @apply_defaults + def __init__( + self, + *, + stack_name: str, + params: Optional[dict] = None, + aws_conn_id: str = "aws_default", + **kwargs + ): + super().__init__(**kwargs) + self.params = params or {} + self.stack_name = stack_name + self.aws_conn_id = aws_conn_id + + def execute(self, context): + self.log.info("Parameters: %s", self.params) + + cloudformation_hook = AWSCloudFormationHook(aws_conn_id=self.aws_conn_id) + cloudformation_hook.delete_stack(self.stack_name, self.params) diff --git a/reference/providers/amazon/aws/operators/datasync.py b/reference/providers/amazon/aws/operators/datasync.py new file mode 100644 index 0000000..b08fe33 --- /dev/null +++ b/reference/providers/amazon/aws/operators/datasync.py @@ -0,0 +1,402 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Create, get, update, execute and delete an AWS DataSync Task.""" + +import logging +import random +from typing import List, Optional + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.datasync import AWSDataSyncHook +from airflow.utils.decorators import apply_defaults + + +# pylint: disable=too-many-instance-attributes, too-many-arguments +class AWSDataSyncOperator(BaseOperator): + r"""Find, Create, Update, Execute and Delete AWS DataSync Tasks. + + If ``do_xcom_push`` is True, then the DataSync TaskArn and TaskExecutionArn + which were executed will be pushed to an XCom. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AWSDataSyncOperator` + + .. note:: There may be 0, 1, or many existing DataSync Tasks defined in your AWS + environment. The default behavior is to create a new Task if there are 0, or + execute the Task if there was 1 Task, or fail if there were many Tasks. + + :param str aws_conn_id: AWS connection to use. + :param int wait_interval_seconds: Time to wait between two + consecutive calls to check TaskExecution status. + :param str task_arn: AWS DataSync TaskArn to use. If None, then this operator will + attempt to either search for an existing Task or attempt to create a new Task. + :param str source_location_uri: Source location URI to search for. All DataSync + Tasks with a LocationArn with this URI will be considered. + Example: ``smb://server/subdir`` + :param str destination_location_uri: Destination location URI to search for. + All DataSync Tasks with a LocationArn with this URI will be considered. + Example: ``s3://airflow_bucket/stuff`` + :param bool allow_random_task_choice: If multiple Tasks match, one must be chosen to + execute. If allow_random_task_choice is True then a random one is chosen. + :param bool allow_random_location_choice: If multiple Locations match, one must be chosen + when creating a task. If allow_random_location_choice is True then a random one is chosen. + :param dict create_task_kwargs: If no suitable TaskArn is identified, + it will be created if ``create_task_kwargs`` is defined. + ``create_task_kwargs`` is then used internally like this: + ``boto3.create_task(**create_task_kwargs)`` + Example: ``{'Name': 'xyz', 'Options': ..., 'Excludes': ..., 'Tags': ...}`` + :param dict create_source_location_kwargs: If no suitable LocationArn is found, + a Location will be created if ``create_source_location_kwargs`` is defined. + ``create_source_location_kwargs`` is then used internally like this: + ``boto3.create_location_xyz(**create_source_location_kwargs)`` + The xyz is determined from the prefix of source_location_uri, eg ``smb:/...`` or ``s3:/...`` + Example: ``{'Subdirectory': ..., 'ServerHostname': ..., ...}`` + :param dict create_destination_location_kwargs: If no suitable LocationArn is found, + a Location will be created if ``create_destination_location_kwargs`` is defined. + ``create_destination_location_kwargs`` is used internally like this: + ``boto3.create_location_xyz(**create_destination_location_kwargs)`` + The xyz is determined from the prefix of destination_location_uri, eg ``smb:/...` or ``s3:/...`` + Example: ``{'S3BucketArn': ..., 'S3Config': {'BucketAccessRoleArn': ...}, ...}`` + :param dict update_task_kwargs: If a suitable TaskArn is found or created, + it will be updated if ``update_task_kwargs`` is defined. + ``update_task_kwargs`` is used internally like this: + ``boto3.update_task(TaskArn=task_arn, **update_task_kwargs)`` + Example: ``{'Name': 'xyz', 'Options': ..., 'Excludes': ...}`` + :param dict task_execution_kwargs: Additional kwargs passed directly when starting the + Task execution, used internally like this: + ``boto3.start_task_execution(TaskArn=task_arn, **task_execution_kwargs)`` + :param bool delete_task_after_execution: If True then the TaskArn which was executed + will be deleted from AWS DataSync on successful completion. + :raises AirflowException: If ``task_arn`` was not specified, or if + either ``source_location_uri`` or ``destination_location_uri`` were + not specified. + :raises AirflowException: If source or destination Location were not found + and could not be created. + :raises AirflowException: If ``choose_task`` or ``choose_location`` fails. + :raises AirflowException: If Task creation, update, execution or delete fails. + """ + + template_fields = ( + "task_arn", + "source_location_uri", + "destination_location_uri", + "create_task_kwargs", + "create_source_location_kwargs", + "create_destination_location_kwargs", + "update_task_kwargs", + "task_execution_kwargs", + ) + ui_color = "#44b5e2" + + @apply_defaults + def __init__( + self, + *, + aws_conn_id: str = "aws_default", + wait_interval_seconds: int = 5, + task_arn: Optional[str] = None, + source_location_uri: Optional[str] = None, + destination_location_uri: Optional[str] = None, + allow_random_task_choice: bool = False, + allow_random_location_choice: bool = False, + create_task_kwargs: Optional[dict] = None, + create_source_location_kwargs: Optional[dict] = None, + create_destination_location_kwargs: Optional[dict] = None, + update_task_kwargs: Optional[dict] = None, + task_execution_kwargs: Optional[dict] = None, + delete_task_after_execution: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + + # Assignments + self.aws_conn_id = aws_conn_id + self.wait_interval_seconds = wait_interval_seconds + + self.task_arn = task_arn + + self.source_location_uri = source_location_uri + self.destination_location_uri = destination_location_uri + self.allow_random_task_choice = allow_random_task_choice + self.allow_random_location_choice = allow_random_location_choice + + self.create_task_kwargs = create_task_kwargs if create_task_kwargs else {} + self.create_source_location_kwargs = {} + if create_source_location_kwargs: + self.create_source_location_kwargs = create_source_location_kwargs + self.create_destination_location_kwargs = {} + if create_destination_location_kwargs: + self.create_destination_location_kwargs = create_destination_location_kwargs + + self.update_task_kwargs = update_task_kwargs if update_task_kwargs else {} + self.task_execution_kwargs = ( + task_execution_kwargs if task_execution_kwargs else {} + ) + self.delete_task_after_execution = delete_task_after_execution + + # Validations + valid = False + if self.task_arn: + valid = True + if self.source_location_uri and self.destination_location_uri: + valid = True + if not valid: + raise AirflowException( + "Either specify task_arn or both source_location_uri and destination_location_uri. " + "task_arn={} source_location_uri={} destination_location_uri={}".format( + task_arn, source_location_uri, destination_location_uri + ) + ) + + # Others + self.hook: Optional[AWSDataSyncHook] = None + # Candidates - these are found in AWS as possible things + # for us to use + self.candidate_source_location_arns: Optional[List[str]] = None + self.candidate_destination_location_arns: Optional[List[str]] = None + self.candidate_task_arns: Optional[List[str]] = None + # Actuals + self.source_location_arn: Optional[str] = None + self.destination_location_arn: Optional[str] = None + self.task_execution_arn: Optional[str] = None + + def get_hook(self) -> AWSDataSyncHook: + """Create and return AWSDataSyncHook. + + :return AWSDataSyncHook: An AWSDataSyncHook instance. + """ + if self.hook: + return self.hook + + self.hook = AWSDataSyncHook( + aws_conn_id=self.aws_conn_id, + wait_interval_seconds=self.wait_interval_seconds, + ) + return self.hook + + def execute(self, context): + # If task_arn was not specified then try to + # find 0, 1 or many candidate DataSync Tasks to run + if not self.task_arn: + self._get_tasks_and_locations() + + # If some were found, identify which one to run + if self.candidate_task_arns: + self.task_arn = self.choose_task(self.candidate_task_arns) + + # If we could not find one then try to create one + if not self.task_arn and self.create_task_kwargs: + self._create_datasync_task() + + if not self.task_arn: + raise AirflowException( + "DataSync TaskArn could not be identified or created." + ) + + self.log.info("Using DataSync TaskArn %s", self.task_arn) + + # Update the DataSync Task + if self.update_task_kwargs: + self._update_datasync_task() + + # Execute the DataSync Task + self._execute_datasync_task() + + if not self.task_execution_arn: + raise AirflowException("Nothing was executed") + + # Delete the DataSyncTask + if self.delete_task_after_execution: + self._delete_datasync_task() + + return {"TaskArn": self.task_arn, "TaskExecutionArn": self.task_execution_arn} + + def _get_tasks_and_locations(self) -> None: + """Find existing DataSync Task based on source and dest Locations.""" + hook = self.get_hook() + + self.candidate_source_location_arns = self._get_location_arns( + self.source_location_uri + ) + + self.candidate_destination_location_arns = self._get_location_arns( + self.destination_location_uri + ) + + if not self.candidate_source_location_arns: + self.log.info("No matching source Locations") + return + + if not self.candidate_destination_location_arns: + self.log.info("No matching destination Locations") + return + + self.log.info("Finding DataSync TaskArns that have these LocationArns") + self.candidate_task_arns = hook.get_task_arns_for_location_arns( + self.candidate_source_location_arns, + self.candidate_destination_location_arns, + ) + self.log.info("Found candidate DataSync TaskArns %s", self.candidate_task_arns) + + def choose_task(self, task_arn_list: list) -> Optional[str]: + """Select 1 DataSync TaskArn from a list""" + if not task_arn_list: + return None + if len(task_arn_list) == 1: + return task_arn_list[0] + if self.allow_random_task_choice: + # Items are unordered so we don't want to just take + # the [0] one as it implies ordered items were received + # from AWS and might lead to confusion. Rather explicitly + # choose a random one + return random.choice(task_arn_list) + raise AirflowException(f"Unable to choose a Task from {task_arn_list}") + + def choose_location(self, location_arn_list: Optional[List[str]]) -> Optional[str]: + """Select 1 DataSync LocationArn from a list""" + if not location_arn_list: + return None + if len(location_arn_list) == 1: + return location_arn_list[0] + if self.allow_random_location_choice: + # Items are unordered so we don't want to just take + # the [0] one as it implies ordered items were received + # from AWS and might lead to confusion. Rather explicitly + # choose a random one + return random.choice(location_arn_list) + raise AirflowException(f"Unable to choose a Location from {location_arn_list}") + + def _create_datasync_task(self) -> None: + """Create a AWS DataSyncTask.""" + hook = self.get_hook() + + self.source_location_arn = self.choose_location( + self.candidate_source_location_arns + ) + if ( + not self.source_location_arn + and self.source_location_uri + and self.create_source_location_kwargs + ): + self.log.info("Attempting to create source Location") + self.source_location_arn = hook.create_location( + self.source_location_uri, **self.create_source_location_kwargs + ) + if not self.source_location_arn: + raise AirflowException( + "Unable to determine source LocationArn. Does a suitable DataSync Location exist?" + ) + + self.destination_location_arn = self.choose_location( + self.candidate_destination_location_arns + ) + if ( + not self.destination_location_arn + and self.destination_location_uri + and self.create_destination_location_kwargs + ): + self.log.info("Attempting to create destination Location") + self.destination_location_arn = hook.create_location( + self.destination_location_uri, **self.create_destination_location_kwargs + ) + if not self.destination_location_arn: + raise AirflowException( + "Unable to determine destination LocationArn. Does a suitable DataSync Location exist?" + ) + + self.log.info("Creating a Task.") + self.task_arn = hook.create_task( + self.source_location_arn, + self.destination_location_arn, + **self.create_task_kwargs, + ) + if not self.task_arn: + raise AirflowException("Task could not be created") + self.log.info("Created a Task with TaskArn %s", self.task_arn) + + def _update_datasync_task(self) -> None: + """Update a AWS DataSyncTask.""" + if not self.task_arn: + return + + hook = self.get_hook() + self.log.info("Updating TaskArn %s", self.task_arn) + hook.update_task(self.task_arn, **self.update_task_kwargs) + self.log.info("Updated TaskArn %s", self.task_arn) + + def _execute_datasync_task(self) -> None: + """Create and monitor an AWSDataSync TaskExecution for a Task.""" + if not self.task_arn: + raise AirflowException("Missing TaskArn") + + hook = self.get_hook() + + # Create a task execution: + self.log.info("Starting execution for TaskArn %s", self.task_arn) + self.task_execution_arn = hook.start_task_execution( + self.task_arn, **self.task_execution_kwargs + ) + self.log.info("Started TaskExecutionArn %s", self.task_execution_arn) + + # Wait for task execution to complete + self.log.info("Waiting for TaskExecutionArn %s", self.task_execution_arn) + result = hook.wait_for_task_execution(self.task_execution_arn) + self.log.info("Completed TaskExecutionArn %s", self.task_execution_arn) + task_execution_description = hook.describe_task_execution( + task_execution_arn=self.task_execution_arn + ) + self.log.info("task_execution_description=%s", task_execution_description) + + # Log some meaningful statuses + level = logging.ERROR if not result else logging.INFO + self.log.log(level, "Status=%s", task_execution_description["Status"]) + if "Result" in task_execution_description: + for k, v in task_execution_description["Result"].items(): + if "Status" in k or "Error" in k: + self.log.log(level, "%s=%s", k, v) + + if not result: + raise AirflowException(f"Failed TaskExecutionArn {self.task_execution_arn}") + + def on_kill(self) -> None: + """Cancel the submitted DataSync task.""" + hook = self.get_hook() + if self.task_execution_arn: + self.log.info("Cancelling TaskExecutionArn %s", self.task_execution_arn) + hook.cancel_task_execution(task_execution_arn=self.task_execution_arn) + self.log.info("Cancelled TaskExecutionArn %s", self.task_execution_arn) + + def _delete_datasync_task(self) -> None: + """Deletes an AWS DataSync Task.""" + if not self.task_arn: + return + + hook = self.get_hook() + # Delete task: + self.log.info("Deleting Task with TaskArn %s", self.task_arn) + hook.delete_task(self.task_arn) + self.log.info("Task Deleted") + + def _get_location_arns(self, location_uri) -> List[str]: + location_arns = self.get_hook().get_location_arns(location_uri) + self.log.info( + "Found LocationArns %s for LocationUri %s", location_arns, location_uri + ) + return location_arns diff --git a/reference/providers/amazon/aws/operators/ec2_start_instance.py b/reference/providers/amazon/aws/operators/ec2_start_instance.py new file mode 100644 index 0000000..157fdd8 --- /dev/null +++ b/reference/providers/amazon/aws/operators/ec2_start_instance.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Optional + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook +from airflow.utils.decorators import apply_defaults + + +class EC2StartInstanceOperator(BaseOperator): + """ + Start AWS EC2 instance using boto3. + + :param instance_id: id of the AWS EC2 instance + :type instance_id: str + :param aws_conn_id: aws connection to use + :type aws_conn_id: str + :param region_name: (optional) aws region name associated with the client + :type region_name: Optional[str] + :param check_interval: time in seconds that the job should wait in + between each instance state checks until operation is completed + :type check_interval: float + """ + + template_fields = ("instance_id", "region_name") + ui_color = "#eeaa11" + ui_fgcolor = "#ffffff" + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + check_interval: float = 15, + **kwargs, + ): + super().__init__(**kwargs) + self.instance_id = instance_id + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.check_interval = check_interval + + def execute(self, context): + ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + self.log.info("Starting EC2 instance %s", self.instance_id) + instance = ec2_hook.get_instance(instance_id=self.instance_id) + instance.start() + ec2_hook.wait_for_state( + instance_id=self.instance_id, + target_state="running", + check_interval=self.check_interval, + ) diff --git a/reference/providers/amazon/aws/operators/ec2_stop_instance.py b/reference/providers/amazon/aws/operators/ec2_stop_instance.py new file mode 100644 index 0000000..fba4652 --- /dev/null +++ b/reference/providers/amazon/aws/operators/ec2_stop_instance.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Optional + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook +from airflow.utils.decorators import apply_defaults + + +class EC2StopInstanceOperator(BaseOperator): + """ + Stop AWS EC2 instance using boto3. + + :param instance_id: id of the AWS EC2 instance + :type instance_id: str + :param aws_conn_id: aws connection to use + :type aws_conn_id: str + :param region_name: (optional) aws region name associated with the client + :type region_name: Optional[str] + :param check_interval: time in seconds that the job should wait in + between each instance state checks until operation is completed + :type check_interval: float + """ + + template_fields = ("instance_id", "region_name") + ui_color = "#eeaa11" + ui_fgcolor = "#ffffff" + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + check_interval: float = 15, + **kwargs, + ): + super().__init__(**kwargs) + self.instance_id = instance_id + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.check_interval = check_interval + + def execute(self, context): + ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + self.log.info("Stopping EC2 instance %s", self.instance_id) + instance = ec2_hook.get_instance(instance_id=self.instance_id) + instance.stop() + ec2_hook.wait_for_state( + instance_id=self.instance_id, + target_state="stopped", + check_interval=self.check_interval, + ) diff --git a/reference/providers/amazon/aws/operators/ecs.py b/reference/providers/amazon/aws/operators/ecs.py new file mode 100644 index 0000000..3127264 --- /dev/null +++ b/reference/providers/amazon/aws/operators/ecs.py @@ -0,0 +1,367 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import re +import sys +from collections import deque +from datetime import datetime +from typing import Dict, Generator, Optional + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook +from airflow.typing_compat import Protocol, runtime_checkable +from airflow.utils.decorators import apply_defaults +from botocore.waiter import Waiter + + +@runtime_checkable +class ECSProtocol(Protocol): + """ + A structured Protocol for ``boto3.client('ecs')``. This is used for type hints on + :py:meth:`.ECSOperator.client`. + + .. seealso:: + + - https://mypy.readthedocs.io/en/latest/protocols.html + - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html + """ + + # pylint: disable=C0103, line-too-long + def run_task(self, **kwargs) -> Dict: + """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task""" # noqa: E501 + ... + + def get_waiter(self, x: str) -> Waiter: + """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.get_waiter""" # noqa: E501 + ... + + def describe_tasks(self, cluster: str, tasks) -> Dict: + """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_tasks""" # noqa: E501 + ... + + def stop_task(self, cluster, task, reason: str) -> Dict: + """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.stop_task""" # noqa: E501 + ... + + def describe_task_definition(self, taskDefinition: str) -> Dict: + """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_task_definition""" # noqa: E501 + ... + + def list_tasks( + self, cluster: str, launchType: str, desiredStatus: str, family: str + ) -> Dict: + """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.list_tasks""" # noqa: E501 + ... + + # pylint: enable=C0103, line-too-long + + +class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes + """ + Execute a task on AWS ECS (Elastic Container Service) + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ECSOperator` + + :param task_definition: the task definition name on Elastic Container Service + :type task_definition: str + :param cluster: the cluster name on Elastic Container Service + :type cluster: str + :param overrides: the same parameter that boto3 will receive (templated): + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task + :type overrides: dict + :param aws_conn_id: connection id of AWS credentials / region name. If None, + credential boto3 strategy will be used + (http://boto3.readthedocs.io/en/latest/guide/configuration.html). + :type aws_conn_id: str + :param region_name: region name to use in AWS Hook. + Override the region_name in connection (if provided) + :type region_name: str + :param launch_type: the launch type on which to run your task ('EC2' or 'FARGATE') + :type launch_type: str + :param group: the name of the task group associated with the task + :type group: str + :param placement_constraints: an array of placement constraint objects to use for + the task + :type placement_constraints: list + :param placement_strategy: an array of placement strategy objects to use for + the task + :type placement_strategy: list + :param platform_version: the platform version on which your task is running + :type platform_version: str + :param network_configuration: the network configuration for the task + :type network_configuration: dict + :param tags: a dictionary of tags in the form of {'tagKey': 'tagValue'}. + :type tags: dict + :param awslogs_group: the CloudWatch group where your ECS container logs are stored. + Only required if you want logs to be shown in the Airflow UI after your job has + finished. + :type awslogs_group: str + :param awslogs_region: the region in which your CloudWatch logs are stored. + If None, this is the same as the `region_name` parameter. If that is also None, + this is the default AWS region based on your connection settings. + :type awslogs_region: str + :param awslogs_stream_prefix: the stream prefix that is used for the CloudWatch logs. + This is usually based on some custom name combined with the name of the container. + Only required if you want logs to be shown in the Airflow UI after your job has + finished. + :type awslogs_stream_prefix: str + :param reattach: If set to True, will check if a task from the same family is already running. + If so, the operator will attach to it instead of starting a new task. + :type reattach: bool + """ + + ui_color = "#f0ede4" + template_fields = ("overrides",) + + @apply_defaults + def __init__( + self, + *, + task_definition: str, + cluster: str, + overrides: dict, # pylint: disable=too-many-arguments + aws_conn_id: Optional[str] = None, + region_name: Optional[str] = None, + launch_type: str = "EC2", + group: Optional[str] = None, + placement_constraints: Optional[list] = None, + placement_strategy: Optional[list] = None, + platform_version: str = "LATEST", + network_configuration: Optional[dict] = None, + tags: Optional[dict] = None, + awslogs_group: Optional[str] = None, + awslogs_region: Optional[str] = None, + awslogs_stream_prefix: Optional[str] = None, + propagate_tags: Optional[str] = None, + reattach: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.task_definition = task_definition + self.cluster = cluster + self.overrides = overrides + self.launch_type = launch_type + self.group = group + self.placement_constraints = placement_constraints + self.placement_strategy = placement_strategy + self.platform_version = platform_version + self.network_configuration = network_configuration + + self.tags = tags + self.awslogs_group = awslogs_group + self.awslogs_stream_prefix = awslogs_stream_prefix + self.awslogs_region = awslogs_region + self.propagate_tags = propagate_tags + self.reattach = reattach + + if self.awslogs_region is None: + self.awslogs_region = region_name + + self.hook: Optional[AwsBaseHook] = None + self.client: Optional[ECSProtocol] = None + self.arn: Optional[str] = None + + def execute(self, context): + self.log.info( + "Running ECS Task - Task definition: %s - on cluster %s", + self.task_definition, + self.cluster, + ) + self.log.info("ECSOperator overrides: %s", self.overrides) + + self.client = self.get_hook().get_conn() + + if self.reattach: + self._try_reattach_task() + + if not self.arn: + self._start_task() + + self._wait_for_task_ended() + + self._check_success_task() + + self.log.info("ECS Task has been successfully executed") + + if self.do_xcom_push: + return self._last_log_message() + + return None + + def _start_task(self): + run_opts = { + "cluster": self.cluster, + "taskDefinition": self.task_definition, + "overrides": self.overrides, + "startedBy": self.owner, + } + + if self.launch_type: + run_opts["launchType"] = self.launch_type + if self.launch_type == "FARGATE": + run_opts["platformVersion"] = self.platform_version + if self.group is not None: + run_opts["group"] = self.group + if self.placement_constraints is not None: + run_opts["placementConstraints"] = self.placement_constraints + if self.placement_strategy is not None: + run_opts["placementStrategy"] = self.placement_strategy + if self.network_configuration is not None: + run_opts["networkConfiguration"] = self.network_configuration + if self.tags is not None: + run_opts["tags"] = [{"key": k, "value": v} for (k, v) in self.tags.items()] + if self.propagate_tags is not None: + run_opts["propagateTags"] = self.propagate_tags + + response = self.client.run_task(**run_opts) + + failures = response["failures"] + if len(failures) > 0: + raise AirflowException(response) + self.log.info("ECS Task started: %s", response) + + self.arn = response["tasks"][0]["taskArn"] + + def _try_reattach_task(self): + task_def_resp = self.client.describe_task_definition(self.task_definition) + ecs_task_family = task_def_resp["taskDefinition"]["family"] + + list_tasks_resp = self.client.list_tasks( + cluster=self.cluster, + launchType=self.launch_type, + desiredStatus="RUNNING", + family=ecs_task_family, + ) + running_tasks = list_tasks_resp["taskArns"] + + running_tasks_count = len(running_tasks) + if running_tasks_count > 1: + self.arn = running_tasks[0] + self.log.warning("More than 1 ECS Task found. Reattaching to %s", self.arn) + elif running_tasks_count == 1: + self.arn = running_tasks[0] + self.log.info("Reattaching task: %s", self.arn) + else: + self.log.info("No active tasks found to reattach") + + def _wait_for_task_ended(self) -> None: + if not self.client or not self.arn: + return + + waiter = self.client.get_waiter("tasks_stopped") + waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow + waiter.wait(cluster=self.cluster, tasks=[self.arn]) + + return + + def _cloudwatch_log_events(self) -> Generator: + if self._aws_logs_enabled(): + task_id = self.arn.split("/")[-1] + stream_name = f"{self.awslogs_stream_prefix}/{task_id}" + yield from self.get_logs_hook().get_log_events( + self.awslogs_group, stream_name + ) + else: + yield from () + + def _aws_logs_enabled(self): + return self.awslogs_group and self.awslogs_stream_prefix + + def _last_log_message(self): + try: + return deque(self._cloudwatch_log_events(), maxlen=1).pop()["message"] + except IndexError: + return None + + def _check_success_task(self) -> None: + if not self.client or not self.arn: + return + + response = self.client.describe_tasks(cluster=self.cluster, tasks=[self.arn]) + self.log.info("ECS Task stopped, check status: %s", response) + + # Get logs from CloudWatch if the awslogs log driver was used + for event in self._cloudwatch_log_events(): + event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0) + self.log.info("[%s] %s", event_dt.isoformat(), event["message"]) + + if len(response.get("failures", [])) > 0: + raise AirflowException(response) + + for task in response["tasks"]: + # This is a `stoppedReason` that indicates a task has not + # successfully finished, but there is no other indication of failure + # in the response. + # https://docs.aws.amazon.com/AmazonECS/latest/developerguide/stopped-task-errors.html + if re.match( + r"Host EC2 \(instance .+?\) (stopped|terminated)\.", + task.get("stoppedReason", ""), + ): + raise AirflowException( + "The task was stopped because the host instance terminated: {}".format( + task.get("stoppedReason", "") + ) + ) + containers = task["containers"] + for container in containers: + if ( + container.get("lastStatus") == "STOPPED" + and container["exitCode"] != 0 + ): + raise AirflowException(f"This task is not in success state {task}") + elif container.get("lastStatus") == "PENDING": + raise AirflowException(f"This task is still pending {task}") + elif "error" in container.get("reason", "").lower(): + raise AirflowException( + "This containers encounter an error during launching : {}".format( + container.get("reason", "").lower() + ) + ) + + def get_hook(self) -> AwsBaseHook: + """Create and return an AwsHook.""" + if self.hook: + return self.hook + + self.hook = AwsBaseHook( + aws_conn_id=self.aws_conn_id, + client_type="ecs", + region_name=self.region_name, + ) + return self.hook + + def get_logs_hook(self) -> AwsLogsHook: + """Create and return an AwsLogsHook.""" + return AwsLogsHook( + aws_conn_id=self.aws_conn_id, region_name=self.awslogs_region + ) + + def on_kill(self) -> None: + if not self.client or not self.arn: + return + + response = self.client.stop_task( + cluster=self.cluster, task=self.arn, reason="Task killed by the user" + ) + self.log.info(response) diff --git a/reference/providers/amazon/aws/operators/emr_add_steps.py b/reference/providers/amazon/aws/operators/emr_add_steps.py new file mode 100644 index 0000000..54ed73c --- /dev/null +++ b/reference/providers/amazon/aws/operators/emr_add_steps.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import ast +from typing import Any, Dict, List, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.emr import EmrHook +from airflow.utils.decorators import apply_defaults + + +class EmrAddStepsOperator(BaseOperator): + """ + An operator that adds steps to an existing EMR job_flow. + + :param job_flow_id: id of the JobFlow to add steps to. (templated) + :type job_flow_id: Optional[str] + :param job_flow_name: name of the JobFlow to add steps to. Use as an alternative to passing + job_flow_id. will search for id of JobFlow with matching name in one of the states in + param cluster_states. Exactly one cluster like this should exist or will fail. (templated) + :type job_flow_name: Optional[str] + :param cluster_states: Acceptable cluster states when searching for JobFlow id by job_flow_name. + (templated) + :type cluster_states: list + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + :param steps: boto3 style steps or reference to a steps file (must be '.json') to + be added to the jobflow. (templated) + :type steps: list|str + :param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id. + :type do_xcom_push: bool + """ + + template_fields = ["job_flow_id", "job_flow_name", "cluster_states", "steps"] + template_ext = (".json",) + ui_color = "#f9c915" + + @apply_defaults + def __init__( + self, + *, + job_flow_id: Optional[str] = None, + job_flow_name: Optional[str] = None, + cluster_states: Optional[List[str]] = None, + aws_conn_id: str = "aws_default", + steps: Optional[Union[List[dict], str]] = None, + **kwargs, + ): + if kwargs.get("xcom_push") is not None: + raise AirflowException( + "'xcom_push' was deprecated, use 'do_xcom_push' instead" + ) + if not (job_flow_id is None) ^ (job_flow_name is None): + raise AirflowException( + "Exactly one of job_flow_id or job_flow_name must be specified." + ) + super().__init__(**kwargs) + cluster_states = cluster_states or [] + steps = steps or [] + self.aws_conn_id = aws_conn_id + self.job_flow_id = job_flow_id + self.job_flow_name = job_flow_name + self.cluster_states = cluster_states + self.steps = steps + + def execute(self, context: Dict[str, Any]) -> List[str]: + emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) + + emr = emr_hook.get_conn() + + job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name( + str(self.job_flow_name), self.cluster_states + ) + + if not job_flow_id: + raise AirflowException(f"No cluster found for name: {self.job_flow_name}") + + if self.do_xcom_push: + context["ti"].xcom_push(key="job_flow_id", value=job_flow_id) + + self.log.info("Adding steps to %s", job_flow_id) + + # steps may arrive as a string representing a list + # e.g. if we used XCom or a file then: steps="[{ step1 }, { step2 }]" + steps = self.steps + if isinstance(steps, str): + steps = ast.literal_eval(steps) + + response = emr.add_job_flow_steps(JobFlowId=job_flow_id, Steps=steps) + + if not response["ResponseMetadata"]["HTTPStatusCode"] == 200: + raise AirflowException(f"Adding steps failed: {response}") + else: + self.log.info("Steps %s added to JobFlow", response["StepIds"]) + return response["StepIds"] diff --git a/reference/providers/amazon/aws/operators/emr_create_job_flow.py b/reference/providers/amazon/aws/operators/emr_create_job_flow.py new file mode 100644 index 0000000..7f144a5 --- /dev/null +++ b/reference/providers/amazon/aws/operators/emr_create_job_flow.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import ast +from typing import Any, Dict, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.emr import EmrHook +from airflow.utils.decorators import apply_defaults + + +class EmrCreateJobFlowOperator(BaseOperator): + """ + Creates an EMR JobFlow, reading the config from the EMR connection. + A dictionary of JobFlow overrides can be passed that override + the config from the connection. + + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + :param emr_conn_id: emr connection to use + :type emr_conn_id: str + :param job_flow_overrides: boto3 style arguments or reference to an arguments file + (must be '.json') to override emr_connection extra. (templated) + :type job_flow_overrides: dict|str + """ + + template_fields = ["job_flow_overrides"] + template_ext = (".json",) + ui_color = "#f9c915" + + @apply_defaults + def __init__( + self, + *, + aws_conn_id: str = "aws_default", + emr_conn_id: str = "emr_default", + job_flow_overrides: Optional[Union[str, Dict[str, Any]]] = None, + region_name: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.aws_conn_id = aws_conn_id + self.emr_conn_id = emr_conn_id + if job_flow_overrides is None: + job_flow_overrides = {} + self.job_flow_overrides = job_flow_overrides + self.region_name = region_name + + def execute(self, context: Dict[str, Any]) -> str: + emr = EmrHook( + aws_conn_id=self.aws_conn_id, + emr_conn_id=self.emr_conn_id, + region_name=self.region_name, + ) + + self.log.info( + "Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s", + self.aws_conn_id, + self.emr_conn_id, + ) + + if isinstance(self.job_flow_overrides, str): + job_flow_overrides: Dict[str, Any] = ast.literal_eval( + self.job_flow_overrides + ) + self.job_flow_overrides = job_flow_overrides + else: + job_flow_overrides = self.job_flow_overrides + response = emr.create_job_flow(job_flow_overrides) + + if not response["ResponseMetadata"]["HTTPStatusCode"] == 200: + raise AirflowException(f"JobFlow creation failed: {response}") + else: + self.log.info("JobFlow with id %s created", response["JobFlowId"]) + return response["JobFlowId"] diff --git a/reference/providers/amazon/aws/operators/emr_modify_cluster.py b/reference/providers/amazon/aws/operators/emr_modify_cluster.py new file mode 100644 index 0000000..a96386f --- /dev/null +++ b/reference/providers/amazon/aws/operators/emr_modify_cluster.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.emr import EmrHook +from airflow.utils.decorators import apply_defaults + + +class EmrModifyClusterOperator(BaseOperator): + """ + An operator that modifies an existing EMR cluster. + :param cluster_id: cluster identifier + :type cluster_id: str + :param step_concurrency_level: Concurrency of the cluster + :type step_concurrency_level: int + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + :param do_xcom_push: if True, cluster_id is pushed to XCom with key cluster_id. + :type do_xcom_push: bool + """ + + template_fields = ["cluster_id", "step_concurrency_level"] + template_ext = () + ui_color = "#f9c915" + + @apply_defaults + def __init__( + self, + *, + cluster_id: str, + step_concurrency_level: int, + aws_conn_id: str = "aws_default", + **kwargs, + ): + if kwargs.get("xcom_push") is not None: + raise AirflowException( + "'xcom_push' was deprecated, use 'do_xcom_push' instead" + ) + super().__init__(**kwargs) + self.aws_conn_id = aws_conn_id + self.cluster_id = cluster_id + self.step_concurrency_level = step_concurrency_level + + def execute(self, context: Dict[str, Any]) -> int: + emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) + + emr = emr_hook.get_conn() + + if self.do_xcom_push: + context["ti"].xcom_push(key="cluster_id", value=self.cluster_id) + + self.log.info("Modifying cluster %s", self.cluster_id) + response = emr.modify_cluster( + ClusterId=self.cluster_id, StepConcurrencyLevel=self.step_concurrency_level + ) + + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Modify cluster failed: {response}") + else: + self.log.info( + "Steps concurrency level %d", response["StepConcurrencyLevel"] + ) + return response["StepConcurrencyLevel"] diff --git a/reference/providers/amazon/aws/operators/emr_terminate_job_flow.py b/reference/providers/amazon/aws/operators/emr_terminate_job_flow.py new file mode 100644 index 0000000..5bd4b7a --- /dev/null +++ b/reference/providers/amazon/aws/operators/emr_terminate_job_flow.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.emr import EmrHook +from airflow.utils.decorators import apply_defaults + + +class EmrTerminateJobFlowOperator(BaseOperator): + """ + Operator to terminate EMR JobFlows. + + :param job_flow_id: id of the JobFlow to terminate. (templated) + :type job_flow_id: str + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + """ + + template_fields = ["job_flow_id"] + template_ext = () + ui_color = "#f9c915" + + @apply_defaults + def __init__(self, *, job_flow_id: str, aws_conn_id: str = "aws_default", **kwargs): + super().__init__(**kwargs) + self.job_flow_id = job_flow_id + self.aws_conn_id = aws_conn_id + + def execute(self, context: Dict[str, Any]) -> None: + emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn() + + self.log.info("Terminating JobFlow %s", self.job_flow_id) + response = emr.terminate_job_flows(JobFlowIds=[self.job_flow_id]) + + if not response["ResponseMetadata"]["HTTPStatusCode"] == 200: + raise AirflowException(f"JobFlow termination failed: {response}") + else: + self.log.info("JobFlow with id %s terminated", self.job_flow_id) diff --git a/reference/providers/amazon/aws/operators/glacier.py b/reference/providers/amazon/aws/operators/glacier.py new file mode 100644 index 0000000..c88d652 --- /dev/null +++ b/reference/providers/amazon/aws/operators/glacier.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.glacier import GlacierHook +from airflow.utils.decorators import apply_defaults + + +class GlacierCreateJobOperator(BaseOperator): + """ + Initiate an Amazon Glacier inventory-retrieval job + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GlacierCreateJobOperator` + + :param aws_conn_id: The reference to the AWS connection details + :type aws_conn_id: str + :param vault_name: the Glacier vault on which job is executed + :type vault_name: str + """ + + template_fields = ("vault_name",) + + @apply_defaults + def __init__( + self, + *, + aws_conn_id="aws_default", + vault_name: str, + **kwargs, + ): + super().__init__(**kwargs) + self.aws_conn_id = aws_conn_id + self.vault_name = vault_name + + def execute(self, context): + hook = GlacierHook(aws_conn_id=self.aws_conn_id) + response = hook.retrieve_inventory(vault_name=self.vault_name) + return response diff --git a/reference/providers/amazon/aws/operators/glue.py b/reference/providers/amazon/aws/operators/glue.py new file mode 100644 index 0000000..541871d --- /dev/null +++ b/reference/providers/amazon/aws/operators/glue.py @@ -0,0 +1,134 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os.path +from typing import Optional + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.glue import AwsGlueJobHook +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.utils.decorators import apply_defaults + + +class AwsGlueJobOperator(BaseOperator): + """ + Creates an AWS Glue Job. AWS Glue is a serverless Spark + ETL service for running Spark Jobs on the AWS cloud. + Language support: Python and Scala + + :param job_name: unique job name per AWS Account + :type job_name: Optional[str] + :param script_location: location of ETL script. Must be a local or S3 path + :type script_location: Optional[str] + :param job_desc: job description details + :type job_desc: Optional[str] + :param concurrent_run_limit: The maximum number of concurrent runs allowed for a job + :type concurrent_run_limit: Optional[int] + :param script_args: etl script arguments and AWS Glue arguments (templated) + :type script_args: dict + :param retry_limit: The maximum number of times to retry this job if it fails + :type retry_limit: Optional[int] + :param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job. + :type num_of_dpus: int + :param region_name: aws region name (example: us-east-1) + :type region_name: str + :param s3_bucket: S3 bucket where logs and local etl script will be uploaded + :type s3_bucket: Optional[str] + :param iam_role_name: AWS IAM Role for Glue Job Execution + :type iam_role_name: Optional[str] + :param create_job_kwargs: Extra arguments for Glue Job Creation + :type create_job_kwargs: Optional[dict] + """ + + template_fields = ("script_args",) + template_ext = () + ui_color = "#ededed" + + @apply_defaults + def __init__( + self, + *, + job_name: str = "aws_glue_default_job", + job_desc: str = "AWS Glue Job with Airflow", + script_location: Optional[str] = None, + concurrent_run_limit: Optional[int] = None, + script_args: Optional[dict] = None, + retry_limit: Optional[int] = None, + num_of_dpus: int = 6, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + s3_bucket: Optional[str] = None, + iam_role_name: Optional[str] = None, + create_job_kwargs: Optional[dict] = None, + **kwargs, + ): # pylint: disable=too-many-arguments + super().__init__(**kwargs) + self.job_name = job_name + self.job_desc = job_desc + self.script_location = script_location + self.concurrent_run_limit = concurrent_run_limit or 1 + self.script_args = script_args or {} + self.retry_limit = retry_limit + self.num_of_dpus = num_of_dpus + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.s3_bucket = s3_bucket + self.iam_role_name = iam_role_name + self.s3_protocol = "s3://" + self.s3_artifacts_prefix = "artifacts/glue-scripts/" + self.create_job_kwargs = create_job_kwargs + + def execute(self, context): + """ + Executes AWS Glue Job from Airflow + + :return: the id of the current glue job. + """ + if self.script_location and not self.script_location.startswith( + self.s3_protocol + ): + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) + script_name = os.path.basename(self.script_location) + s3_hook.load_file( + self.script_location, + self.s3_bucket, + self.s3_artifacts_prefix + script_name, + ) + glue_job = AwsGlueJobHook( + job_name=self.job_name, + desc=self.job_desc, + concurrent_run_limit=self.concurrent_run_limit, + script_location=self.script_location, + retry_limit=self.retry_limit, + num_of_dpus=self.num_of_dpus, + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + s3_bucket=self.s3_bucket, + iam_role_name=self.iam_role_name, + create_job_kwargs=self.create_job_kwargs, + ) + self.log.info("Initializing AWS Glue Job: %s", self.job_name) + glue_job_run = glue_job.initialize_job(self.script_args) + glue_job_run = glue_job.job_completion(self.job_name, glue_job_run["JobRunId"]) + self.log.info( + "AWS Glue Job: %s status: %s. Run Id: %s", + self.job_name, + glue_job_run["JobRunState"], + glue_job_run["JobRunId"], + ) + return glue_job_run["JobRunId"] diff --git a/reference/providers/amazon/aws/operators/glue_crawler.py b/reference/providers/amazon/aws/operators/glue_crawler.py new file mode 100644 index 0000000..0d1f42c --- /dev/null +++ b/reference/providers/amazon/aws/operators/glue_crawler.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.glue_crawler import AwsGlueCrawlerHook +from airflow.utils.decorators import apply_defaults + + +class AwsGlueCrawlerOperator(BaseOperator): + """ + Creates, updates and triggers an AWS Glue Crawler. AWS Glue Crawler is a serverless + service that manages a catalog of metadata tables that contain the inferred + schema, format and data types of data stores within the AWS cloud. + + :param config: Configurations for the AWS Glue crawler + :type config: dict + :param aws_conn_id: aws connection to use + :type aws_conn_id: Optional[str] + :param poll_interval: Time (in seconds) to wait between two consecutive calls to check crawler status + :type poll_interval: Optional[int] + """ + + ui_color = "#ededed" + + @apply_defaults + def __init__( + self, + config, + aws_conn_id="aws_default", + poll_interval: int = 5, + **kwargs, + ): + super().__init__(**kwargs) + self.aws_conn_id = aws_conn_id + self.poll_interval = poll_interval + self.config = config + + @cached_property + def hook(self) -> AwsGlueCrawlerHook: + """Create and return an AwsGlueCrawlerHook.""" + return AwsGlueCrawlerHook(self.aws_conn_id) + + def execute(self, context): + """ + Executes AWS Glue Crawler from Airflow + + :return: the name of the current glue crawler. + """ + crawler_name = self.config["Name"] + if self.hook.has_crawler(crawler_name): + self.hook.update_crawler(**self.config) + else: + self.hook.create_crawler(**self.config) + + self.log.info("Triggering AWS Glue Crawler") + self.hook.start_crawler(crawler_name) + self.log.info("Waiting for AWS Glue Crawler") + self.hook.wait_for_crawler_completion( + crawler_name=crawler_name, poll_interval=self.poll_interval + ) + + return crawler_name diff --git a/reference/providers/amazon/aws/operators/s3_bucket.py b/reference/providers/amazon/aws/operators/s3_bucket.py new file mode 100644 index 0000000..2dc9387 --- /dev/null +++ b/reference/providers/amazon/aws/operators/s3_bucket.py @@ -0,0 +1,116 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains AWS S3 operators.""" +from typing import Optional + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.utils.decorators import apply_defaults + + +class S3CreateBucketOperator(BaseOperator): + """ + This operator creates an S3 bucket + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:S3CreateBucketOperator` + + :param bucket_name: This is bucket name you want to create + :type bucket_name: str + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is None or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :type aws_conn_id: Optional[str] + :param region_name: AWS region_name. If not specified fetched from connection. + :type region_name: Optional[str] + """ + + template_fields = ("bucket_name",) + + @apply_defaults + def __init__( + self, + *, + bucket_name: str, + aws_conn_id: Optional[str] = "aws_default", + region_name: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.bucket_name = bucket_name + self.region_name = region_name + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + def execute(self, context): + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + if not s3_hook.check_for_bucket(self.bucket_name): + s3_hook.create_bucket( + bucket_name=self.bucket_name, region_name=self.region_name + ) + self.log.info("Created bucket with name: %s", self.bucket_name) + else: + self.log.info("Bucket with name: %s already exists", self.bucket_name) + + +class S3DeleteBucketOperator(BaseOperator): + """ + This operator deletes an S3 bucket + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:S3DeleteBucketOperator` + + :param bucket_name: This is bucket name you want to delete + :type bucket_name: str + :param force_delete: Forcibly delete all objects in the bucket before deleting the bucket + :type force_delete: bool + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is None or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :type aws_conn_id: Optional[str] + """ + + template_fields = ("bucket_name",) + + def __init__( + self, + bucket_name: str, + force_delete: bool = False, + aws_conn_id: Optional[str] = "aws_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.bucket_name = bucket_name + self.force_delete = force_delete + self.aws_conn_id = aws_conn_id + + def execute(self, context): + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) + if s3_hook.check_for_bucket(self.bucket_name): + s3_hook.delete_bucket( + bucket_name=self.bucket_name, force_delete=self.force_delete + ) + self.log.info("Deleted bucket with name: %s", self.bucket_name) + else: + self.log.info("Bucket with name: %s doesn't exist", self.bucket_name) diff --git a/reference/providers/amazon/aws/operators/s3_bucket_tagging.py b/reference/providers/amazon/aws/operators/s3_bucket_tagging.py new file mode 100644 index 0000000..81f5481 --- /dev/null +++ b/reference/providers/amazon/aws/operators/s3_bucket_tagging.py @@ -0,0 +1,159 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains AWS S3 operators.""" +from typing import Dict, List, Optional + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook + +BUCKET_DOES_NOT_EXIST_MSG = "Bucket with name: %s doesn't exist" + + +class S3GetBucketTaggingOperator(BaseOperator): + """ + This operator gets tagging from an S3 bucket + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:S3GetBucketTaggingOperator` + + :param bucket_name: This is bucket name you want to reference + :type bucket_name: str + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is None or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :type aws_conn_id: Optional[str] + """ + + template_fields = ("bucket_name",) + + def __init__( + self, bucket_name: str, aws_conn_id: Optional[str] = "aws_default", **kwargs + ) -> None: + super().__init__(**kwargs) + self.bucket_name = bucket_name + self.aws_conn_id = aws_conn_id + + def execute(self, context): + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) + + if s3_hook.check_for_bucket(self.bucket_name): + self.log.info("Getting tags for bucket %s", self.bucket_name) + return s3_hook.get_bucket_tagging(self.bucket_name) + else: + self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name) + return None + + +class S3PutBucketTaggingOperator(BaseOperator): + """ + This operator puts tagging for an S3 bucket. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:S3PutBucketTaggingOperator` + + :param bucket_name: The name of the bucket to add tags to. + :type bucket_name: str + :param key: The key portion of the key/value pair for a tag to be added. + If a key is provided, a value must be provided as well. + :type key: str + :param value: The value portion of the key/value pair for a tag to be added. + If a value is provided, a key must be provided as well. + :param tag_set: A List of key/value pairs. + :type tag_set: List[Dict[str, str]] + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is None or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then the default boto3 configuration would be used (and must be + maintained on each worker node). + :type aws_conn_id: Optional[str] + """ + + template_fields = ("bucket_name",) + + def __init__( + self, + bucket_name: str, + key: Optional[str] = None, + value: Optional[str] = None, + tag_set: Optional[List[Dict[str, str]]] = None, + aws_conn_id: Optional[str] = "aws_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.key = key + self.value = value + self.tag_set = tag_set + self.bucket_name = bucket_name + self.aws_conn_id = aws_conn_id + + def execute(self, context): + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) + + if s3_hook.check_for_bucket(self.bucket_name): + self.log.info("Putting tags for bucket %s", self.bucket_name) + return s3_hook.put_bucket_tagging( + key=self.key, + value=self.value, + tag_set=self.tag_set, + bucket_name=self.bucket_name, + ) + else: + self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name) + return None + + +class S3DeleteBucketTaggingOperator(BaseOperator): + """ + This operator deletes tagging from an S3 bucket. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:S3DeleteBucketTaggingOperator` + + :param bucket_name: This is the name of the bucket to delete tags from. + :type bucket_name: str + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is None or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :type aws_conn_id: Optional[str] + """ + + template_fields = ("bucket_name",) + + def __init__( + self, bucket_name: str, aws_conn_id: Optional[str] = "aws_default", **kwargs + ) -> None: + super().__init__(**kwargs) + self.bucket_name = bucket_name + self.aws_conn_id = aws_conn_id + + def execute(self, context): + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) + + if s3_hook.check_for_bucket(self.bucket_name): + self.log.info("Deleting tags for bucket %s", self.bucket_name) + return s3_hook.delete_bucket_tagging(self.bucket_name) + else: + self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name) + return None diff --git a/reference/providers/amazon/aws/operators/s3_copy_object.py b/reference/providers/amazon/aws/operators/s3_copy_object.py new file mode 100644 index 0000000..7abee7d --- /dev/null +++ b/reference/providers/amazon/aws/operators/s3_copy_object.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.utils.decorators import apply_defaults + + +class S3CopyObjectOperator(BaseOperator): + """ + Creates a copy of an object that is already stored in S3. + + Note: the S3 connection used here needs to have access to both + source and destination bucket/key. + + :param source_bucket_key: The key of the source object. (templated) + + It can be either full s3:// style url or relative path from root level. + + When it's specified as a full s3:// url, please omit source_bucket_name. + :type source_bucket_key: str + :param dest_bucket_key: The key of the object to copy to. (templated) + + The convention to specify `dest_bucket_key` is the same as `source_bucket_key`. + :type dest_bucket_key: str + :param source_bucket_name: Name of the S3 bucket where the source object is in. (templated) + + It should be omitted when `source_bucket_key` is provided as a full s3:// url. + :type source_bucket_name: str + :param dest_bucket_name: Name of the S3 bucket to where the object is copied. (templated) + + It should be omitted when `dest_bucket_key` is provided as a full s3:// url. + :type dest_bucket_name: str + :param source_version_id: Version ID of the source object (OPTIONAL) + :type source_version_id: str + :param aws_conn_id: Connection id of the S3 connection to use + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + + You can provide the following values: + + - False: do not validate SSL certificates. SSL will still be used, + but SSL certificates will not be + verified. + - path/to/cert/bundle.pem: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str + :param acl_policy: String specifying the canned ACL policy for the file being + uploaded to the S3 bucket. + :type acl_policy: str + """ + + template_fields = ( + "source_bucket_key", + "dest_bucket_key", + "source_bucket_name", + "dest_bucket_name", + ) + + @apply_defaults + def __init__( + self, + *, + source_bucket_key: str, + dest_bucket_key: str, + source_bucket_name: Optional[str] = None, + dest_bucket_name: Optional[str] = None, + source_version_id: Optional[str] = None, + aws_conn_id: str = "aws_default", + verify: Optional[Union[str, bool]] = None, + acl_policy: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.source_bucket_key = source_bucket_key + self.dest_bucket_key = dest_bucket_key + self.source_bucket_name = source_bucket_name + self.dest_bucket_name = dest_bucket_name + self.source_version_id = source_version_id + self.aws_conn_id = aws_conn_id + self.verify = verify + self.acl_policy = acl_policy + + def execute(self, context): + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + s3_hook.copy_object( + self.source_bucket_key, + self.dest_bucket_key, + self.source_bucket_name, + self.dest_bucket_name, + self.source_version_id, + self.acl_policy, + ) diff --git a/reference/providers/amazon/aws/operators/s3_delete_objects.py b/reference/providers/amazon/aws/operators/s3_delete_objects.py new file mode 100644 index 0000000..5e72993 --- /dev/null +++ b/reference/providers/amazon/aws/operators/s3_delete_objects.py @@ -0,0 +1,94 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.utils.decorators import apply_defaults + + +class S3DeleteObjectsOperator(BaseOperator): + """ + To enable users to delete single object or multiple objects from + a bucket using a single HTTP request. + + Users may specify up to 1000 keys to delete. + + :param bucket: Name of the bucket in which you are going to delete object(s). (templated) + :type bucket: str + :param keys: The key(s) to delete from S3 bucket. (templated) + + When ``keys`` is a string, it's supposed to be the key name of + the single object to delete. + + When ``keys`` is a list, it's supposed to be the list of the + keys to delete. + + You may specify up to 1000 keys. + :type keys: str or list + :param prefix: Prefix of objects to delete. (templated) + All objects matching this prefix in the bucket will be deleted. + :type prefix: str + :param aws_conn_id: Connection id of the S3 connection to use + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used, + but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str + """ + + template_fields = ("keys", "bucket", "prefix") + + @apply_defaults + def __init__( + self, + *, + bucket: str, + keys: Optional[Union[str, list]] = None, + prefix: Optional[str] = None, + aws_conn_id: str = "aws_default", + verify: Optional[Union[str, bool]] = None, + **kwargs, + ): + + if not bool(keys) ^ bool(prefix): + raise ValueError("Either keys or prefix should be set.") + + super().__init__(**kwargs) + self.bucket = bucket + self.keys = keys + self.prefix = prefix + self.aws_conn_id = aws_conn_id + self.verify = verify + + def execute(self, context): + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + + keys = self.keys or s3_hook.list_keys( + bucket_name=self.bucket, prefix=self.prefix + ) + if keys: + s3_hook.delete_objects(bucket=self.bucket, keys=keys) diff --git a/reference/providers/amazon/aws/operators/s3_file_transform.py b/reference/providers/amazon/aws/operators/s3_file_transform.py new file mode 100644 index 0000000..7c5f873 --- /dev/null +++ b/reference/providers/amazon/aws/operators/s3_file_transform.py @@ -0,0 +1,185 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import subprocess +import sys +from tempfile import NamedTemporaryFile +from typing import Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.utils.decorators import apply_defaults + + +class S3FileTransformOperator(BaseOperator): + """ + Copies data from a source S3 location to a temporary location on the + local filesystem. Runs a transformation on this file as specified by + the transformation script and uploads the output to a destination S3 + location. + + The locations of the source and the destination files in the local + filesystem is provided as an first and second arguments to the + transformation script. The transformation script is expected to read the + data from source, transform it and write the output to the local + destination file. The operator then takes over control and uploads the + local destination file to S3. + + S3 Select is also available to filter the source contents. Users can + omit the transformation script if S3 Select expression is specified. + + :param source_s3_key: The key to be retrieved from S3. (templated) + :type source_s3_key: str + :param dest_s3_key: The key to be written from S3. (templated) + :type dest_s3_key: str + :param transform_script: location of the executable transformation script + :type transform_script: str + :param select_expression: S3 Select expression + :type select_expression: str + :param script_args: arguments for transformation script (templated) + :type script_args: sequence of str + :param source_aws_conn_id: source s3 connection + :type source_aws_conn_id: str + :param source_verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + + This is also applicable to ``dest_verify``. + :type source_verify: bool or str + :param dest_aws_conn_id: destination s3 connection + :type dest_aws_conn_id: str + :param dest_verify: Whether or not to verify SSL certificates for S3 connection. + See: ``source_verify`` + :type dest_verify: bool or str + :param replace: Replace dest S3 key if it already exists + :type replace: bool + """ + + template_fields = ("source_s3_key", "dest_s3_key", "script_args") + template_ext = () + ui_color = "#f9c915" + + @apply_defaults + def __init__( + self, + *, + source_s3_key: str, + dest_s3_key: str, + transform_script: Optional[str] = None, + select_expression=None, + script_args: Optional[Sequence[str]] = None, + source_aws_conn_id: str = "aws_default", + source_verify: Optional[Union[bool, str]] = None, + dest_aws_conn_id: str = "aws_default", + dest_verify: Optional[Union[bool, str]] = None, + replace: bool = False, + **kwargs, + ) -> None: + # pylint: disable=too-many-arguments + super().__init__(**kwargs) + self.source_s3_key = source_s3_key + self.source_aws_conn_id = source_aws_conn_id + self.source_verify = source_verify + self.dest_s3_key = dest_s3_key + self.dest_aws_conn_id = dest_aws_conn_id + self.dest_verify = dest_verify + self.replace = replace + self.transform_script = transform_script + self.select_expression = select_expression + self.script_args = script_args or [] + self.output_encoding = sys.getdefaultencoding() + + def execute(self, context): + if self.transform_script is None and self.select_expression is None: + raise AirflowException( + "Either transform_script or select_expression must be specified" + ) + + source_s3 = S3Hook( + aws_conn_id=self.source_aws_conn_id, verify=self.source_verify + ) + dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify) + + self.log.info("Downloading source S3 file %s", self.source_s3_key) + if not source_s3.check_for_key(self.source_s3_key): + raise AirflowException( + f"The source key {self.source_s3_key} does not exist" + ) + source_s3_key_object = source_s3.get_key(self.source_s3_key) + + with NamedTemporaryFile("wb") as f_source, NamedTemporaryFile("wb") as f_dest: + self.log.info( + "Dumping S3 file %s contents to local file %s", + self.source_s3_key, + f_source.name, + ) + + if self.select_expression is not None: + content = source_s3.select_key( + key=self.source_s3_key, expression=self.select_expression + ) + f_source.write(content.encode("utf-8")) + else: + source_s3_key_object.download_fileobj(Fileobj=f_source) + f_source.flush() + + if self.transform_script is not None: + process = subprocess.Popen( + [ + self.transform_script, + f_source.name, + f_dest.name, + *self.script_args, + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + close_fds=True, + ) + + self.log.info("Output:") + for line in iter(process.stdout.readline, b""): + self.log.info(line.decode(self.output_encoding).rstrip()) + + process.wait() + + if process.returncode: + raise AirflowException( + f"Transform script failed: {process.returncode}" + ) + else: + self.log.info( + "Transform script successful. Output temporarily located at %s", + f_dest.name, + ) + + self.log.info("Uploading transformed file to S3") + f_dest.flush() + dest_s3.load_file( + filename=f_dest.name if self.transform_script else f_source.name, + key=self.dest_s3_key, + replace=self.replace, + ) + self.log.info("Upload successful") diff --git a/reference/providers/amazon/aws/operators/s3_list.py b/reference/providers/amazon/aws/operators/s3_list.py new file mode 100644 index 0000000..b74402f --- /dev/null +++ b/reference/providers/amazon/aws/operators/s3_list.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterable, Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.utils.decorators import apply_defaults + + +class S3ListOperator(BaseOperator): + """ + List all objects from the bucket with the given string prefix in name. + + This operator returns a python list with the name of objects which can be + used by `xcom` in the downstream task. + + :param bucket: The S3 bucket where to find the objects. (templated) + :type bucket: str + :param prefix: Prefix string to filters the objects whose name begin with + such prefix. (templated) + :type prefix: str + :param delimiter: the delimiter marks key hierarchy. (templated) + :type delimiter: str + :param aws_conn_id: The connection ID to use when connecting to S3 storage. + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str + + + **Example**: + The following operator would list all the files + (excluding subfolders) from the S3 + ``customers/2018/04/`` key in the ``data`` bucket. :: + + s3_file = S3ListOperator( + task_id='list_3s_files', + bucket='data', + prefix='customers/2018/04/', + delimiter='/', + aws_conn_id='aws_customers_conn' + ) + """ + + template_fields: Iterable[str] = ("bucket", "prefix", "delimiter") + ui_color = "#ffd700" + + @apply_defaults + def __init__( + self, + *, + bucket: str, + prefix: str = "", + delimiter: str = "", + aws_conn_id: str = "aws_default", + verify: Optional[Union[str, bool]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.bucket = bucket + self.prefix = prefix + self.delimiter = delimiter + self.aws_conn_id = aws_conn_id + self.verify = verify + + def execute(self, context): + hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + + self.log.info( + "Getting the list of files from bucket: %s in prefix: %s (Delimiter {%s)", + self.bucket, + self.prefix, + self.delimiter, + ) + + return hook.list_keys( + bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter + ) diff --git a/reference/providers/amazon/aws/operators/sagemaker_base.py b/reference/providers/amazon/aws/operators/sagemaker_base.py new file mode 100644 index 0000000..d0a70ac --- /dev/null +++ b/reference/providers/amazon/aws/operators/sagemaker_base.py @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from typing import Iterable + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.utils.decorators import apply_defaults + + +class SageMakerBaseOperator(BaseOperator): + """ + This is the base operator for all SageMaker operators. + + :param config: The configuration necessary to start a training job (templated) + :type config: dict + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: str + """ + + template_fields = ["config"] + template_ext = () + ui_color = "#ededed" + + integer_fields = [] # type: Iterable[Iterable[str]] + + @apply_defaults + def __init__(self, *, config: dict, aws_conn_id: str = "aws_default", **kwargs): + super().__init__(**kwargs) + + self.aws_conn_id = aws_conn_id + self.config = config + + def parse_integer(self, config, field): + """Recursive method for parsing string fields holding integer values to integers.""" + if len(field) == 1: + if isinstance(config, list): + for sub_config in config: + self.parse_integer(sub_config, field) + return + head = field[0] + if head in config: + config[head] = int(config[head]) + return + + if isinstance(config, list): + for sub_config in config: + self.parse_integer(sub_config, field) + return + + head, tail = field[0], field[1:] + if head in config: + self.parse_integer(config[head], tail) + return + + def parse_config_integers(self): + """ + Parse the integer fields of training config to integers in case the config is rendered by Jinja and + all fields are str. + """ + for field in self.integer_fields: + self.parse_integer(self.config, field) + + def expand_role(self): # noqa: D402 + """Placeholder for calling boto3's expand_role(), which expands an IAM role name into an ARN.""" + + def preprocess_config(self): + """Process the config into a usable form.""" + self.log.info("Preprocessing the config and doing required s3_operations") + + self.hook.configure_s3_resources(self.config) + self.parse_config_integers() + self.expand_role() + + self.log.info( + "After preprocessing the config is:\n %s", + json.dumps(self.config, sort_keys=True, indent=4, separators=(",", ": ")), + ) + + def execute(self, context): + raise NotImplementedError("Please implement execute() in sub class!") + + @cached_property + def hook(self): + """Return SageMakerHook""" + return SageMakerHook(aws_conn_id=self.aws_conn_id) diff --git a/reference/providers/amazon/aws/operators/sagemaker_endpoint.py b/reference/providers/amazon/aws/operators/sagemaker_endpoint.py new file mode 100644 index 0000000..9e1d04d --- /dev/null +++ b/reference/providers/amazon/aws/operators/sagemaker_endpoint.py @@ -0,0 +1,170 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator +from airflow.utils.decorators import apply_defaults +from botocore.exceptions import ClientError + + +class SageMakerEndpointOperator(SageMakerBaseOperator): + """ + Create a SageMaker endpoint. + + This operator returns The ARN of the endpoint created in Amazon SageMaker + + :param config: + The configuration necessary to create an endpoint. + + If you need to create a SageMaker endpoint based on an existed + SageMaker model and an existed SageMaker endpoint config:: + + config = endpoint_configuration; + + If you need to create all of SageMaker model, SageMaker endpoint-config and SageMaker endpoint:: + + config = { + 'Model': model_configuration, + 'EndpointConfig': endpoint_config_configuration, + 'Endpoint': endpoint_configuration + } + + For details of the configuration parameter of model_configuration see + :py:meth:`SageMaker.Client.create_model` + + For details of the configuration parameter of endpoint_config_configuration see + :py:meth:`SageMaker.Client.create_endpoint_config` + + For details of the configuration parameter of endpoint_configuration see + :py:meth:`SageMaker.Client.create_endpoint` + + :type config: dict + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: str + :param wait_for_completion: Whether the operator should wait until the endpoint creation finishes. + :type wait_for_completion: bool + :param check_interval: If wait is set to True, this is the time interval, in seconds, that this operation + waits before polling the status of the endpoint creation. + :type check_interval: int + :param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't + finish within max_ingestion_time seconds. If you set this parameter to None it never times out. + :type max_ingestion_time: int + :param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'. + :type operation: str + """ + + @apply_defaults + def __init__( + self, + *, + config: dict, + wait_for_completion: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + operation: str = "create", + **kwargs, + ): + super().__init__(config=config, **kwargs) + + self.config = config + self.wait_for_completion = wait_for_completion + self.check_interval = check_interval + self.max_ingestion_time = max_ingestion_time + self.operation = operation.lower() + if self.operation not in ["create", "update"]: + raise ValueError( + 'Invalid value! Argument operation has to be one of "create" and "update"' + ) + self.create_integer_fields() + + def create_integer_fields(self) -> None: + """Set fields which should be casted to integers.""" + if "EndpointConfig" in self.config: + self.integer_fields = [ + ["EndpointConfig", "ProductionVariants", "InitialInstanceCount"] + ] + + def expand_role(self) -> None: + if "Model" not in self.config: + return + hook = AwsBaseHook(self.aws_conn_id, client_type="iam") + config = self.config["Model"] + if "ExecutionRoleArn" in config: + config["ExecutionRoleArn"] = hook.expand_role(config["ExecutionRoleArn"]) + + def execute(self, context) -> dict: + self.preprocess_config() + + model_info = self.config.get("Model") + endpoint_config_info = self.config.get("EndpointConfig") + endpoint_info = self.config.get("Endpoint", self.config) + + if model_info: + self.log.info("Creating SageMaker model %s.", model_info["ModelName"]) + self.hook.create_model(model_info) + + if endpoint_config_info: + self.log.info( + "Creating endpoint config %s.", + endpoint_config_info["EndpointConfigName"], + ) + self.hook.create_endpoint_config(endpoint_config_info) + + if self.operation == "create": + sagemaker_operation = self.hook.create_endpoint + log_str = "Creating" + elif self.operation == "update": + sagemaker_operation = self.hook.update_endpoint + log_str = "Updating" + else: + raise ValueError( + 'Invalid value! Argument operation has to be one of "create" and "update"' + ) + + self.log.info( + "%s SageMaker endpoint %s.", log_str, endpoint_info["EndpointName"] + ) + try: + response = sagemaker_operation( + endpoint_info, + wait_for_completion=self.wait_for_completion, + check_interval=self.check_interval, + max_ingestion_time=self.max_ingestion_time, + ) + except ClientError: # Botocore throws a ClientError if the endpoint is already created + self.operation = "update" + sagemaker_operation = self.hook.update_endpoint + log_str = "Updating" + response = sagemaker_operation( + endpoint_info, + wait_for_completion=self.wait_for_completion, + check_interval=self.check_interval, + max_ingestion_time=self.max_ingestion_time, + ) + + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Sagemaker endpoint creation failed: {response}") + else: + return { + "EndpointConfig": self.hook.describe_endpoint_config( + endpoint_info["EndpointConfigName"] + ), + "Endpoint": self.hook.describe_endpoint(endpoint_info["EndpointName"]), + } diff --git a/reference/providers/amazon/aws/operators/sagemaker_endpoint_config.py b/reference/providers/amazon/aws/operators/sagemaker_endpoint_config.py new file mode 100644 index 0000000..299dfee --- /dev/null +++ b/reference/providers/amazon/aws/operators/sagemaker_endpoint_config.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator +from airflow.utils.decorators import apply_defaults + + +class SageMakerEndpointConfigOperator(SageMakerBaseOperator): + """ + Create a SageMaker endpoint config. + + This operator returns The ARN of the endpoint config created in Amazon SageMaker + + :param config: The configuration necessary to create an endpoint config. + + For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config` + :type config: dict + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: str + """ + + integer_fields = [["ProductionVariants", "InitialInstanceCount"]] + + @apply_defaults + def __init__(self, *, config: dict, **kwargs): + super().__init__(config=config, **kwargs) + + self.config = config + + def execute(self, context) -> dict: + self.preprocess_config() + + self.log.info( + "Creating SageMaker Endpoint Config %s.", self.config["EndpointConfigName"] + ) + response = self.hook.create_endpoint_config(self.config) + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException( + f"Sagemaker endpoint config creation failed: {response}" + ) + else: + return { + "EndpointConfig": self.hook.describe_endpoint_config( + self.config["EndpointConfigName"] + ) + } diff --git a/reference/providers/amazon/aws/operators/sagemaker_model.py b/reference/providers/amazon/aws/operators/sagemaker_model.py new file mode 100644 index 0000000..92b23ab --- /dev/null +++ b/reference/providers/amazon/aws/operators/sagemaker_model.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator +from airflow.utils.decorators import apply_defaults + + +class SageMakerModelOperator(SageMakerBaseOperator): + """ + Create a SageMaker model. + + This operator returns The ARN of the model created in Amazon SageMaker + + :param config: The configuration necessary to create a model. + + For details of the configuration parameter see :py:meth:`SageMaker.Client.create_model` + :type config: dict + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: str + """ + + @apply_defaults + def __init__(self, *, config, **kwargs): + super().__init__(config=config, **kwargs) + + self.config = config + + def expand_role(self) -> None: + if "ExecutionRoleArn" in self.config: + hook = AwsBaseHook(self.aws_conn_id, client_type="iam") + self.config["ExecutionRoleArn"] = hook.expand_role( + self.config["ExecutionRoleArn"] + ) + + def execute(self, context) -> dict: + self.preprocess_config() + + self.log.info("Creating SageMaker Model %s.", self.config["ModelName"]) + response = self.hook.create_model(self.config) + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Sagemaker model creation failed: {response}") + else: + return {"Model": self.hook.describe_model(self.config["ModelName"])} diff --git a/reference/providers/amazon/aws/operators/sagemaker_processing.py b/reference/providers/amazon/aws/operators/sagemaker_processing.py new file mode 100644 index 0000000..5530d06 --- /dev/null +++ b/reference/providers/amazon/aws/operators/sagemaker_processing.py @@ -0,0 +1,139 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator +from airflow.utils.decorators import apply_defaults + + +class SageMakerProcessingOperator(SageMakerBaseOperator): + """ + Initiate a SageMaker processing job. + + This operator returns The ARN of the processing job created in Amazon SageMaker. + + :param config: The configuration necessary to start a processing job (templated). + + For details of the configuration parameter see :py:meth:`SageMaker.Client.create_processing_job` + :type config: dict + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: str + :param wait_for_completion: If wait is set to True, the time interval, in seconds, + that the operation waits to check the status of the processing job. + :type wait_for_completion: bool + :param print_log: if the operator should print the cloudwatch log during processing + :type print_log: bool + :param check_interval: if wait is set to be true, this is the time interval + in seconds which the operator will check the status of the processing job + :type check_interval: int + :param max_ingestion_time: If wait is set to True, the operation fails if the processing job + doesn't finish within max_ingestion_time seconds. If you set this parameter to None, + the operation does not timeout. + :type max_ingestion_time: int + :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" + (default) and "fail". + :type action_if_job_exists: str + """ + + @apply_defaults + def __init__( + self, + *, + config: dict, + aws_conn_id: str, + wait_for_completion: bool = True, + print_log: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8 + **kwargs, + ): + super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) + + if action_if_job_exists not in ("increment", "fail"): + raise AirflowException( + "Argument action_if_job_exists accepts only 'increment' and 'fail'. " + f"Provided value: '{action_if_job_exists}'." + ) + self.action_if_job_exists = action_if_job_exists + self.wait_for_completion = wait_for_completion + self.print_log = print_log + self.check_interval = check_interval + self.max_ingestion_time = max_ingestion_time + self._create_integer_fields() + + def _create_integer_fields(self) -> None: + """Set fields which should be casted to integers.""" + self.integer_fields = [ + ["ProcessingResources", "ClusterConfig", "InstanceCount"], + ["ProcessingResources", "ClusterConfig", "VolumeSizeInGB"], + ] + if "StoppingCondition" in self.config: + self.integer_fields += [["StoppingCondition", "MaxRuntimeInSeconds"]] + + def expand_role(self) -> None: + if "RoleArn" in self.config: + hook = AwsBaseHook(self.aws_conn_id, client_type="iam") + self.config["RoleArn"] = hook.expand_role(self.config["RoleArn"]) + + def execute(self, context) -> dict: + self.preprocess_config() + + processing_job_name = self.config["ProcessingJobName"] + processing_jobs = self.hook.list_processing_jobs( + NameContains=processing_job_name + ) + + # Check if given ProcessingJobName already exists + if processing_job_name in [pj["ProcessingJobName"] for pj in processing_jobs]: + if self.action_if_job_exists == "fail": + raise AirflowException( + f"A SageMaker processing job with name {processing_job_name} already exists." + ) + if self.action_if_job_exists == "increment": + self.log.info( + "Found existing processing job with name '%s'.", processing_job_name + ) + new_processing_job_name = ( + f"{processing_job_name}-{len(processing_jobs) + 1}" + ) + self.config["ProcessingJobName"] = new_processing_job_name + self.log.info( + "Incremented processing job name to '%s'.", new_processing_job_name + ) + + self.log.info( + "Creating SageMaker processing job %s.", self.config["ProcessingJobName"] + ) + response = self.hook.create_processing_job( + self.config, + wait_for_completion=self.wait_for_completion, + check_interval=self.check_interval, + max_ingestion_time=self.max_ingestion_time, + ) + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException( + f"Sagemaker Processing Job creation failed: {response}" + ) + return { + "Processing": self.hook.describe_processing_job( + self.config["ProcessingJobName"] + ) + } diff --git a/reference/providers/amazon/aws/operators/sagemaker_training.py b/reference/providers/amazon/aws/operators/sagemaker_training.py new file mode 100644 index 0000000..d8d2d3b --- /dev/null +++ b/reference/providers/amazon/aws/operators/sagemaker_training.py @@ -0,0 +1,134 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator +from airflow.utils.decorators import apply_defaults + + +class SageMakerTrainingOperator(SageMakerBaseOperator): + """ + Initiate a SageMaker training job. + + This operator returns The ARN of the training job created in Amazon SageMaker. + + :param config: The configuration necessary to start a training job (templated). + + For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job` + :type config: dict + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: str + :param wait_for_completion: If wait is set to True, the time interval, in seconds, + that the operation waits to check the status of the training job. + :type wait_for_completion: bool + :param print_log: if the operator should print the cloudwatch log during training + :type print_log: bool + :param check_interval: if wait is set to be true, this is the time interval + in seconds which the operator will check the status of the training job + :type check_interval: int + :param max_ingestion_time: If wait is set to True, the operation fails if the training job + doesn't finish within max_ingestion_time seconds. If you set this parameter to None, + the operation does not timeout. + :type max_ingestion_time: int + :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" + (default) and "fail". + :type action_if_job_exists: str + """ + + integer_fields = [ + ["ResourceConfig", "InstanceCount"], + ["ResourceConfig", "VolumeSizeInGB"], + ["StoppingCondition", "MaxRuntimeInSeconds"], + ] + + @apply_defaults + def __init__( + self, + *, + config: dict, + wait_for_completion: bool = True, + print_log: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8 + **kwargs, + ): + super().__init__(config=config, **kwargs) + + self.wait_for_completion = wait_for_completion + self.print_log = print_log + self.check_interval = check_interval + self.max_ingestion_time = max_ingestion_time + + if action_if_job_exists in ("increment", "fail"): + self.action_if_job_exists = action_if_job_exists + else: + raise AirflowException( + "Argument action_if_job_exists accepts only 'increment' and 'fail'. " + f"Provided value: '{action_if_job_exists}'." + ) + + def expand_role(self) -> None: + if "RoleArn" in self.config: + hook = AwsBaseHook(self.aws_conn_id, client_type="iam") + self.config["RoleArn"] = hook.expand_role(self.config["RoleArn"]) + + def execute(self, context) -> dict: + self.preprocess_config() + + training_job_name = self.config["TrainingJobName"] + training_jobs = self.hook.list_training_jobs(name_contains=training_job_name) + + # Check if given TrainingJobName already exists + if training_job_name in [tj["TrainingJobName"] for tj in training_jobs]: + if self.action_if_job_exists == "increment": + self.log.info( + "Found existing training job with name '%s'.", training_job_name + ) + new_training_job_name = f"{training_job_name}-{len(training_jobs) + 1}" + self.config["TrainingJobName"] = new_training_job_name + self.log.info( + "Incremented training job name to '%s'.", new_training_job_name + ) + elif self.action_if_job_exists == "fail": + raise AirflowException( + f"A SageMaker training job with name {training_job_name} already exists." + ) + + self.log.info( + "Creating SageMaker training job %s.", self.config["TrainingJobName"] + ) + response = self.hook.create_training_job( + self.config, + wait_for_completion=self.wait_for_completion, + print_log=self.print_log, + check_interval=self.check_interval, + max_ingestion_time=self.max_ingestion_time, + ) + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException( + f"Sagemaker Training Job creation failed: {response}" + ) + else: + return { + "Training": self.hook.describe_training_job( + self.config["TrainingJobName"] + ) + } diff --git a/reference/providers/amazon/aws/operators/sagemaker_transform.py b/reference/providers/amazon/aws/operators/sagemaker_transform.py new file mode 100644 index 0000000..83e4f54 --- /dev/null +++ b/reference/providers/amazon/aws/operators/sagemaker_transform.py @@ -0,0 +1,133 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator +from airflow.utils.decorators import apply_defaults + + +class SageMakerTransformOperator(SageMakerBaseOperator): + """ + Initiate a SageMaker transform job. + + This operator returns The ARN of the model created in Amazon SageMaker. + + :param config: The configuration necessary to start a transform job (templated). + + If you need to create a SageMaker transform job based on an existed SageMaker model:: + + config = transform_config + + If you need to create both SageMaker model and SageMaker Transform job:: + + config = { + 'Model': model_config, + 'Transform': transform_config + } + + For details of the configuration parameter of transform_config see + :py:meth:`SageMaker.Client.create_transform_job` + + For details of the configuration parameter of model_config, See: + :py:meth:`SageMaker.Client.create_model` + + :type config: dict + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: str + :param wait_for_completion: Set to True to wait until the transform job finishes. + :type wait_for_completion: bool + :param check_interval: If wait is set to True, the time interval, in seconds, + that this operation waits to check the status of the transform job. + :type check_interval: int + :param max_ingestion_time: If wait is set to True, the operation fails + if the transform job doesn't finish within max_ingestion_time seconds. If you + set this parameter to None, the operation does not timeout. + :type max_ingestion_time: int + """ + + @apply_defaults + def __init__( + self, + *, + config: dict, + wait_for_completion: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + **kwargs, + ): + super().__init__(config=config, **kwargs) + self.config = config + self.wait_for_completion = wait_for_completion + self.check_interval = check_interval + self.max_ingestion_time = max_ingestion_time + self.create_integer_fields() + + def create_integer_fields(self) -> None: + """Set fields which should be casted to integers.""" + self.integer_fields: List[List[str]] = [ + ["Transform", "TransformResources", "InstanceCount"], + ["Transform", "MaxConcurrentTransforms"], + ["Transform", "MaxPayloadInMB"], + ] + if "Transform" not in self.config: + for field in self.integer_fields: + field.pop(0) + + def expand_role(self) -> None: + if "Model" not in self.config: + return + config = self.config["Model"] + if "ExecutionRoleArn" in config: + hook = AwsBaseHook(self.aws_conn_id, client_type="iam") + config["ExecutionRoleArn"] = hook.expand_role(config["ExecutionRoleArn"]) + + def execute(self, context) -> dict: + self.preprocess_config() + + model_config = self.config.get("Model") + transform_config = self.config.get("Transform", self.config) + + if model_config: + self.log.info( + "Creating SageMaker Model %s for transform job", + model_config["ModelName"], + ) + self.hook.create_model(model_config) + + self.log.info( + "Creating SageMaker transform Job %s.", transform_config["TransformJobName"] + ) + response = self.hook.create_transform_job( + transform_config, + wait_for_completion=self.wait_for_completion, + check_interval=self.check_interval, + max_ingestion_time=self.max_ingestion_time, + ) + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException( + f"Sagemaker transform Job creation failed: {response}" + ) + else: + return { + "Model": self.hook.describe_model(transform_config["ModelName"]), + "Transform": self.hook.describe_transform_job( + transform_config["TransformJobName"] + ), + } diff --git a/reference/providers/amazon/aws/operators/sagemaker_tuning.py b/reference/providers/amazon/aws/operators/sagemaker_tuning.py new file mode 100644 index 0000000..4f7a9c1 --- /dev/null +++ b/reference/providers/amazon/aws/operators/sagemaker_tuning.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator +from airflow.utils.decorators import apply_defaults + + +class SageMakerTuningOperator(SageMakerBaseOperator): + """ + Initiate a SageMaker hyperparameter tuning job. + + This operator returns The ARN of the tuning job created in Amazon SageMaker. + + :param config: The configuration necessary to start a tuning job (templated). + + For details of the configuration parameter see + :py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job` + :type config: dict + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: str + :param wait_for_completion: Set to True to wait until the tuning job finishes. + :type wait_for_completion: bool + :param check_interval: If wait is set to True, the time interval, in seconds, + that this operation waits to check the status of the tuning job. + :type check_interval: int + :param max_ingestion_time: If wait is set to True, the operation fails + if the tuning job doesn't finish within max_ingestion_time seconds. If you + set this parameter to None, the operation does not timeout. + :type max_ingestion_time: int + """ + + integer_fields = [ + ["HyperParameterTuningJobConfig", "ResourceLimits", "MaxNumberOfTrainingJobs"], + ["HyperParameterTuningJobConfig", "ResourceLimits", "MaxParallelTrainingJobs"], + ["TrainingJobDefinition", "ResourceConfig", "InstanceCount"], + ["TrainingJobDefinition", "ResourceConfig", "VolumeSizeInGB"], + ["TrainingJobDefinition", "StoppingCondition", "MaxRuntimeInSeconds"], + ] + + @apply_defaults + def __init__( + self, + *, + config: dict, + wait_for_completion: bool = True, + check_interval: int = 30, + max_ingestion_time: Optional[int] = None, + **kwargs, + ): + super().__init__(config=config, **kwargs) + self.config = config + self.wait_for_completion = wait_for_completion + self.check_interval = check_interval + self.max_ingestion_time = max_ingestion_time + + def expand_role(self) -> None: + if "TrainingJobDefinition" in self.config: + config = self.config["TrainingJobDefinition"] + if "RoleArn" in config: + hook = AwsBaseHook(self.aws_conn_id, client_type="iam") + config["RoleArn"] = hook.expand_role(config["RoleArn"]) + + def execute(self, context) -> dict: + self.preprocess_config() + + self.log.info( + "Creating SageMaker Hyper-Parameter Tuning Job %s", + self.config["HyperParameterTuningJobName"], + ) + + response = self.hook.create_tuning_job( + self.config, + wait_for_completion=self.wait_for_completion, + check_interval=self.check_interval, + max_ingestion_time=self.max_ingestion_time, + ) + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Sagemaker Tuning Job creation failed: {response}") + else: + return { + "Tuning": self.hook.describe_tuning_job( + self.config["HyperParameterTuningJobName"] + ) + } diff --git a/reference/providers/amazon/aws/operators/sns.py b/reference/providers/amazon/aws/operators/sns.py new file mode 100644 index 0000000..a92b36f --- /dev/null +++ b/reference/providers/amazon/aws/operators/sns.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Publish message to SNS queue""" +from typing import Optional + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.sns import AwsSnsHook +from airflow.utils.decorators import apply_defaults + + +class SnsPublishOperator(BaseOperator): + """ + Publish a message to Amazon SNS. + + :param aws_conn_id: aws connection to use + :type aws_conn_id: str + :param target_arn: either a TopicArn or an EndpointArn + :type target_arn: str + :param message: the default message you want to send (templated) + :type message: str + :param subject: the message subject you want to send (templated) + :type subject: str + :param message_attributes: the message attributes you want to send as a flat dict (data type will be + determined automatically) + :type message_attributes: dict + """ + + template_fields = ["message", "subject", "message_attributes"] + template_ext = () + + @apply_defaults + def __init__( + self, + *, + target_arn: str, + message: str, + aws_conn_id: str = "aws_default", + subject: Optional[str] = None, + message_attributes: Optional[dict] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.target_arn = target_arn + self.message = message + self.subject = subject + self.message_attributes = message_attributes + self.aws_conn_id = aws_conn_id + + def execute(self, context): + sns = AwsSnsHook(aws_conn_id=self.aws_conn_id) + + self.log.info( + "Sending SNS notification to %s using %s:\nsubject=%s\nattributes=%s\nmessage=%s", + self.target_arn, + self.aws_conn_id, + self.subject, + self.message_attributes, + self.message, + ) + + return sns.publish_to_target( + target_arn=self.target_arn, + message=self.message, + subject=self.subject, + message_attributes=self.message_attributes, + ) diff --git a/reference/providers/amazon/aws/operators/sqs.py b/reference/providers/amazon/aws/operators/sqs.py new file mode 100644 index 0000000..a508f5d --- /dev/null +++ b/reference/providers/amazon/aws/operators/sqs.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Publish message to SQS queue""" +from typing import Optional + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.sqs import SQSHook +from airflow.utils.decorators import apply_defaults + + +class SQSPublishOperator(BaseOperator): + """ + Publish message to a SQS queue. + + :param sqs_queue: The SQS queue url (templated) + :type sqs_queue: str + :param message_content: The message content (templated) + :type message_content: str + :param message_attributes: additional attributes for the message (default: None) + For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message` + :type message_attributes: dict + :param delay_seconds: message delay (templated) (default: 1 second) + :type delay_seconds: int + :param aws_conn_id: AWS connection id (default: aws_default) + :type aws_conn_id: str + """ + + template_fields = ("sqs_queue", "message_content", "delay_seconds") + ui_color = "#6ad3fa" + + @apply_defaults + def __init__( + self, + *, + sqs_queue: str, + message_content: str, + message_attributes: Optional[dict] = None, + delay_seconds: int = 0, + aws_conn_id: str = "aws_default", + **kwargs, + ): + super().__init__(**kwargs) + self.sqs_queue = sqs_queue + self.aws_conn_id = aws_conn_id + self.message_content = message_content + self.delay_seconds = delay_seconds + self.message_attributes = message_attributes or {} + + def execute(self, context): + """ + Publish the message to SQS queue + + :param context: the context object + :type context: dict + :return: dict with information about the message sent + For details of the returned dict see :py:meth:`botocore.client.SQS.send_message` + :rtype: dict + """ + hook = SQSHook(aws_conn_id=self.aws_conn_id) + + result = hook.send_message( + queue_url=self.sqs_queue, + message_body=self.message_content, + delay_seconds=self.delay_seconds, + message_attributes=self.message_attributes, + ) + + self.log.info("result is send_message is %s", result) + + return result diff --git a/reference/providers/amazon/aws/operators/step_function_get_execution_output.py b/reference/providers/amazon/aws/operators/step_function_get_execution_output.py new file mode 100644 index 0000000..e5400c3 --- /dev/null +++ b/reference/providers/amazon/aws/operators/step_function_get_execution_output.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from typing import Optional + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.utils.decorators import apply_defaults + + +class StepFunctionGetExecutionOutputOperator(BaseOperator): + """ + An Operator that begins execution of an Step Function State Machine + + Additional arguments may be specified and are passed down to the underlying BaseOperator. + + .. seealso:: + :class:`~airflow.models.BaseOperator` + + :param execution_arn: ARN of the Step Function State Machine Execution + :type execution_arn: str + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + :type aws_conn_id: str + """ + + template_fields = ["execution_arn"] + template_ext = () + ui_color = "#f9c915" + + @apply_defaults + def __init__( + self, + *, + execution_arn: str, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.execution_arn = execution_arn + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + def execute(self, context): + hook = StepFunctionHook( + aws_conn_id=self.aws_conn_id, region_name=self.region_name + ) + + execution_status = hook.describe_execution(self.execution_arn) + execution_output = ( + json.loads(execution_status["output"]) + if "output" in execution_status + else None + ) + + self.log.info("Got State Machine Execution output for %s", self.execution_arn) + + return execution_output diff --git a/reference/providers/amazon/aws/operators/step_function_start_execution.py b/reference/providers/amazon/aws/operators/step_function_start_execution.py new file mode 100644 index 0000000..9154fa4 --- /dev/null +++ b/reference/providers/amazon/aws/operators/step_function_start_execution.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.utils.decorators import apply_defaults + + +class StepFunctionStartExecutionOperator(BaseOperator): + """ + An Operator that begins execution of an Step Function State Machine + + Additional arguments may be specified and are passed down to the underlying BaseOperator. + + .. seealso:: + :class:`~airflow.models.BaseOperator` + + :param state_machine_arn: ARN of the Step Function State Machine + :type state_machine_arn: str + :param name: The name of the execution. + :type name: Optional[str] + :param state_machine_input: JSON data input to pass to the State Machine + :type state_machine_input: Union[Dict[str, any], str, None] + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + :param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn. + :type do_xcom_push: bool + """ + + template_fields = ["state_machine_arn", "name", "input"] + template_ext = () + ui_color = "#f9c915" + + @apply_defaults + def __init__( + self, + *, + state_machine_arn: str, + name: Optional[str] = None, + state_machine_input: Union[dict, str, None] = None, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.state_machine_arn = state_machine_arn + self.name = name + self.input = state_machine_input + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + def execute(self, context): + hook = StepFunctionHook( + aws_conn_id=self.aws_conn_id, region_name=self.region_name + ) + + execution_arn = hook.start_execution( + self.state_machine_arn, self.name, self.input + ) + + if execution_arn is None: + raise AirflowException( + f"Failed to start State Machine execution for: {self.state_machine_arn}" + ) + + self.log.info( + "Started State Machine execution for %s: %s", + self.state_machine_arn, + execution_arn, + ) + + return execution_arn diff --git a/reference/providers/amazon/aws/secrets/__init__.py b/reference/providers/amazon/aws/secrets/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/amazon/aws/secrets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/amazon/aws/secrets/secrets_manager.py b/reference/providers/amazon/aws/secrets/secrets_manager.py new file mode 100644 index 0000000..3407c70 --- /dev/null +++ b/reference/providers/amazon/aws/secrets/secrets_manager.py @@ -0,0 +1,163 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Objects relating to sourcing secrets from AWS Secrets Manager""" + +from typing import Optional + +import boto3 + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.secrets import BaseSecretsBackend +from airflow.utils.log.logging_mixin import LoggingMixin + + +class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin): + """ + Retrieves Connection or Variables from AWS Secrets Manager + + Configurable via ``airflow.cfg`` like so: + + .. code-block:: ini + + [secrets] + backend = airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend + backend_kwargs = {"connections_prefix": "airflow/connections"} + + For example, if secrets prefix is ``airflow/connections/smtp_default``, this would be accessible + if you provide ``{"connections_prefix": "airflow/connections"}`` and request conn_id ``smtp_default``. + If variables prefix is ``airflow/variables/hello``, this would be accessible + if you provide ``{"variables_prefix": "airflow/variables"}`` and request variable key ``hello``. + And if config_prefix is ``airflow/config/sql_alchemy_conn``, this would be accessible + if you provide ``{"config_prefix": "airflow/config"}`` and request config + key ``sql_alchemy_conn``. + + You can also pass additional keyword arguments like ``aws_secret_access_key``, ``aws_access_key_id`` + or ``region_name`` to this class and they would be passed on to Boto3 client. + + :param connections_prefix: Specifies the prefix of the secret to read to get Connections. + If set to None (null), requests for connections will not be sent to AWS Secrets Manager + :type connections_prefix: str + :param variables_prefix: Specifies the prefix of the secret to read to get Variables. + If set to None (null), requests for variables will not be sent to AWS Secrets Manager + :type variables_prefix: str + :param config_prefix: Specifies the prefix of the secret to read to get Configurations. + If set to None (null), requests for configurations will not be sent to AWS Secrets Manager + :type config_prefix: str + :param profile_name: The name of a profile to use. If not given, then the default profile is used. + :type profile_name: str + :param sep: separator used to concatenate secret_prefix and secret_id. Default: "/" + :type sep: str + """ + + def __init__( + self, + connections_prefix: str = "airflow/connections", + variables_prefix: str = "airflow/variables", + config_prefix: str = "airflow/config", + profile_name: Optional[str] = None, + sep: str = "/", + **kwargs, + ): + super().__init__() + if connections_prefix is not None: + self.connections_prefix = connections_prefix.rstrip("/") + else: + self.connections_prefix = connections_prefix + if variables_prefix is not None: + self.variables_prefix = variables_prefix.rstrip("/") + else: + self.variables_prefix = variables_prefix + if config_prefix is not None: + self.config_prefix = config_prefix.rstrip("/") + else: + self.config_prefix = config_prefix + self.profile_name = profile_name + self.sep = sep + self.kwargs = kwargs + + @cached_property + def client(self): + """Create a Secrets Manager client""" + session = boto3.session.Session( + profile_name=self.profile_name, + ) + return session.client(service_name="secretsmanager", **self.kwargs) + + def get_conn_uri(self, conn_id: str) -> Optional[str]: + """ + Get Connection Value + + :param conn_id: connection id + :type conn_id: str + """ + if self.connections_prefix is None: + return None + + return self._get_secret(self.connections_prefix, conn_id) + + def get_variable(self, key: str) -> Optional[str]: + """ + Get Airflow Variable + + :param key: Variable Key + :return: Variable Value + """ + if self.variables_prefix is None: + return None + + return self._get_secret(self.variables_prefix, key) + + def get_config(self, key: str) -> Optional[str]: + """ + Get Airflow Configuration + + :param key: Configuration Option Key + :return: Configuration Option Value + """ + if self.config_prefix is None: + return None + + return self._get_secret(self.config_prefix, key) + + def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]: + """ + Get secret value from Secrets Manager + + :param path_prefix: Prefix for the Path to get Secret + :type path_prefix: str + :param secret_id: Secret Key + :type secret_id: str + """ + secrets_path = self.build_path(path_prefix, secret_id, self.sep) + try: + response = self.client.get_secret_value( + SecretId=secrets_path, + ) + return response.get("SecretString") + except self.client.exceptions.ResourceNotFoundException: + self.log.debug( + "An error occurred (ResourceNotFoundException) when calling the " + "get_secret_value operation: " + "Secret %s not found.", + secrets_path, + ) + return None diff --git a/reference/providers/amazon/aws/secrets/systems_manager.py b/reference/providers/amazon/aws/secrets/systems_manager.py new file mode 100644 index 0000000..50b2f15 --- /dev/null +++ b/reference/providers/amazon/aws/secrets/systems_manager.py @@ -0,0 +1,148 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Objects relating to sourcing connections from AWS SSM Parameter Store""" +from typing import Optional + +import boto3 + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.secrets import BaseSecretsBackend +from airflow.utils.log.logging_mixin import LoggingMixin + + +class SystemsManagerParameterStoreBackend(BaseSecretsBackend, LoggingMixin): + """ + Retrieves Connection or Variables from AWS SSM Parameter Store + + Configurable via ``airflow.cfg`` like so: + + .. code-block:: ini + + [secrets] + backend = airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend + backend_kwargs = {"connections_prefix": "/airflow/connections", "profile_name": null} + + For example, if ssm path is ``/airflow/connections/smtp_default``, this would be accessible + if you provide ``{"connections_prefix": "/airflow/connections"}`` and request conn_id ``smtp_default``. + And if ssm path is ``/airflow/variables/hello``, this would be accessible + if you provide ``{"variables_prefix": "/airflow/variables"}`` and request conn_id ``hello``. + + :param connections_prefix: Specifies the prefix of the secret to read to get Connections. + If set to None (null), requests for connections will not be sent to AWS SSM Parameter Store. + :type connections_prefix: str + :param variables_prefix: Specifies the prefix of the secret to read to get Variables. + If set to None (null), requests for variables will not be sent to AWS SSM Parameter Store. + :type variables_prefix: str + :param config_prefix: Specifies the prefix of the secret to read to get Variables. + If set to None (null), requests for configurations will not be sent to AWS SSM Parameter Store. + :type config_prefix: str + :param profile_name: The name of a profile to use. If not given, then the default profile is used. + :type profile_name: str + """ + + def __init__( + self, + connections_prefix: str = "/airflow/connections", + variables_prefix: str = "/airflow/variables", + config_prefix: str = "/airflow/config", + profile_name: Optional[str] = None, + **kwargs, + ): + super().__init__() + if connections_prefix is not None: + self.connections_prefix = connections_prefix.rstrip("/") + else: + self.connections_prefix = connections_prefix + if variables_prefix is not None: + self.variables_prefix = variables_prefix.rstrip("/") + else: + self.variables_prefix = variables_prefix + if config_prefix is not None: + self.config_prefix = config_prefix.rstrip("/") + else: + self.config_prefix = config_prefix + self.profile_name = profile_name + self.kwargs = kwargs + + @cached_property + def client(self): + """Create a SSM client""" + session = boto3.Session(profile_name=self.profile_name) + return session.client("ssm", **self.kwargs) + + def get_conn_uri(self, conn_id: str) -> Optional[str]: + """ + Get param value + + :param conn_id: connection id + :type conn_id: str + """ + if self.connections_prefix is None: + return None + + return self._get_secret(self.connections_prefix, conn_id) + + def get_variable(self, key: str) -> Optional[str]: + """ + Get Airflow Variable from Environment Variable + + :param key: Variable Key + :return: Variable Value + """ + if self.variables_prefix is None: + return None + + return self._get_secret(self.variables_prefix, key) + + def get_config(self, key: str) -> Optional[str]: + """ + Get Airflow Configuration + + :param key: Configuration Option Key + :return: Configuration Option Value + """ + if self.config_prefix is None: + return None + + return self._get_secret(self.config_prefix, key) + + def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]: + """ + Get secret value from Parameter Store. + + :param path_prefix: Prefix for the Path to get Secret + :type path_prefix: str + :param secret_id: Secret Key + :type secret_id: str + """ + ssm_path = self.build_path(path_prefix, secret_id) + try: + response = self.client.get_parameter(Name=ssm_path, WithDecryption=True) + value = response["Parameter"]["Value"] + return value + except self.client.exceptions.ParameterNotFound: + self.log.info( + "An error occurred (ParameterNotFound) when calling the GetParameter operation: " + "Parameter %s not found.", + ssm_path, + ) + return None diff --git a/reference/providers/amazon/aws/sensors/__init__.py b/reference/providers/amazon/aws/sensors/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/amazon/aws/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/amazon/aws/sensors/athena.py b/reference/providers/amazon/aws/sensors/athena.py new file mode 100644 index 0000000..9d42b4c --- /dev/null +++ b/reference/providers/amazon/aws/sensors/athena.py @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Optional + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class AthenaSensor(BaseSensorOperator): + """ + Asks for the state of the Query until it reaches a failure state or success state. + If the query fails, the task will fail. + + :param query_execution_id: query_execution_id to check the state of + :type query_execution_id: str + :param max_retries: Number of times to poll for query state before + returning the current state, defaults to None + :type max_retries: int + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + :type aws_conn_id: str + :param sleep_time: Time in seconds to wait between two consecutive call to + check query status on athena, defaults to 10 + :type sleep_time: int + """ + + INTERMEDIATE_STATES = ( + "QUEUED", + "RUNNING", + ) + FAILURE_STATES = ( + "FAILED", + "CANCELLED", + ) + SUCCESS_STATES = ("SUCCEEDED",) + + template_fields = ["query_execution_id"] + template_ext = () + ui_color = "#66c3ff" + + @apply_defaults + def __init__( + self, + *, + query_execution_id: str, + max_retries: Optional[int] = None, + aws_conn_id: str = "aws_default", + sleep_time: int = 10, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.aws_conn_id = aws_conn_id + self.query_execution_id = query_execution_id + self.sleep_time = sleep_time + self.max_retries = max_retries + + def poke(self, context: dict) -> bool: + state = self.hook.poll_query_status(self.query_execution_id, self.max_retries) + + if state in self.FAILURE_STATES: + raise AirflowException("Athena sensor failed") + + if state in self.INTERMEDIATE_STATES: + return False + return True + + @cached_property + def hook(self) -> AWSAthenaHook: + """Create and return an AWSAthenaHook""" + return AWSAthenaHook(self.aws_conn_id, self.sleep_time) diff --git a/reference/providers/amazon/aws/sensors/cloud_formation.py b/reference/providers/amazon/aws/sensors/cloud_formation.py new file mode 100644 index 0000000..92327e5 --- /dev/null +++ b/reference/providers/amazon/aws/sensors/cloud_formation.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains sensors for AWS CloudFormation.""" +from typing import Optional + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.providers.amazon.aws.hooks.cloud_formation import AWSCloudFormationHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class CloudFormationCreateStackSensor(BaseSensorOperator): + """ + Waits for a stack to be created successfully on AWS CloudFormation. + + :param stack_name: The name of the stack to wait for (templated) + :type stack_name: str + :param aws_conn_id: ID of the Airflow connection where credentials and extra configuration are + stored + :type aws_conn_id: str + :param poke_interval: Time in seconds that the job should wait between each try + :type poke_interval: int + """ + + template_fields = ["stack_name"] + ui_color = "#C5CAE9" + + @apply_defaults + def __init__( + self, *, stack_name, aws_conn_id="aws_default", region_name=None, **kwargs + ): + super().__init__(**kwargs) + self.stack_name = stack_name + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + def poke(self, context): + stack_status = self.hook.get_stack_status(self.stack_name) + if stack_status == "CREATE_COMPLETE": + return True + if stack_status in ("CREATE_IN_PROGRESS", None): + return False + raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}") + + @cached_property + def hook(self) -> AWSCloudFormationHook: + """Create and return an AWSCloudFormationHook""" + return AWSCloudFormationHook( + aws_conn_id=self.aws_conn_id, region_name=self.region_name + ) + + +class CloudFormationDeleteStackSensor(BaseSensorOperator): + """ + Waits for a stack to be deleted successfully on AWS CloudFormation. + + :param stack_name: The name of the stack to wait for (templated) + :type stack_name: str + :param aws_conn_id: ID of the Airflow connection where credentials and extra configuration are + stored + :type aws_conn_id: str + :param poke_interval: Time in seconds that the job should wait between each try + :type poke_interval: int + """ + + template_fields = ["stack_name"] + ui_color = "#C5CAE9" + + @apply_defaults + def __init__( + self, + *, + stack_name: str, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.stack_name = stack_name + + def poke(self, context): + stack_status = self.hook.get_stack_status(self.stack_name) + if stack_status in ("DELETE_COMPLETE", None): + return True + if stack_status == "DELETE_IN_PROGRESS": + return False + raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}") + + @cached_property + def hook(self) -> AWSCloudFormationHook: + """Create and return an AWSCloudFormationHook""" + return AWSCloudFormationHook( + aws_conn_id=self.aws_conn_id, region_name=self.region_name + ) diff --git a/reference/providers/amazon/aws/sensors/ec2_instance_state.py b/reference/providers/amazon/aws/sensors/ec2_instance_state.py new file mode 100644 index 0000000..f3fbd9f --- /dev/null +++ b/reference/providers/amazon/aws/sensors/ec2_instance_state.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Optional + +from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class EC2InstanceStateSensor(BaseSensorOperator): + """ + Check the state of the AWS EC2 instance until + state of the instance become equal to the target state. + + :param target_state: target state of instance + :type target_state: str + :param instance_id: id of the AWS EC2 instance + :type instance_id: str + :param region_name: (optional) aws region name associated with the client + :type region_name: Optional[str] + """ + + template_fields = ("target_state", "instance_id", "region_name") + ui_color = "#cc8811" + ui_fgcolor = "#ffffff" + valid_states = ["running", "stopped", "terminated"] + + @apply_defaults + def __init__( + self, + *, + target_state: str, + instance_id: str, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + **kwargs, + ): + if target_state not in self.valid_states: + raise ValueError(f"Invalid target_state: {target_state}") + super().__init__(**kwargs) + self.target_state = target_state + self.instance_id = instance_id + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + def poke(self, context): + ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + instance_state = ec2_hook.get_instance_state(instance_id=self.instance_id) + self.log.info("instance state: %s", instance_state) + return instance_state == self.target_state diff --git a/reference/providers/amazon/aws/sensors/emr_base.py b/reference/providers/amazon/aws/sensors/emr_base.py new file mode 100644 index 0000000..2997752 --- /dev/null +++ b/reference/providers/amazon/aws/sensors/emr_base.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Iterable, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.emr import EmrHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class EmrBaseSensor(BaseSensorOperator): + """ + Contains general sensor behavior for EMR. + + Subclasses should implement following methods: + - ``get_emr_response()`` + - ``state_from_response()`` + - ``failure_message_from_response()`` + + Subclasses should set ``target_states`` and ``failed_states`` fields. + + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + """ + + ui_color = "#66c3ff" + + @apply_defaults + def __init__(self, *, aws_conn_id: str = "aws_default", **kwargs): + super().__init__(**kwargs) + self.aws_conn_id = aws_conn_id + self.target_states: Optional[Iterable[str]] = None # will be set in subclasses + self.failed_states: Optional[Iterable[str]] = None # will be set in subclasses + self.hook: Optional[EmrHook] = None + + def get_hook(self) -> EmrHook: + """Get EmrHook""" + if self.hook: + return self.hook + + self.hook = EmrHook(aws_conn_id=self.aws_conn_id) + return self.hook + + def poke(self, context): + response = self.get_emr_response() + + if not response["ResponseMetadata"]["HTTPStatusCode"] == 200: + self.log.info("Bad HTTP response: %s", response) + return False + + state = self.state_from_response(response) + self.log.info("Job flow currently %s", state) + + if state in self.target_states: + return True + + if state in self.failed_states: + final_message = "EMR job failed" + failure_message = self.failure_message_from_response(response) + if failure_message: + final_message += " " + failure_message + raise AirflowException(final_message) + + return False + + def get_emr_response(self) -> Dict[str, Any]: + """ + Make an API call with boto3 and get response. + + :return: response + :rtype: dict[str, Any] + """ + raise NotImplementedError("Please implement get_emr_response() in subclass") + + @staticmethod + def state_from_response(response: Dict[str, Any]) -> str: + """ + Get state from response dictionary. + + :param response: response from AWS API + :type response: dict[str, Any] + :return: state + :rtype: str + """ + raise NotImplementedError("Please implement state_from_response() in subclass") + + @staticmethod + def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: + """ + Get failure message from response dictionary. + + :param response: response from AWS API + :type response: dict[str, Any] + :return: failure message + :rtype: Optional[str] + """ + raise NotImplementedError( + "Please implement failure_message_from_response() in subclass" + ) diff --git a/reference/providers/amazon/aws/sensors/emr_job_flow.py b/reference/providers/amazon/aws/sensors/emr_job_flow.py new file mode 100644 index 0000000..cff7c84 --- /dev/null +++ b/reference/providers/amazon/aws/sensors/emr_job_flow.py @@ -0,0 +1,106 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Iterable, Optional + +from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor +from airflow.utils.decorators import apply_defaults + + +class EmrJobFlowSensor(EmrBaseSensor): + """ + Asks for the state of the EMR JobFlow (Cluster) until it reaches + any of the target states. + If it fails the sensor errors, failing the task. + + With the default target states, sensor waits cluster to be terminated. + When target_states is set to ['RUNNING', 'WAITING'] sensor waits + until job flow to be ready (after 'STARTING' and 'BOOTSTRAPPING' states) + + :param job_flow_id: job_flow_id to check the state of + :type job_flow_id: str + :param target_states: the target states, sensor waits until + job flow reaches any of these states + :type target_states: list[str] + :param failed_states: the failure states, sensor fails when + job flow reaches any of these states + :type failed_states: list[str] + """ + + template_fields = ["job_flow_id", "target_states", "failed_states"] + template_ext = () + + @apply_defaults + def __init__( + self, + *, + job_flow_id: str, + target_states: Optional[Iterable[str]] = None, + failed_states: Optional[Iterable[str]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.job_flow_id = job_flow_id + self.target_states = target_states or ["TERMINATED"] + self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"] + + def get_emr_response(self) -> Dict[str, Any]: + """ + Make an API call with boto3 and get cluster-level details. + + .. seealso:: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#EMR.Client.describe_cluster + + :return: response + :rtype: dict[str, Any] + """ + emr_client = self.get_hook().get_conn() + + self.log.info("Poking cluster %s", self.job_flow_id) + return emr_client.describe_cluster(ClusterId=self.job_flow_id) + + @staticmethod + def state_from_response(response: Dict[str, Any]) -> str: + """ + Get state from response dictionary. + + :param response: response from AWS API + :type response: dict[str, Any] + :return: current state of the cluster + :rtype: str + """ + return response["Cluster"]["Status"]["State"] + + @staticmethod + def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: + """ + Get failure message from response dictionary. + + :param response: response from AWS API + :type response: dict[str, Any] + :return: failure message + :rtype: Optional[str] + """ + cluster_status = response["Cluster"]["Status"] + state_change_reason = cluster_status.get("StateChangeReason") + if state_change_reason: + return "for code: {} with message {}".format( + state_change_reason.get("Code", "No code"), + state_change_reason.get("Message", "Unknown"), + ) + return None diff --git a/reference/providers/amazon/aws/sensors/emr_step.py b/reference/providers/amazon/aws/sensors/emr_step.py new file mode 100644 index 0000000..bedcd5b --- /dev/null +++ b/reference/providers/amazon/aws/sensors/emr_step.py @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Iterable, Optional + +from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor +from airflow.utils.decorators import apply_defaults + + +class EmrStepSensor(EmrBaseSensor): + """ + Asks for the state of the step until it reaches any of the target states. + If it fails the sensor errors, failing the task. + + With the default target states, sensor waits step to be completed. + + :param job_flow_id: job_flow_id which contains the step check the state of + :type job_flow_id: str + :param step_id: step to check the state of + :type step_id: str + :param target_states: the target states, sensor waits until + step reaches any of these states + :type target_states: list[str] + :param failed_states: the failure states, sensor fails when + step reaches any of these states + :type failed_states: list[str] + """ + + template_fields = ["job_flow_id", "step_id", "target_states", "failed_states"] + template_ext = () + + @apply_defaults + def __init__( + self, + *, + job_flow_id: str, + step_id: str, + target_states: Optional[Iterable[str]] = None, + failed_states: Optional[Iterable[str]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.job_flow_id = job_flow_id + self.step_id = step_id + self.target_states = target_states or ["COMPLETED"] + self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"] + + def get_emr_response(self) -> Dict[str, Any]: + """ + Make an API call with boto3 and get details about the cluster step. + + .. seealso:: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#EMR.Client.describe_step + + :return: response + :rtype: dict[str, Any] + """ + emr_client = self.get_hook().get_conn() + + self.log.info("Poking step %s on cluster %s", self.step_id, self.job_flow_id) + return emr_client.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id) + + @staticmethod + def state_from_response(response: Dict[str, Any]) -> str: + """ + Get state from response dictionary. + + :param response: response from AWS API + :type response: dict[str, Any] + :return: execution state of the cluster step + :rtype: str + """ + return response["Step"]["Status"]["State"] + + @staticmethod + def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: + """ + Get failure message from response dictionary. + + :param response: response from AWS API + :type response: dict[str, Any] + :return: failure message + :rtype: Optional[str] + """ + fail_details = response["Step"]["Status"].get("FailureDetails") + if fail_details: + return "for reason {} with message {} and log file {}".format( + fail_details.get("Reason"), + fail_details.get("Message"), + fail_details.get("LogFile"), + ) + return None diff --git a/reference/providers/amazon/aws/sensors/glacier.py b/reference/providers/amazon/aws/sensors/glacier.py new file mode 100644 index 0000000..4389eb4 --- /dev/null +++ b/reference/providers/amazon/aws/sensors/glacier.py @@ -0,0 +1,105 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from enum import Enum +from typing import Any + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.glacier import GlacierHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class JobStatus(Enum): + """Glacier jobs description""" + + IN_PROGRESS = "InProgress" + SUCCEEDED = "Succeeded" + + +class GlacierJobOperationSensor(BaseSensorOperator): + """ + Glacier sensor for checking job state. This operator runs only in reschedule mode. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GlacierJobOperationSensor` + + :param aws_conn_id: The reference to the AWS connection details + :type aws_conn_id: str + :param vault_name: name of Glacier vault on which job is executed + :type vault_name: str + :param job_id: the job ID was returned by retrieve_inventory() + :type job_id: str + :param poke_interval: Time in seconds that the job should wait in + between each tries + :type poke_interval: float + :param mode: How the sensor operates. + Options are: ``{ poke | reschedule }``, default is ``poke``. + When set to ``poke`` the sensor is taking up a worker slot for its + whole execution time and sleeps between pokes. Use this mode if the + expected runtime of the sensor is short or if a short poke interval + is required. Note that the sensor will hold onto a worker slot and + a pool slot for the duration of the sensor's runtime in this mode. + When set to ``reschedule`` the sensor task frees the worker slot when + the criteria is not yet met and it's rescheduled at a later time. Use + this mode if the time before the criteria is met is expected to be + quite long. The poke interval should be more than one minute to + prevent too much load on the scheduler. + :type mode: str + """ + + template_fields = ["vault_name", "job_id"] + + @apply_defaults + def __init__( + self, + *, + aws_conn_id: str = "aws_default", + vault_name: str, + job_id: str, + poke_interval: int = 60 * 20, + mode: str = "reschedule", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.aws_conn_id = aws_conn_id + self.vault_name = vault_name + self.job_id = job_id + self.poke_interval = poke_interval + self.mode = mode + + def poke(self, context) -> bool: + hook = GlacierHook(aws_conn_id=self.aws_conn_id) + response = hook.describe_job(vault_name=self.vault_name, job_id=self.job_id) + + if response["StatusCode"] == JobStatus.SUCCEEDED.value: + self.log.info( + "Job status: %s, code status: %s", + response["Action"], + response["StatusCode"], + ) + self.log.info("Job finished successfully") + return True + elif response["StatusCode"] == JobStatus.IN_PROGRESS.value: + self.log.info("Processing...") + self.log.warning("Code status: %s", response["StatusCode"]) + return False + else: + raise AirflowException( + f'Sensor failed. Job status: {response["Action"]}, code status: {response["StatusCode"]}' + ) diff --git a/reference/providers/amazon/aws/sensors/glue.py b/reference/providers/amazon/aws/sensors/glue.py new file mode 100644 index 0000000..1e85b9a --- /dev/null +++ b/reference/providers/amazon/aws/sensors/glue.py @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.glue import AwsGlueJobHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class AwsGlueJobSensor(BaseSensorOperator): + """ + Waits for an AWS Glue Job to reach any of the status below + 'FAILED', 'STOPPED', 'SUCCEEDED' + + :param job_name: The AWS Glue Job unique name + :type job_name: str + :param run_id: The AWS Glue current running job identifier + :type run_id: str + """ + + template_fields = ("job_name", "run_id") + + @apply_defaults + def __init__( + self, *, job_name: str, run_id: str, aws_conn_id: str = "aws_default", **kwargs + ): + super().__init__(**kwargs) + self.job_name = job_name + self.run_id = run_id + self.aws_conn_id = aws_conn_id + self.success_states = ["SUCCEEDED"] + self.errored_states = ["FAILED", "STOPPED", "TIMEOUT"] + + def poke(self, context): + hook = AwsGlueJobHook(aws_conn_id=self.aws_conn_id) + self.log.info( + "Poking for job run status :for Glue Job %s and ID %s", + self.job_name, + self.run_id, + ) + job_state = hook.get_job_state(job_name=self.job_name, run_id=self.run_id) + if job_state in self.success_states: + self.log.info("Exiting Job %s Run State: %s", self.run_id, job_state) + return True + elif job_state in self.errored_states: + job_error_message = ( + "Exiting Job " + self.run_id + " Run State: " + job_state + ) + raise AirflowException(job_error_message) + else: + return False diff --git a/reference/providers/amazon/aws/sensors/glue_catalog_partition.py b/reference/providers/amazon/aws/sensors/glue_catalog_partition.py new file mode 100644 index 0000000..7fbf7f2 --- /dev/null +++ b/reference/providers/amazon/aws/sensors/glue_catalog_partition.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class AwsGlueCatalogPartitionSensor(BaseSensorOperator): + """ + Waits for a partition to show up in AWS Glue Catalog. + + :param table_name: The name of the table to wait for, supports the dot + notation (my_database.my_table) + :type table_name: str + :param expression: The partition clause to wait for. This is passed as + is to the AWS Glue Catalog API's get_partitions function, + and supports SQL like notation as in ``ds='2015-01-01' + AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"``. + See https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html + #aws-glue-api-catalog-partitions-GetPartitions + :type expression: str + :param aws_conn_id: ID of the Airflow connection where + credentials and extra configuration are stored + :type aws_conn_id: str + :param region_name: Optional aws region name (example: us-east-1). Uses region from connection + if not specified. + :type region_name: str + :param database_name: The name of the catalog database where the partitions reside. + :type database_name: str + :param poke_interval: Time in seconds that the job should wait in + between each tries + :type poke_interval: int + """ + + template_fields = ( + "database_name", + "table_name", + "expression", + ) + ui_color = "#C5CAE9" + + @apply_defaults + def __init__( + self, + *, + table_name: str, + expression: str = "ds='{{ ds }}'", + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + database_name: str = "default", + poke_interval: int = 60 * 3, + **kwargs, + ): + super().__init__(poke_interval=poke_interval, **kwargs) + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.table_name = table_name + self.expression = expression + self.database_name = database_name + self.hook: Optional[AwsGlueCatalogHook] = None + + def poke(self, context): + """Checks for existence of the partition in the AWS Glue Catalog table""" + if "." in self.table_name: + self.database_name, self.table_name = self.table_name.split(".") + self.log.info( + "Poking for table %s. %s, expression %s", + self.database_name, + self.table_name, + self.expression, + ) + + return self.get_hook().check_for_partition( + self.database_name, self.table_name, self.expression + ) + + def get_hook(self) -> AwsGlueCatalogHook: + """Gets the AwsGlueCatalogHook""" + if self.hook: + return self.hook + + self.hook = AwsGlueCatalogHook( + aws_conn_id=self.aws_conn_id, region_name=self.region_name + ) + return self.hook diff --git a/reference/providers/amazon/aws/sensors/glue_crawler.py b/reference/providers/amazon/aws/sensors/glue_crawler.py new file mode 100644 index 0000000..529557a --- /dev/null +++ b/reference/providers/amazon/aws/sensors/glue_crawler.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.glue_crawler import AwsGlueCrawlerHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class AwsGlueCrawlerSensor(BaseSensorOperator): + """ + Waits for an AWS Glue crawler to reach any of the statuses below + 'FAILED', 'CANCELLED', 'SUCCEEDED' + + :param crawler_name: The AWS Glue crawler unique name + :type crawler_name: str + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + :type aws_conn_id: str + """ + + @apply_defaults + def __init__( + self, *, crawler_name: str, aws_conn_id: str = "aws_default", **kwargs + ) -> None: + super().__init__(**kwargs) + self.crawler_name = crawler_name + self.aws_conn_id = aws_conn_id + self.success_statuses = "SUCCEEDED" + self.errored_statuses = ("FAILED", "CANCELLED") + self.hook: Optional[AwsGlueCrawlerHook] = None + + def poke(self, context): + hook = self.get_hook() + self.log.info("Poking for AWS Glue crawler: %s", self.crawler_name) + crawler_state = hook.get_crawler(self.crawler_name)["State"] + if crawler_state == "READY": + self.log.info("State: %s", crawler_state) + crawler_status = hook.get_crawler(self.crawler_name)["LastCrawl"]["Status"] + if crawler_status == self.success_statuses: + self.log.info("Status: %s", crawler_status) + return True + else: + raise AirflowException(f"Status: {crawler_status}") + else: + return False + + def get_hook(self) -> AwsGlueCrawlerHook: + """Returns a new or pre-existing AwsGlueCrawlerHook""" + if self.hook: + return self.hook + + self.hook = AwsGlueCrawlerHook(aws_conn_id=self.aws_conn_id) + return self.hook diff --git a/reference/providers/amazon/aws/sensors/redshift.py b/reference/providers/amazon/aws/sensors/redshift.py new file mode 100644 index 0000000..0abca06 --- /dev/null +++ b/reference/providers/amazon/aws/sensors/redshift.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from airflow.providers.amazon.aws.hooks.redshift import RedshiftHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class AwsRedshiftClusterSensor(BaseSensorOperator): + """ + Waits for a Redshift cluster to reach a specific status. + + :param cluster_identifier: The identifier for the cluster being pinged. + :type cluster_identifier: str + :param target_status: The cluster status desired. + :type target_status: str + """ + + template_fields = ("cluster_identifier", "target_status") + + @apply_defaults + def __init__( + self, + *, + cluster_identifier: str, + target_status: str = "available", + aws_conn_id: str = "aws_default", + **kwargs, + ): + super().__init__(**kwargs) + self.cluster_identifier = cluster_identifier + self.target_status = target_status + self.aws_conn_id = aws_conn_id + self.hook: Optional[RedshiftHook] = None + + def poke(self, context): + self.log.info( + "Poking for status : %s\nfor cluster %s", + self.target_status, + self.cluster_identifier, + ) + return ( + self.get_hook().cluster_status(self.cluster_identifier) + == self.target_status + ) + + def get_hook(self) -> RedshiftHook: + """Create and return a RedshiftHook""" + if self.hook: + return self.hook + + self.hook = RedshiftHook(aws_conn_id=self.aws_conn_id) + return self.hook diff --git a/reference/providers/amazon/aws/sensors/s3_key.py b/reference/providers/amazon/aws/sensors/s3_key.py new file mode 100644 index 0000000..e5c2b09 --- /dev/null +++ b/reference/providers/amazon/aws/sensors/s3_key.py @@ -0,0 +1,217 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import re +from typing import Callable, List, Optional, Union +from urllib.parse import urlparse + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class S3KeySensor(BaseSensorOperator): + """ + Waits for a key (a file-like instance on S3) to be present in a S3 bucket. + S3 being a key/value it does not support folders. The path is just a key + a resource. + + :param bucket_key: The key being waited on. Supports full s3:// style url + or relative path from root level. When it's specified as a full s3:// + url, please leave bucket_name as `None`. + :type bucket_key: str + :param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key`` + is not provided as a full s3:// url. + :type bucket_name: str + :param wildcard_match: whether the bucket_key should be interpreted as a + Unix wildcard pattern + :type wildcard_match: bool + :param aws_conn_id: a reference to the s3 connection + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str + """ + + template_fields = ("bucket_key", "bucket_name") + + @apply_defaults + def __init__( + self, + *, + bucket_key: str, + bucket_name: Optional[str] = None, + wildcard_match: bool = False, + aws_conn_id: str = "aws_default", + verify: Optional[Union[str, bool]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.bucket_name = bucket_name + self.bucket_key = bucket_key + self.wildcard_match = wildcard_match + self.aws_conn_id = aws_conn_id + self.verify = verify + self.hook: Optional[S3Hook] = None + + def poke(self, context): + + if self.bucket_name is None: + parsed_url = urlparse(self.bucket_key) + if parsed_url.netloc == "": + raise AirflowException( + "If key is a relative path from root, please provide a bucket_name" + ) + self.bucket_name = parsed_url.netloc + self.bucket_key = parsed_url.path.lstrip("/") + else: + parsed_url = urlparse(self.bucket_key) + if parsed_url.scheme != "" or parsed_url.netloc != "": + raise AirflowException( + "If bucket_name is provided, bucket_key" + + " should be relative path from root" + + " level, rather than a full s3:// url" + ) + + self.log.info("Poking for key : s3://%s/%s", self.bucket_name, self.bucket_key) + if self.wildcard_match: + return self.get_hook().check_for_wildcard_key( + self.bucket_key, self.bucket_name + ) + return self.get_hook().check_for_key(self.bucket_key, self.bucket_name) + + def get_hook(self) -> S3Hook: + """Create and return an S3Hook""" + if self.hook: + return self.hook + + self.hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + return self.hook + + +class S3KeySizeSensor(S3KeySensor): + """ + Waits for a key (a file-like instance on S3) to be present and be more than + some size in a S3 bucket. + S3 being a key/value it does not support folders. The path is just a key + a resource. + + :param bucket_key: The key being waited on. Supports full s3:// style url + or relative path from root level. When it's specified as a full s3:// + url, please leave bucket_name as `None`. + :type bucket_key: str + :param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key`` + is not provided as a full s3:// url. + :type bucket_name: str + :param wildcard_match: whether the bucket_key should be interpreted as a + Unix wildcard pattern + :type wildcard_match: bool + :param aws_conn_id: a reference to the s3 connection + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str + :type check_fn: Optional[Callable[..., bool]] + :param check_fn: Function that receives the list of the S3 objects, + and returns the boolean: + - ``True``: a certain criteria is met + - ``False``: the criteria isn't met + **Example**: Wait for any S3 object size more than 1 megabyte :: + + def check_fn(self, data: List) -> bool: + return any(f.get('Size', 0) > 1048576 for f in data if isinstance(f, dict)) + """ + + @apply_defaults + def __init__( + self, + *, + check_fn: Optional[Callable[..., bool]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.check_fn_user = check_fn + + def poke(self, context): + if super().poke(context=context) is False: + return False + + s3_objects = self.get_files(s3_hook=self.get_hook()) + if not s3_objects: + return False + check_fn = self.check_fn if self.check_fn_user is None else self.check_fn_user + return check_fn(s3_objects) + + def get_files(self, s3_hook: S3Hook, delimiter: Optional[str] = "/") -> List: + """Gets a list of files in the bucket""" + prefix = self.bucket_key + config = { + "PageSize": None, + "MaxItems": None, + } + if self.wildcard_match: + prefix = re.split(r"[*]", self.bucket_key, 1)[0] + + paginator = s3_hook.get_conn().get_paginator("list_objects_v2") + response = paginator.paginate( + Bucket=self.bucket_name, + Prefix=prefix, + Delimiter=delimiter, + PaginationConfig=config, + ) + keys = [] + for page in response: + if "Contents" in page: + _temp = [ + k + for k in page["Contents"] + if isinstance(k.get("Size", None), (int, float)) + ] + keys = keys + _temp + return keys + + def check_fn( + self, data: List, object_min_size: Optional[Union[int, float]] = 0 + ) -> bool: + """Default function for checking that S3 Objects have size more than 0 + + :param data: List of the objects in S3 bucket. + :type data: list + :param object_min_size: Checks if the objects sizes are greater then this value. + :type object_min_size: int + """ + return all( + f.get("Size", 0) > object_min_size for f in data if isinstance(f, dict) + ) diff --git a/reference/providers/amazon/aws/sensors/s3_keys_unchanged.py b/reference/providers/amazon/aws/sensors/s3_keys_unchanged.py new file mode 100644 index 0000000..87ca581 --- /dev/null +++ b/reference/providers/amazon/aws/sensors/s3_keys_unchanged.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from datetime import datetime +from typing import Optional, Set, Union + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.sensors.base import BaseSensorOperator, poke_mode_only +from airflow.utils.decorators import apply_defaults + + +@poke_mode_only +class S3KeysUnchangedSensor(BaseSensorOperator): + """ + Checks for changes in the number of objects at prefix in AWS S3 + bucket and returns True if the inactivity period has passed with no + increase in the number of objects. Note, this sensor will not behave correctly + in reschedule mode, as the state of the listed objects in the S3 bucket will + be lost between rescheduled invocations. + + :param bucket_name: Name of the S3 bucket + :type bucket_name: str + :param prefix: The prefix being waited on. Relative path from bucket root level. + :type prefix: str + :param aws_conn_id: a reference to the s3 connection + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: Optional[Union[bool, str]] + :param inactivity_period: The total seconds of inactivity to designate + keys unchanged. Note, this mechanism is not real time and + this operator may not return until a poke_interval after this period + has passed with no additional objects sensed. + :type inactivity_period: float + :param min_objects: The minimum number of objects needed for keys unchanged + sensor to be considered valid. + :type min_objects: int + :param previous_objects: The set of object ids found during the last poke. + :type previous_objects: Optional[Set[str]] + :param allow_delete: Should this sensor consider objects being deleted + between pokes valid behavior. If true a warning message will be logged + when this happens. If false an error will be raised. + :type allow_delete: bool + """ + + template_fields = ("bucket_name", "prefix") + + @apply_defaults + def __init__( + self, + *, + bucket_name: str, + prefix: str, + aws_conn_id: str = "aws_default", + verify: Optional[Union[bool, str]] = None, + inactivity_period: float = 60 * 60, + min_objects: int = 1, + previous_objects: Optional[Set[str]] = None, + allow_delete: bool = True, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + + self.bucket_name = bucket_name + self.prefix = prefix + if inactivity_period < 0: + raise ValueError("inactivity_period must be non-negative") + self.inactivity_period = inactivity_period + self.min_objects = min_objects + self.previous_objects = previous_objects or set() + self.inactivity_seconds = 0 + self.allow_delete = allow_delete + self.aws_conn_id = aws_conn_id + self.verify = verify + self.last_activity_time: Optional[datetime] = None + + @cached_property + def hook(self): + """Returns S3Hook.""" + return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + + def is_keys_unchanged(self, current_objects: Set[str]) -> bool: + """ + Checks whether new objects have been uploaded and the inactivity_period + has passed and updates the state of the sensor accordingly. + + :param current_objects: set of object ids in bucket during last poke. + :type current_objects: set[str] + """ + current_num_objects = len(current_objects) + if current_objects > self.previous_objects: + # When new objects arrived, reset the inactivity_seconds + # and update previous_objects for the next poke. + self.log.info( + "New objects found at %s, resetting last_activity_time.", + os.path.join(self.bucket_name, self.prefix), + ) + self.log.debug("New objects: %s", current_objects - self.previous_objects) + self.last_activity_time = datetime.now() + self.inactivity_seconds = 0 + self.previous_objects = current_objects + return False + + if self.previous_objects - current_objects: + # During the last poke interval objects were deleted. + if self.allow_delete: + deleted_objects = self.previous_objects - current_objects + self.previous_objects = current_objects + self.last_activity_time = datetime.now() + self.log.info( + "Objects were deleted during the last poke interval. Updating the " + "file counter and resetting last_activity_time:\n%s", + deleted_objects, + ) + return False + + raise AirflowException( + "Illegal behavior: objects were deleted in %s between pokes." + % os.path.join(self.bucket_name, self.prefix) + ) + + if self.last_activity_time: + self.inactivity_seconds = int( + (datetime.now() - self.last_activity_time).total_seconds() + ) + else: + # Handles the first poke where last inactivity time is None. + self.last_activity_time = datetime.now() + self.inactivity_seconds = 0 + + if self.inactivity_seconds >= self.inactivity_period: + path = os.path.join(self.bucket_name, self.prefix) + + if current_num_objects >= self.min_objects: + self.log.info( + "SUCCESS: \nSensor found %s objects at %s.\n" + "Waited at least %s seconds, with no new objects uploaded.", + current_num_objects, + path, + self.inactivity_period, + ) + return True + + self.log.error( + "FAILURE: Inactivity Period passed, not enough objects found in %s", + path, + ) + + return False + return False + + def poke(self, context): + return self.is_keys_unchanged( + set(self.hook.list_keys(self.bucket_name, prefix=self.prefix)) + ) diff --git a/reference/providers/amazon/aws/sensors/s3_prefix.py b/reference/providers/amazon/aws/sensors/s3_prefix.py new file mode 100644 index 0000000..841fb0c --- /dev/null +++ b/reference/providers/amazon/aws/sensors/s3_prefix.py @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional, Union + +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class S3PrefixSensor(BaseSensorOperator): + """ + Waits for a prefix to exist. A prefix is the first part of a key, + thus enabling checking of constructs similar to glob airfl* or + SQL LIKE 'airfl%'. There is the possibility to precise a delimiter to + indicate the hierarchy or keys, meaning that the match will stop at that + delimiter. Current code accepts sane delimiters, i.e. characters that + are NOT special characters in the Python regex engine. + + :param bucket_name: Name of the S3 bucket + :type bucket_name: str + :param prefix: The prefix being waited on. Relative path from bucket root level. + :type prefix: str + :param delimiter: The delimiter intended to show hierarchy. + Defaults to '/'. + :type delimiter: str + :param aws_conn_id: a reference to the s3 connection + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str + """ + + template_fields = ("prefix", "bucket_name") + + @apply_defaults + def __init__( + self, + *, + bucket_name: str, + prefix: str, + delimiter: str = "/", + aws_conn_id: str = "aws_default", + verify: Optional[Union[str, bool]] = None, + **kwargs, + ): + super().__init__(**kwargs) + # Parse + self.bucket_name = bucket_name + self.prefix = prefix + self.delimiter = delimiter + self.full_url = "s3://" + bucket_name + "/" + prefix + self.aws_conn_id = aws_conn_id + self.verify = verify + self.hook: Optional[S3Hook] = None + + def poke(self, context): + self.log.info( + "Poking for prefix : %s in bucket s3://%s", self.prefix, self.bucket_name + ) + return self.get_hook().check_for_prefix( + prefix=self.prefix, delimiter=self.delimiter, bucket_name=self.bucket_name + ) + + def get_hook(self) -> S3Hook: + """Create and return an S3Hook""" + if self.hook: + return self.hook + + self.hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + return self.hook diff --git a/reference/providers/amazon/aws/sensors/sagemaker_base.py b/reference/providers/amazon/aws/sensors/sagemaker_base.py new file mode 100644 index 0000000..cd63d04 --- /dev/null +++ b/reference/providers/amazon/aws/sensors/sagemaker_base.py @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional, Set + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class SageMakerBaseSensor(BaseSensorOperator): + """ + Contains general sensor behavior for SageMaker. + Subclasses should implement get_sagemaker_response() + and state_from_response() methods. + Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods. + """ + + ui_color = "#ededed" + + @apply_defaults + def __init__(self, *, aws_conn_id: str = "aws_default", **kwargs): + super().__init__(**kwargs) + self.aws_conn_id = aws_conn_id + self.hook: Optional[SageMakerHook] = None + + def get_hook(self) -> SageMakerHook: + """Get SageMakerHook""" + if self.hook: + return self.hook + + self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) + return self.hook + + def poke(self, context): + response = self.get_sagemaker_response() + + if not response["ResponseMetadata"]["HTTPStatusCode"] == 200: + self.log.info("Bad HTTP response: %s", response) + return False + + state = self.state_from_response(response) + + self.log.info("Job currently %s", state) + + if state in self.non_terminal_states(): + return False + + if state in self.failed_states(): + failed_reason = self.get_failed_reason_from_response(response) + raise AirflowException( + f"Sagemaker job failed for the following reason: {failed_reason}" + ) + return True + + def non_terminal_states(self) -> Set[str]: + """Placeholder for returning states with should not terminate.""" + raise NotImplementedError("Please implement non_terminal_states() in subclass") + + def failed_states(self) -> Set[str]: + """Placeholder for returning states with are considered failed.""" + raise NotImplementedError("Please implement failed_states() in subclass") + + def get_sagemaker_response(self) -> Optional[dict]: + """Placeholder for checking status of a SageMaker task.""" + raise NotImplementedError( + "Please implement get_sagemaker_response() in subclass" + ) + + def get_failed_reason_from_response( + self, response: dict + ) -> str: # pylint: disable=unused-argument + """Placeholder for extracting the reason for failure from an AWS response.""" + return "Unknown" + + def state_from_response(self, response: dict) -> str: + """Placeholder for extracting the state from an AWS response.""" + raise NotImplementedError("Please implement state_from_response() in subclass") diff --git a/reference/providers/amazon/aws/sensors/sagemaker_endpoint.py b/reference/providers/amazon/aws/sensors/sagemaker_endpoint.py new file mode 100644 index 0000000..45fd9bd --- /dev/null +++ b/reference/providers/amazon/aws/sensors/sagemaker_endpoint.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor +from airflow.utils.decorators import apply_defaults + + +class SageMakerEndpointSensor(SageMakerBaseSensor): + """ + Asks for the state of the endpoint state until it reaches a terminal state. + If it fails the sensor errors, the task fails. + + :param job_name: job_name of the endpoint instance to check the state of + :type job_name: str + """ + + template_fields = ["endpoint_name"] + template_ext = () + + @apply_defaults + def __init__(self, *, endpoint_name, **kwargs): + super().__init__(**kwargs) + self.endpoint_name = endpoint_name + + def non_terminal_states(self): + return SageMakerHook.endpoint_non_terminal_states + + def failed_states(self): + return SageMakerHook.failed_states + + def get_sagemaker_response(self): + self.log.info("Poking Sagemaker Endpoint %s", self.endpoint_name) + return self.get_hook().describe_endpoint(self.endpoint_name) + + def get_failed_reason_from_response(self, response): + return response["FailureReason"] + + def state_from_response(self, response): + return response["EndpointStatus"] diff --git a/reference/providers/amazon/aws/sensors/sagemaker_training.py b/reference/providers/amazon/aws/sensors/sagemaker_training.py new file mode 100644 index 0000000..6c50d71 --- /dev/null +++ b/reference/providers/amazon/aws/sensors/sagemaker_training.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import time +from typing import Optional + +from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook +from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor +from airflow.utils.decorators import apply_defaults + + +class SageMakerTrainingSensor(SageMakerBaseSensor): + """ + Asks for the state of the training state until it reaches a terminal state. + If it fails the sensor errors, failing the task. + + :param job_name: name of the SageMaker training job to check the state of + :type job_name: str + :param print_log: if the operator should print the cloudwatch log + :type print_log: bool + """ + + template_fields = ["job_name"] + template_ext = () + + @apply_defaults + def __init__(self, *, job_name, print_log=True, **kwargs): + super().__init__(**kwargs) + self.job_name = job_name + self.print_log = print_log + self.positions = {} + self.stream_names = [] + self.instance_count: Optional[int] = None + self.state: Optional[int] = None + self.last_description = None + self.last_describe_job_call = None + self.log_resource_inited = False + + def init_log_resource(self, hook: SageMakerHook) -> None: + """Set tailing LogState for associated training job.""" + description = hook.describe_training_job(self.job_name) + self.instance_count = description["ResourceConfig"]["InstanceCount"] + + status = description["TrainingJobStatus"] + job_already_completed = status not in self.non_terminal_states() + self.state = ( + LogState.TAILING if not job_already_completed else LogState.COMPLETE + ) + self.last_description = description + self.last_describe_job_call = time.monotonic() + self.log_resource_inited = True + + def non_terminal_states(self): + return SageMakerHook.non_terminal_states + + def failed_states(self): + return SageMakerHook.failed_states + + def get_sagemaker_response(self): + if self.print_log: + if not self.log_resource_inited: + self.init_log_resource(self.get_hook()) + ( + self.state, + self.last_description, + self.last_describe_job_call, + ) = self.get_hook().describe_training_job_with_log( + self.job_name, + self.positions, + self.stream_names, + self.instance_count, + self.state, + self.last_description, + self.last_describe_job_call, + ) + else: + self.last_description = self.get_hook().describe_training_job(self.job_name) + + status = self.state_from_response(self.last_description) + if ( + status not in self.non_terminal_states() + and status not in self.failed_states() + ): + billable_time = ( + self.last_description["TrainingEndTime"] + - self.last_description["TrainingStartTime"] + ) * self.last_description["ResourceConfig"]["InstanceCount"] + self.log.info( + "Billable seconds: %s", int(billable_time.total_seconds()) + 1 + ) + + return self.last_description + + def get_failed_reason_from_response(self, response): + return response["FailureReason"] + + def state_from_response(self, response): + return response["TrainingJobStatus"] diff --git a/reference/providers/amazon/aws/sensors/sagemaker_transform.py b/reference/providers/amazon/aws/sensors/sagemaker_transform.py new file mode 100644 index 0000000..af2789c --- /dev/null +++ b/reference/providers/amazon/aws/sensors/sagemaker_transform.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor +from airflow.utils.decorators import apply_defaults + + +class SageMakerTransformSensor(SageMakerBaseSensor): + """ + Asks for the state of the transform state until it reaches a terminal state. + The sensor will error if the job errors, throwing a AirflowException + containing the failure reason. + + :param job_name: job_name of the transform job instance to check the state of + :type job_name: str + """ + + template_fields = ["job_name"] + template_ext = () + + @apply_defaults + def __init__(self, *, job_name: str, **kwargs): + super().__init__(**kwargs) + self.job_name = job_name + + def non_terminal_states(self): + return SageMakerHook.non_terminal_states + + def failed_states(self): + return SageMakerHook.failed_states + + def get_sagemaker_response(self): + self.log.info("Poking Sagemaker Transform Job %s", self.job_name) + return self.get_hook().describe_transform_job(self.job_name) + + def get_failed_reason_from_response(self, response): + return response["FailureReason"] + + def state_from_response(self, response): + return response["TransformJobStatus"] diff --git a/reference/providers/amazon/aws/sensors/sagemaker_tuning.py b/reference/providers/amazon/aws/sensors/sagemaker_tuning.py new file mode 100644 index 0000000..3619b0d --- /dev/null +++ b/reference/providers/amazon/aws/sensors/sagemaker_tuning.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor +from airflow.utils.decorators import apply_defaults + + +class SageMakerTuningSensor(SageMakerBaseSensor): + """ + Asks for the state of the tuning state until it reaches a terminal state. + The sensor will error if the job errors, throwing a AirflowException + containing the failure reason. + + :param job_name: job_name of the tuning instance to check the state of + :type job_name: str + """ + + template_fields = ["job_name"] + template_ext = () + + @apply_defaults + def __init__(self, *, job_name: str, **kwargs): + super().__init__(**kwargs) + self.job_name = job_name + + def non_terminal_states(self): + return SageMakerHook.non_terminal_states + + def failed_states(self): + return SageMakerHook.failed_states + + def get_sagemaker_response(self): + self.log.info("Poking Sagemaker Tuning Job %s", self.job_name) + return self.get_hook().describe_tuning_job(self.job_name) + + def get_failed_reason_from_response(self, response): + return response["FailureReason"] + + def state_from_response(self, response): + return response["HyperParameterTuningJobStatus"] diff --git a/reference/providers/amazon/aws/sensors/sqs.py b/reference/providers/amazon/aws/sensors/sqs.py new file mode 100644 index 0000000..57afd78 --- /dev/null +++ b/reference/providers/amazon/aws/sensors/sqs.py @@ -0,0 +1,111 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Reads and then deletes the message from SQS queue""" +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.sqs import SQSHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class SQSSensor(BaseSensorOperator): + """ + Get messages from an SQS queue and then deletes the message from the SQS queue. + If deletion of messages fails an AirflowException is thrown otherwise, the message + is pushed through XCom with the key ``message``. + + :param aws_conn_id: AWS connection id + :type aws_conn_id: str + :param sqs_queue: The SQS queue url (templated) + :type sqs_queue: str + :param max_messages: The maximum number of messages to retrieve for each poke (templated) + :type max_messages: int + :param wait_time_seconds: The time in seconds to wait for receiving messages (default: 1 second) + :type wait_time_seconds: int + """ + + template_fields = ("sqs_queue", "max_messages") + + @apply_defaults + def __init__( + self, + *, + sqs_queue, + aws_conn_id: str = "aws_default", + max_messages: int = 5, + wait_time_seconds: int = 1, + **kwargs, + ): + super().__init__(**kwargs) + self.sqs_queue = sqs_queue + self.aws_conn_id = aws_conn_id + self.max_messages = max_messages + self.wait_time_seconds = wait_time_seconds + self.hook: Optional[SQSHook] = None + + def poke(self, context): + """ + Check for message on subscribed queue and write to xcom the message with key ``messages`` + + :param context: the context object + :type context: dict + :return: ``True`` if message is available or ``False`` + """ + sqs_conn = self.get_hook().get_conn() + + self.log.info("SQSSensor checking for message on queue: %s", self.sqs_queue) + + messages = sqs_conn.receive_message( + QueueUrl=self.sqs_queue, + MaxNumberOfMessages=self.max_messages, + WaitTimeSeconds=self.wait_time_seconds, + ) + + self.log.info("received message %s", str(messages)) + + if "Messages" in messages and messages["Messages"]: + entries = [ + {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]} + for message in messages["Messages"] + ] + + result = sqs_conn.delete_message_batch( + QueueUrl=self.sqs_queue, Entries=entries + ) + + if "Successful" in result: + context["ti"].xcom_push(key="messages", value=messages) + return True + else: + raise AirflowException( + "Delete SQS Messages failed " + + str(result) + + " for messages " + + str(messages) + ) + + return False + + def get_hook(self) -> SQSHook: + """Create and return an SQSHook""" + if self.hook: + return self.hook + + self.hook = SQSHook(aws_conn_id=self.aws_conn_id) + return self.hook diff --git a/reference/providers/amazon/aws/sensors/step_function_execution.py b/reference/providers/amazon/aws/sensors/step_function_execution.py new file mode 100644 index 0000000..b7ff7da --- /dev/null +++ b/reference/providers/amazon/aws/sensors/step_function_execution.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class StepFunctionExecutionSensor(BaseSensorOperator): + """ + Asks for the state of the Step Function State Machine Execution until it + reaches a failure state or success state. + If it fails, failing the task. + + On successful completion of the Execution the Sensor will do an XCom Push + of the State Machine's output to `output` + + :param execution_arn: execution_arn to check the state of + :type execution_arn: str + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + :type aws_conn_id: str + """ + + INTERMEDIATE_STATES = ("RUNNING",) + FAILURE_STATES = ( + "FAILED", + "TIMED_OUT", + "ABORTED", + ) + SUCCESS_STATES = ("SUCCEEDED",) + + template_fields = ["execution_arn"] + template_ext = () + ui_color = "#66c3ff" + + @apply_defaults + def __init__( + self, + *, + execution_arn: str, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.execution_arn = execution_arn + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.hook: Optional[StepFunctionHook] = None + + def poke(self, context): + execution_status = self.get_hook().describe_execution(self.execution_arn) + state = execution_status["status"] + output = ( + json.loads(execution_status["output"]) + if "output" in execution_status + else None + ) + + if state in self.FAILURE_STATES: + raise AirflowException( + f"Step Function sensor failed. State Machine Output: {output}" + ) + + if state in self.INTERMEDIATE_STATES: + return False + + self.log.info("Doing xcom_push of output") + self.xcom_push(context, "output", output) + return True + + def get_hook(self) -> StepFunctionHook: + """Create and return a StepFunctionHook""" + if self.hook: + return self.hook + + self.hook = StepFunctionHook( + aws_conn_id=self.aws_conn_id, region_name=self.region_name + ) + return self.hook diff --git a/reference/providers/amazon/aws/transfers/__init__.py b/reference/providers/amazon/aws/transfers/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/amazon/aws/transfers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/amazon/aws/transfers/dynamodb_to_s3.py b/reference/providers/amazon/aws/transfers/dynamodb_to_s3.py new file mode 100644 index 0000000..5a45165 --- /dev/null +++ b/reference/providers/amazon/aws/transfers/dynamodb_to_s3.py @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +""" +This module contains operators to replicate records from +DynamoDB table to S3. +""" +import json +from copy import copy +from os.path import getsize +from tempfile import NamedTemporaryFile +from typing import IO, Any, Callable, Dict, Optional +from uuid import uuid4 + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.utils.decorators import apply_defaults + + +def _convert_item_to_json_bytes(item: Dict[str, Any]) -> bytes: + return (json.dumps(item) + "\n").encode("utf-8") + + +def _upload_file_to_s3(file_obj: IO, bucket_name: str, s3_key_prefix: str) -> None: + s3_client = S3Hook().get_conn() + file_obj.seek(0) + s3_client.upload_file( + Filename=file_obj.name, + Bucket=bucket_name, + Key=s3_key_prefix + str(uuid4()), + ) + + +class DynamoDBToS3Operator(BaseOperator): + """ + Replicates records from a DynamoDB table to S3. + It scans a DynamoDB table and write the received records to a file + on the local filesystem. It flushes the file to S3 once the file size + exceeds the file size limit specified by the user. + + Users can also specify a filtering criteria using dynamodb_scan_kwargs + to only replicate records that satisfy the criteria. + + To parallelize the replication, users can create multiple tasks of DynamoDBToS3Operator. + For instance to replicate with parallelism of 2, create two tasks like: + + .. code-block:: python + + op1 = DynamoDBToS3Operator( + task_id='replicator-1', + dynamodb_table_name='hello', + dynamodb_scan_kwargs={ + 'TotalSegments': 2, + 'Segment': 0, + }, + ... + ) + + op2 = DynamoDBToS3Operator( + task_id='replicator-2', + dynamodb_table_name='hello', + dynamodb_scan_kwargs={ + 'TotalSegments': 2, + 'Segment': 1, + }, + ... + ) + + :param dynamodb_table_name: Dynamodb table to replicate data from + :type dynamodb_table_name: str + :param s3_bucket_name: S3 bucket to replicate data to + :type s3_bucket_name: str + :param file_size: Flush file to s3 if file size >= file_size + :type file_size: int + :param dynamodb_scan_kwargs: kwargs pass to # noqa: E501 pylint: disable=line-too-long + :type dynamodb_scan_kwargs: Optional[Dict[str, Any]] + :param s3_key_prefix: Prefix of s3 object key + :type s3_key_prefix: Optional[str] + :param process_func: How we transforms a dynamodb item to bytes. By default we dump the json + :type process_func: Callable[[Dict[str, Any]], bytes] + """ + + @apply_defaults + def __init__( + self, + *, + dynamodb_table_name: str, + s3_bucket_name: str, + file_size: int, + dynamodb_scan_kwargs: Optional[Dict[str, Any]] = None, + s3_key_prefix: str = "", + process_func: Callable[[Dict[str, Any]], bytes] = _convert_item_to_json_bytes, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.file_size = file_size + self.process_func = process_func + self.dynamodb_table_name = dynamodb_table_name + self.dynamodb_scan_kwargs = dynamodb_scan_kwargs + self.s3_bucket_name = s3_bucket_name + self.s3_key_prefix = s3_key_prefix + + def execute(self, context) -> None: + table = AwsDynamoDBHook().get_conn().Table(self.dynamodb_table_name) + scan_kwargs = ( + copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {} + ) + err = None + f = NamedTemporaryFile() + try: + f = self._scan_dynamodb_and_upload_to_s3(f, scan_kwargs, table) + except Exception as e: + err = e + raise e + finally: + if err is None: + _upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix) + f.close() + + def _scan_dynamodb_and_upload_to_s3( + self, temp_file: IO, scan_kwargs: dict, table: Any + ) -> IO: + while True: + response = table.scan(**scan_kwargs) + items = response["Items"] + for item in items: + temp_file.write(self.process_func(item)) + + if "LastEvaluatedKey" not in response: + # no more items to scan + break + + last_evaluated_key = response["LastEvaluatedKey"] + scan_kwargs["ExclusiveStartKey"] = last_evaluated_key + + # Upload the file to S3 if reach file size limit + if getsize(temp_file.name) >= self.file_size: + _upload_file_to_s3(temp_file, self.s3_bucket_name, self.s3_key_prefix) + temp_file.close() + temp_file = NamedTemporaryFile() + return temp_file diff --git a/reference/providers/amazon/aws/transfers/exasol_to_s3.py b/reference/providers/amazon/aws/transfers/exasol_to_s3.py new file mode 100644 index 0000000..41c09ab --- /dev/null +++ b/reference/providers/amazon/aws/transfers/exasol_to_s3.py @@ -0,0 +1,127 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Transfers data from Exasol database into a S3 Bucket.""" + +from tempfile import NamedTemporaryFile +from typing import Dict, Optional + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.exasol.hooks.exasol import ExasolHook +from airflow.utils.decorators import apply_defaults + + +class ExasolToS3Operator(BaseOperator): + """ + Export data from Exasol database to AWS S3 bucket. + + :param query_or_table: the sql statement to be executed or table name to export + :type query_or_table: str + :param key: S3 key that will point to the file + :type key: str + :param bucket_name: Name of the bucket in which to store the file + :type bucket_name: str + :param replace: A flag to decide whether or not to overwrite the key + if it already exists. If replace is False and the key exists, an + error will be raised. + :type replace: bool + :param encrypt: If True, the file will be encrypted on the server-side + by S3 and will be stored in an encrypted form while at rest in S3. + :type encrypt: bool + :param gzip: If True, the file will be compressed locally + :type gzip: bool + :param acl_policy: String specifying the canned ACL policy for the file being + uploaded to the S3 bucket. + :type acl_policy: str + :param query_params: Query parameters passed to underlying ``export_to_file`` + method of :class:`~pyexasol.connection.ExaConnection`. + :type query_params: dict + :param export_params: Extra parameters passed to underlying ``export_to_file`` + method of :class:`~pyexasol.connection.ExaConnection`. + :type export_params: dict + """ + + template_fields = ( + "query_or_table", + "key", + "bucket_name", + "query_params", + "export_params", + ) + template_fields_renderers = { + "query_or_table": "sql", + "query_params": "json", + "export_params": "json", + } + template_ext = (".sql",) + ui_color = "#ededed" + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + query_or_table: str, + key: str, + bucket_name: Optional[str] = None, + replace: bool = False, + encrypt: bool = False, + gzip: bool = False, + acl_policy: Optional[str] = None, + query_params: Optional[Dict] = None, + export_params: Optional[Dict] = None, + exasol_conn_id: str = "exasol_default", + aws_conn_id: str = "aws_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.query_or_table = query_or_table + self.key = key + self.bucket_name = bucket_name + self.replace = replace + self.encrypt = encrypt + self.gzip = gzip + self.acl_policy = acl_policy + self.query_params = query_params + self.export_params = export_params + self.exasol_conn_id = exasol_conn_id + self.aws_conn_id = aws_conn_id + + def execute(self, context): + exasol_hook = ExasolHook(exasol_conn_id=self.exasol_conn_id) + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) + + with NamedTemporaryFile("w+") as file: + exasol_hook.export_to_file( + filename=file.name, + query_or_table=self.query_or_table, + export_params=self.export_params, + query_params=self.query_params, + ) + file.flush() + self.log.info("Uploading the data as %s", self.key) + s3_hook.load_file( + filename=file.name, + key=self.key, + bucket_name=self.bucket_name, + replace=self.replace, + encrypt=self.encrypt, + gzip=self.gzip, + acl_policy=self.acl_policy, + ) + self.log.info("Data uploaded") + return self.key diff --git a/reference/providers/amazon/aws/transfers/ftp_to_s3.py b/reference/providers/amazon/aws/transfers/ftp_to_s3.py new file mode 100644 index 0000000..d8c664e --- /dev/null +++ b/reference/providers/amazon/aws/transfers/ftp_to_s3.py @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tempfile import NamedTemporaryFile + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.ftp.hooks.ftp import FTPHook +from airflow.utils.decorators import apply_defaults + + +class FTPToS3Operator(BaseOperator): + """ + This operator enables the transferring of files from FTP server to S3. + + :param s3_bucket: The targeted s3 bucket in which upload the file to + :type s3_bucket: str + :param s3_key: The targeted s3 key. This is the specified file path for + uploading the file to S3. + :type s3_key: str + :param ftp_path: The ftp remote path, including the file. + :type ftp_path: str + :param ftp_conn_id: The ftp connection id. The name or identifier for + establishing a connection to the FTP server. + :type ftp_conn_id: str + :param aws_conn_id: The s3 connection id. The name or identifier for + establishing a connection to S3 + :type aws_conn_id: str + :param replace: A flag to decide whether or not to overwrite the key + if it already exists. If replace is False and the key exists, an + error will be raised. + :type replace: bool + :param encrypt: If True, the file will be encrypted on the server-side + by S3 and will be stored in an encrypted form while at rest in S3. + :type encrypt: bool + :param gzip: If True, the file will be compressed locally + :type gzip: bool + :param acl_policy: String specifying the canned ACL policy for the file being + uploaded to the S3 bucket. + :type acl_policy: str + """ + + template_fields = ( + "s3_bucket", + "s3_key", + "ftp_path", + ) + + @apply_defaults + def __init__( + self, + s3_bucket, + s3_key, + ftp_path, + ftp_conn_id="ftp_default", + aws_conn_id="aws_default", + replace=False, + encrypt=False, + gzip=False, + acl_policy=None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.ftp_path = ftp_path + self.aws_conn_id = aws_conn_id + self.ftp_conn_id = ftp_conn_id + self.replace = replace + self.encrypt = encrypt + self.gzip = gzip + self.acl_policy = acl_policy + + def execute(self, context): + s3_hook = S3Hook(self.aws_conn_id) + ftp_hook = FTPHook(ftp_conn_id=self.ftp_conn_id) + + with NamedTemporaryFile() as local_tmp_file: + ftp_hook.retrieve_file( + remote_full_path=self.ftp_path, + local_full_path_or_buffer=local_tmp_file.name, + ) + + s3_hook.load_file( + filename=local_tmp_file.name, + key=self.s3_key, + bucket_name=self.s3_bucket, + replace=self.replace, + encrypt=self.encrypt, + gzip=self.gzip, + acl_policy=self.acl_policy, + ) diff --git a/reference/providers/amazon/aws/transfers/gcs_to_s3.py b/reference/providers/amazon/aws/transfers/gcs_to_s3.py new file mode 100644 index 0000000..1885b3f --- /dev/null +++ b/reference/providers/amazon/aws/transfers/gcs_to_s3.py @@ -0,0 +1,198 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud Storage to S3 operator.""" +import warnings +from typing import Dict, Iterable, List, Optional, Sequence, Union, cast + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils.decorators import apply_defaults + + +class GCSToS3Operator(BaseOperator): + """ + Synchronizes a Google Cloud Storage bucket with an S3 bucket. + + :param bucket: The Google Cloud Storage bucket to find the objects. (templated) + :type bucket: str + :param prefix: Prefix string which filters objects whose name begin with + this prefix. (templated) + :type prefix: str + :param delimiter: The delimiter by which you want to filter the objects. (templated) + For e.g to lists the CSV files from in a directory in GCS you would use + delimiter='.csv'. + :type delimiter: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param delegate_to: Google account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param dest_aws_conn_id: The destination S3 connection + :type dest_aws_conn_id: str + :param dest_s3_key: The base S3 key to be used to store the files. (templated) + :type dest_s3_key: str + :param dest_verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + + :type dest_verify: bool or str + :param replace: Whether or not to verify the existence of the files in the + destination bucket. + By default is set to False + If set to True, will upload all the files replacing the existing ones in + the destination bucket. + If set to False, will upload only the files that are in the origin but not + in the destination bucket. + :type replace: bool + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + :param s3_acl_policy: Optional The string to specify the canned ACL policy for the + object to be uploaded in S3 + :type s3_acl_policy: str + """ + + template_fields: Iterable[str] = ( + "bucket", + "prefix", + "delimiter", + "dest_s3_key", + "google_impersonation_chain", + ) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + *, # pylint: disable=too-many-arguments + bucket: str, + prefix: Optional[str] = None, + delimiter: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + dest_aws_conn_id: str = "aws_default", + dest_s3_key: str, + dest_verify: Optional[Union[str, bool]] = None, + replace: bool = False, + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + dest_s3_extra_args: Optional[Dict] = None, + s3_acl_policy: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = google_cloud_storage_conn_id + + self.bucket = bucket + self.prefix = prefix + self.delimiter = delimiter + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.dest_aws_conn_id = dest_aws_conn_id + self.dest_s3_key = dest_s3_key + self.dest_verify = dest_verify + self.replace = replace + self.google_impersonation_chain = google_impersonation_chain + self.dest_s3_extra_args = dest_s3_extra_args or {} + self.s3_acl_policy = s3_acl_policy + + def execute(self, context) -> List[str]: + # list all files in an Google Cloud Storage bucket + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.google_impersonation_chain, + ) + + self.log.info( + "Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s", + self.bucket, + self.delimiter, + self.prefix, + ) + + files = hook.list( + bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter + ) + + s3_hook = S3Hook( + aws_conn_id=self.dest_aws_conn_id, + verify=self.dest_verify, + extra_args=self.dest_s3_extra_args, + ) + + if not self.replace: + # if we are not replacing -> list all files in the S3 bucket + # and only keep those files which are present in + # Google Cloud Storage and not in S3 + bucket_name, prefix = S3Hook.parse_s3_url(self.dest_s3_key) + # look for the bucket and the prefix to avoid look into + # parent directories/keys + existing_files = s3_hook.list_keys(bucket_name, prefix=prefix) + # in case that no files exists, return an empty array to avoid errors + existing_files = existing_files if existing_files is not None else [] + # remove the prefix for the existing files to allow the match + existing_files = [file.replace(prefix, "", 1) for file in existing_files] + files = list(set(files) - set(existing_files)) + + if files: + + for file in files: + file_bytes = hook.download(object_name=file, bucket_name=self.bucket) + + dest_key = self.dest_s3_key + file + self.log.info("Saving file to %s", dest_key) + + s3_hook.load_bytes( + cast(bytes, file_bytes), + key=dest_key, + replace=self.replace, + acl_policy=self.s3_acl_policy, + ) + + self.log.info("All done, uploaded %d files to S3", len(files)) + else: + self.log.info("In sync, no files needed to be uploaded to S3") + + return files diff --git a/reference/providers/amazon/aws/transfers/glacier_to_gcs.py b/reference/providers/amazon/aws/transfers/glacier_to_gcs.py new file mode 100644 index 0000000..f936666 --- /dev/null +++ b/reference/providers/amazon/aws/transfers/glacier_to_gcs.py @@ -0,0 +1,122 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tempfile +from typing import Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.glacier import GlacierHook +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils.decorators import apply_defaults + + +class GlacierToGCSOperator(BaseOperator): + """ + Transfers data from Amazon Glacier to Google Cloud Storage + + .. note:: + Please be warn that GlacierToGCSOperator may depends on memory usage. + Transferring big files may not working well. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GlacierToGCSOperator` + + :param aws_conn_id: The reference to the AWS connection details + :type aws_conn_id: str + :param gcp_conn_id: The reference to the GCP connection details + :type gcp_conn_id: str + :param vault_name: the Glacier vault on which job is executed + :type vault_name: string + :param bucket_name: the Google Cloud Storage bucket where the data will be transferred + :type bucket_name: str + :param object_name: the name of the object to check in the Google cloud + storage bucket. + :type object_name: str + :param gzip: option to compress local file or file data for upload + :type gzip: bool + :param chunk_size: size of chunk in bytes the that will downloaded from Glacier vault + :type chunk_size: int + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ("vault_name", "bucket_name", "object_name") + + @apply_defaults + def __init__( + self, + *, + aws_conn_id: str = "aws_default", + gcp_conn_id: str = "google_cloud_default", + vault_name: str, + bucket_name: str, + object_name: str, + gzip: bool, + chunk_size: int = 1024, + delegate_to: Optional[str] = None, + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.aws_conn_id = aws_conn_id + self.gcp_conn_id = gcp_conn_id + self.vault_name = vault_name + self.bucket_name = bucket_name + self.object_name = object_name + self.gzip = gzip + self.chunk_size = chunk_size + self.delegate_to = delegate_to + self.impersonation_chain = google_impersonation_chain + + def execute(self, context) -> str: + glacier_hook = GlacierHook(aws_conn_id=self.aws_conn_id) + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + job_id = glacier_hook.retrieve_inventory(vault_name=self.vault_name) + + with tempfile.NamedTemporaryFile() as temp_file: + glacier_data = glacier_hook.retrieve_inventory_results( + vault_name=self.vault_name, job_id=job_id["jobId"] + ) + # Read the file content in chunks using StreamingBody + # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/response.html + stream = glacier_data["body"] + for chunk in stream.iter_chunk(chunk_size=self.chunk_size): + temp_file.write(chunk) + temp_file.flush() + gcs_hook.upload( + bucket_name=self.bucket_name, + object_name=self.object_name, + filename=temp_file.name, + gzip=self.gzip, + ) + return f"gs://{self.bucket_name}/{self.object_name}" diff --git a/reference/providers/amazon/aws/transfers/google_api_to_s3.py b/reference/providers/amazon/aws/transfers/google_api_to_s3.py new file mode 100644 index 0000000..db98943 --- /dev/null +++ b/reference/providers/amazon/aws/transfers/google_api_to_s3.py @@ -0,0 +1,206 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module allows you to transfer data from any Google API endpoint into a S3 Bucket.""" +import json +import sys +from typing import Optional, Sequence, Union + +from airflow.models import BaseOperator, TaskInstance +from airflow.models.xcom import MAX_XCOM_SIZE, XCOM_RETURN_KEY +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.google.common.hooks.discovery_api import GoogleDiscoveryApiHook +from airflow.utils.decorators import apply_defaults + + +class GoogleApiToS3Operator(BaseOperator): + """ + Basic class for transferring data from a Google API endpoint into a S3 Bucket. + + This discovery-based operator use + :class:`~airflow.providers.google.common.hooks.discovery_api.GoogleDiscoveryApiHook` to communicate + with Google Services via the + `Google API Python Client `__. + Please note that this library is in maintenance mode hence it won't fully support Google Cloud in + the future. + Therefore it is recommended that you use the custom Google Cloud Service Operators for working + with the Google Cloud Platform. + + :param google_api_service_name: The specific API service that is being requested. + :type google_api_service_name: str + :param google_api_service_version: The version of the API that is being requested. + :type google_api_service_version: str + :param google_api_endpoint_path: The client libraries path to the api call's executing method. + For example: 'analyticsreporting.reports.batchGet' + + .. note:: See https://developers.google.com/apis-explorer + for more information on which methods are available. + + :type google_api_endpoint_path: str + :param google_api_endpoint_params: The params to control the corresponding endpoint result. + :type google_api_endpoint_params: dict + :param s3_destination_key: The url where to put the data retrieved from the endpoint in S3. + :type s3_destination_key: str + :param google_api_response_via_xcom: Can be set to expose the google api response to xcom. + :type google_api_response_via_xcom: str + :param google_api_endpoint_params_via_xcom: If set to a value this value will be used as a key + for pulling from xcom and updating the google api endpoint params. + :type google_api_endpoint_params_via_xcom: str + :param google_api_endpoint_params_via_xcom_task_ids: Task ids to filter xcom by. + :type google_api_endpoint_params_via_xcom_task_ids: str or list of str + :param google_api_pagination: If set to True Pagination will be enabled for this request + to retrieve all data. + + .. note:: This means the response will be a list of responses. + + :type google_api_pagination: bool + :param google_api_num_retries: Define the number of retries for the google api requests being made + if it fails. + :type google_api_num_retries: int + :param s3_overwrite: Specifies whether the s3 file will be overwritten if exists. + :type s3_overwrite: bool + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: Google account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param aws_conn_id: The connection id specifying the authentication information for the S3 Bucket. + :type aws_conn_id: str + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "google_api_endpoint_params", + "s3_destination_key", + "google_impersonation_chain", + ) + template_ext = () + ui_color = "#cc181e" + + @apply_defaults + def __init__( + self, + *, + google_api_service_name: str, + google_api_service_version: str, + google_api_endpoint_path: str, + google_api_endpoint_params: dict, + s3_destination_key: str, + google_api_response_via_xcom: Optional[str] = None, + google_api_endpoint_params_via_xcom: Optional[str] = None, + google_api_endpoint_params_via_xcom_task_ids: Optional[str] = None, + google_api_pagination: bool = False, + google_api_num_retries: int = 0, + s3_overwrite: bool = False, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + aws_conn_id: str = "aws_default", + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.google_api_service_name = google_api_service_name + self.google_api_service_version = google_api_service_version + self.google_api_endpoint_path = google_api_endpoint_path + self.google_api_endpoint_params = google_api_endpoint_params + self.s3_destination_key = s3_destination_key + self.google_api_response_via_xcom = google_api_response_via_xcom + self.google_api_endpoint_params_via_xcom = google_api_endpoint_params_via_xcom + self.google_api_endpoint_params_via_xcom_task_ids = ( + google_api_endpoint_params_via_xcom_task_ids + ) + self.google_api_pagination = google_api_pagination + self.google_api_num_retries = google_api_num_retries + self.s3_overwrite = s3_overwrite + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.aws_conn_id = aws_conn_id + self.google_impersonation_chain = google_impersonation_chain + + def execute(self, context) -> None: + """ + Transfers Google APIs json data to S3. + + :param context: The context that is being provided when executing. + :type context: dict + """ + self.log.info("Transferring data from %s to s3", self.google_api_service_name) + + if self.google_api_endpoint_params_via_xcom: + self._update_google_api_endpoint_params_via_xcom(context["task_instance"]) + + data = self._retrieve_data_from_google_api() + + self._load_data_to_s3(data) + + if self.google_api_response_via_xcom: + self._expose_google_api_response_via_xcom(context["task_instance"], data) + + def _retrieve_data_from_google_api(self) -> dict: + google_discovery_api_hook = GoogleDiscoveryApiHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_service_name=self.google_api_service_name, + api_version=self.google_api_service_version, + impersonation_chain=self.google_impersonation_chain, + ) + google_api_response = google_discovery_api_hook.query( + endpoint=self.google_api_endpoint_path, + data=self.google_api_endpoint_params, + paginate=self.google_api_pagination, + num_retries=self.google_api_num_retries, + ) + return google_api_response + + def _load_data_to_s3(self, data: dict) -> None: + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) + s3_hook.load_string( + string_data=json.dumps(data), + key=self.s3_destination_key, + replace=self.s3_overwrite, + ) + + def _update_google_api_endpoint_params_via_xcom( + self, task_instance: TaskInstance + ) -> None: + google_api_endpoint_params = task_instance.xcom_pull( + task_ids=self.google_api_endpoint_params_via_xcom_task_ids, + key=self.google_api_endpoint_params_via_xcom, + ) + self.google_api_endpoint_params.update(google_api_endpoint_params) + + def _expose_google_api_response_via_xcom( + self, task_instance: TaskInstance, data: dict + ) -> None: + if sys.getsizeof(data) < MAX_XCOM_SIZE: + task_instance.xcom_push( + key=self.google_api_response_via_xcom or XCOM_RETURN_KEY, value=data + ) + else: + raise RuntimeError( + "The size of the downloaded data is too large to push to XCom!" + ) diff --git a/reference/providers/amazon/aws/transfers/hive_to_dynamodb.py b/reference/providers/amazon/aws/transfers/hive_to_dynamodb.py new file mode 100644 index 0000000..c79afd0 --- /dev/null +++ b/reference/providers/amazon/aws/transfers/hive_to_dynamodb.py @@ -0,0 +1,117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains operator to move data from Hive to DynamoDB.""" + +import json +from typing import Callable, Optional + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook +from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook +from airflow.utils.decorators import apply_defaults + + +class HiveToDynamoDBOperator(BaseOperator): + """ + Moves data from Hive to DynamoDB, note that for now the data is loaded + into memory before being pushed to DynamoDB, so this operator should + be used for smallish amount of data. + + :param sql: SQL query to execute against the hive database. (templated) + :type sql: str + :param table_name: target DynamoDB table + :type table_name: str + :param table_keys: partition key and sort key + :type table_keys: list + :param pre_process: implement pre-processing of source data + :type pre_process: function + :param pre_process_args: list of pre_process function arguments + :type pre_process_args: list + :param pre_process_kwargs: dict of pre_process function arguments + :type pre_process_kwargs: dict + :param region_name: aws region name (example: us-east-1) + :type region_name: str + :param schema: hive database schema + :type schema: str + :param hiveserver2_conn_id: source hive connection + :type hiveserver2_conn_id: str + :param aws_conn_id: aws connection + :type aws_conn_id: str + """ + + template_fields = ("sql",) + template_ext = (".sql",) + ui_color = "#a0e08c" + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + sql: str, + table_name: str, + table_keys: list, + pre_process: Optional[Callable] = None, + pre_process_args: Optional[list] = None, + pre_process_kwargs: Optional[list] = None, + region_name: Optional[str] = None, + schema: str = "default", + hiveserver2_conn_id: str = "hiveserver2_default", + aws_conn_id: str = "aws_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sql = sql + self.table_name = table_name + self.table_keys = table_keys + self.pre_process = pre_process + self.pre_process_args = pre_process_args + self.pre_process_kwargs = pre_process_kwargs + self.region_name = region_name + self.schema = schema + self.hiveserver2_conn_id = hiveserver2_conn_id + self.aws_conn_id = aws_conn_id + + def execute(self, context): + hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) + + self.log.info("Extracting data from Hive") + self.log.info(self.sql) + + data = hive.get_pandas_df(self.sql, schema=self.schema) + dynamodb = AwsDynamoDBHook( + aws_conn_id=self.aws_conn_id, + table_name=self.table_name, + table_keys=self.table_keys, + region_name=self.region_name, + ) + + self.log.info("Inserting rows into dynamodb") + + if self.pre_process is None: + dynamodb.write_batch_data(json.loads(data.to_json(orient="records"))) + else: + dynamodb.write_batch_data( + self.pre_process( + data=data, + args=self.pre_process_args, + kwargs=self.pre_process_kwargs, + ) + ) + + self.log.info("Done.") diff --git a/reference/providers/amazon/aws/transfers/imap_attachment_to_s3.py b/reference/providers/amazon/aws/transfers/imap_attachment_to_s3.py new file mode 100644 index 0000000..f59c86b --- /dev/null +++ b/reference/providers/amazon/aws/transfers/imap_attachment_to_s3.py @@ -0,0 +1,105 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module allows you to transfer mail attachments from a mail server into s3 bucket.""" +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.imap.hooks.imap import ImapHook +from airflow.utils.decorators import apply_defaults + + +class ImapAttachmentToS3Operator(BaseOperator): + """ + Transfers a mail attachment from a mail server into s3 bucket. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ImapAttachmentToS3Operator` + + :param imap_attachment_name: The file name of the mail attachment that you want to transfer. + :type imap_attachment_name: str + :param s3_key: The destination file name in the s3 bucket for the attachment. + :type s3_key: str + :param imap_check_regex: If set checks the `imap_attachment_name` for a regular expression. + :type imap_check_regex: bool + :param imap_mail_folder: The folder on the mail server to look for the attachment. + :type imap_mail_folder: str + :param imap_mail_filter: If set other than 'All' only specific mails will be checked. + See :py:meth:`imaplib.IMAP4.search` for details. + :type imap_mail_filter: str + :param s3_overwrite: If set overwrites the s3 key if already exists. + :type s3_overwrite: bool + :param imap_conn_id: The reference to the connection details of the mail server. + :type imap_conn_id: str + :param s3_conn_id: The reference to the s3 connection details. + :type s3_conn_id: str + """ + + template_fields = ("imap_attachment_name", "s3_key", "imap_mail_filter") + + @apply_defaults + def __init__( + self, + *, + imap_attachment_name: str, + s3_key: str, + imap_check_regex: bool = False, + imap_mail_folder: str = "INBOX", + imap_mail_filter: str = "All", + s3_overwrite: bool = False, + imap_conn_id: str = "imap_default", + s3_conn_id: str = "aws_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.imap_attachment_name = imap_attachment_name + self.s3_key = s3_key + self.imap_check_regex = imap_check_regex + self.imap_mail_folder = imap_mail_folder + self.imap_mail_filter = imap_mail_filter + self.s3_overwrite = s3_overwrite + self.imap_conn_id = imap_conn_id + self.s3_conn_id = s3_conn_id + + def execute(self, context) -> None: + """ + This function executes the transfer from the email server (via imap) into s3. + + :param context: The context while executing. + :type context: dict + """ + self.log.info( + "Transferring mail attachment %s from mail server via imap to s3 key %s...", + self.imap_attachment_name, + self.s3_key, + ) + + with ImapHook(imap_conn_id=self.imap_conn_id) as imap_hook: + imap_mail_attachments = imap_hook.retrieve_mail_attachments( + name=self.imap_attachment_name, + check_regex=self.imap_check_regex, + latest_only=True, + mail_folder=self.imap_mail_folder, + mail_filter=self.imap_mail_filter, + ) + + s3_hook = S3Hook(aws_conn_id=self.s3_conn_id) + s3_hook.load_bytes( + bytes_data=imap_mail_attachments[0][1], + key=self.s3_key, + replace=self.s3_overwrite, + ) diff --git a/reference/providers/amazon/aws/transfers/mongo_to_s3.py b/reference/providers/amazon/aws/transfers/mongo_to_s3.py new file mode 100644 index 0000000..26d6410 --- /dev/null +++ b/reference/providers/amazon/aws/transfers/mongo_to_s3.py @@ -0,0 +1,153 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import warnings +from typing import Any, Iterable, Optional, Union, cast + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.mongo.hooks.mongo import MongoHook +from airflow.utils.decorators import apply_defaults +from bson import json_util + +_DEPRECATION_MSG = "The s3_conn_id parameter has been deprecated. You should pass instead the aws_conn_id parameter." + + +class MongoToS3Operator(BaseOperator): + """Operator meant to move data from mongo via pymongo to s3 via boto. + + :param mongo_conn_id: reference to a specific mongo connection + :type mongo_conn_id: str + :param aws_conn_id: reference to a specific S3 connection + :type aws_conn_id: str + :param mongo_collection: reference to a specific collection in your mongo db + :type mongo_collection: str + :param mongo_query: query to execute. A list including a dict of the query + :type mongo_query: list + :param s3_bucket: reference to a specific S3 bucket to store the data + :type s3_bucket: str + :param s3_key: in which S3 key the file will be stored + :type s3_key: str + :param mongo_db: reference to a specific mongo database + :type mongo_db: str + :param replace: whether or not to replace the file in S3 if it previously existed + :type replace: bool + :param allow_disk_use: in the case you are retrieving a lot of data, you may have + to use the disk to save it instead of saving all in the RAM + :type allow_disk_use: bool + :param compression: type of compression to use for output file in S3. Currently only gzip is supported. + :type compression: str + """ + + template_fields = ("s3_bucket", "s3_key", "mongo_query", "mongo_collection") + ui_color = "#589636" + # pylint: disable=too-many-instance-attributes + + @apply_defaults + def __init__( + self, + *, + s3_conn_id: Optional[str] = None, + mongo_conn_id: str = "mongo_default", + aws_conn_id: str = "aws_default", + mongo_collection: str, + mongo_query: Union[list, dict], + s3_bucket: str, + s3_key: str, + mongo_db: Optional[str] = None, + replace: bool = False, + allow_disk_use: bool = False, + compression: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if s3_conn_id: + warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3) + aws_conn_id = s3_conn_id + + self.mongo_conn_id = mongo_conn_id + self.aws_conn_id = aws_conn_id + self.mongo_db = mongo_db + self.mongo_collection = mongo_collection + + # Grab query and determine if we need to run an aggregate pipeline + self.mongo_query = mongo_query + self.is_pipeline = isinstance(self.mongo_query, list) + + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.replace = replace + self.allow_disk_use = allow_disk_use + self.compression = compression + + def execute(self, context) -> bool: + """Is written to depend on transform method""" + s3_conn = S3Hook(self.aws_conn_id) + + # Grab collection and execute query according to whether or not it is a pipeline + if self.is_pipeline: + results = MongoHook(self.mongo_conn_id).aggregate( + mongo_collection=self.mongo_collection, + aggregate_query=cast(list, self.mongo_query), + mongo_db=self.mongo_db, + allowDiskUse=self.allow_disk_use, + ) + + else: + results = MongoHook(self.mongo_conn_id).find( + mongo_collection=self.mongo_collection, + query=cast(dict, self.mongo_query), + mongo_db=self.mongo_db, + allowDiskUse=self.allow_disk_use, + ) + + # Performs transform then stringifies the docs results into json format + docs_str = self._stringify(self.transform(results)) + + s3_conn.load_string( + string_data=docs_str, + key=self.s3_key, + bucket_name=self.s3_bucket, + replace=self.replace, + compression=self.compression, + ) + + @staticmethod + def _stringify(iterable: Iterable, joinable: str = "\n") -> str: + """ + Takes an iterable (pymongo Cursor or Array) containing dictionaries and + returns a stringified version using python join + """ + return joinable.join( + [json.dumps(doc, default=json_util.default) for doc in iterable] + ) + + @staticmethod + def transform(docs: Any) -> Any: + """This method is meant to be extended by child classes + to perform transformations unique to those operators needs. + Processes pyMongo cursor and returns an iterable with each element being + a JSON serializable dictionary + + Base transform() assumes no processing is needed + ie. docs is a pyMongo cursor of documents and cursor just + needs to be passed through + + Override this method for custom transformations + """ + return docs diff --git a/reference/providers/amazon/aws/transfers/mysql_to_s3.py b/reference/providers/amazon/aws/transfers/mysql_to_s3.py new file mode 100644 index 0000000..0b63dba --- /dev/null +++ b/reference/providers/amazon/aws/transfers/mysql_to_s3.py @@ -0,0 +1,131 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from tempfile import NamedTemporaryFile +from typing import Optional, Union + +import numpy as np +import pandas as pd +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.utils.decorators import apply_defaults + + +class MySQLToS3Operator(BaseOperator): + """ + Saves data from an specific MySQL query into a file in S3. + + :param query: the sql query to be executed. If you want to execute a file, place the absolute path of it, + ending with .sql extension. (templated) + :type query: str + :param s3_bucket: bucket where the data will be stored. (templated) + :type s3_bucket: str + :param s3_key: desired key for the file. It includes the name of the file. (templated) + :type s3_key: str + :param mysql_conn_id: reference to a specific mysql database + :type mysql_conn_id: str + :param aws_conn_id: reference to a specific S3 connection + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str + :param pd_csv_kwargs: arguments to include in pd.to_csv (header, index, columns...) + :type pd_csv_kwargs: dict + :param index: whether to have the index or not in the dataframe + :type index: str + :param header: whether to include header or not into the S3 file + :type header: bool + """ + + template_fields = ( + "s3_bucket", + "s3_key", + "query", + ) + template_ext = (".sql",) + + @apply_defaults + def __init__( + self, + *, + query: str, + s3_bucket: str, + s3_key: str, + mysql_conn_id: str = "mysql_default", + aws_conn_id: str = "aws_default", + verify: Optional[Union[bool, str]] = None, + pd_csv_kwargs: Optional[dict] = None, + index: bool = False, + header: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.query = query + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.mysql_conn_id = mysql_conn_id + self.aws_conn_id = aws_conn_id + self.verify = verify + + self.pd_csv_kwargs = pd_csv_kwargs or {} + if "path_or_buf" in self.pd_csv_kwargs: + raise AirflowException( + "The argument path_or_buf is not allowed, please remove it" + ) + if "index" not in self.pd_csv_kwargs: + self.pd_csv_kwargs["index"] = index + if "header" not in self.pd_csv_kwargs: + self.pd_csv_kwargs["header"] = header + + def _fix_int_dtypes(self, df: pd.DataFrame) -> None: + """Mutate DataFrame to set dtypes for int columns containing NaN values.""" + for col in df: + if "float" in df[col].dtype.name and df[col].hasnans: + # inspect values to determine if dtype of non-null values is int or float + notna_series = df[col].dropna().values + if np.isclose(notna_series, notna_series.astype(int)).all(): + # set to dtype that retains integers and supports NaNs + df[col] = np.where(df[col].isnull(), None, df[col]) + df[col] = df[col].astype(pd.Int64Dtype()) + + def execute(self, context) -> None: + mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) + s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + data_df = mysql_hook.get_pandas_df(self.query) + self.log.info("Data from MySQL obtained") + + self._fix_int_dtypes(data_df) + with NamedTemporaryFile(mode="r+", suffix=".csv") as tmp_csv: + data_df.to_csv(tmp_csv.name, **self.pd_csv_kwargs) + s3_conn.load_file( + filename=tmp_csv.name, key=self.s3_key, bucket_name=self.s3_bucket + ) + + if s3_conn.check_for_key(self.s3_key, bucket_name=self.s3_bucket): + file_location = os.path.join(self.s3_bucket, self.s3_key) + self.log.info("File saved correctly in %s", file_location) diff --git a/reference/providers/amazon/aws/transfers/redshift_to_s3.py b/reference/providers/amazon/aws/transfers/redshift_to_s3.py new file mode 100644 index 0000000..7269122 --- /dev/null +++ b/reference/providers/amazon/aws/transfers/redshift_to_s3.py @@ -0,0 +1,156 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Transfers data from AWS Redshift into a S3 Bucket.""" +from typing import List, Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.utils.redshift import build_credentials_block +from airflow.providers.postgres.hooks.postgres import PostgresHook +from airflow.utils.decorators import apply_defaults + + +class RedshiftToS3Operator(BaseOperator): + """ + Executes an UNLOAD command to s3 as a CSV with headers + + :param s3_bucket: reference to a specific S3 bucket + :type s3_bucket: str + :param s3_key: reference to a specific S3 key. If ``table_as_file_name`` is set + to False, this param must include the desired file name + :type s3_key: str + :param schema: reference to a specific schema in redshift database + Applicable when ``table`` param provided. + :type schema: str + :param table: reference to a specific table in redshift database + Used when ``select_query`` param not provided. + :type table: str + :param select_query: custom select query to fetch data from redshift database + :type select_query: str + :param redshift_conn_id: reference to a specific redshift database + :type redshift_conn_id: str + :param aws_conn_id: reference to a specific S3 connection + If the AWS connection contains 'aws_iam_role' in ``extras`` + the operator will use AWS STS credentials with a token + https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str + :param unload_options: reference to a list of UNLOAD options + :type unload_options: list + :param autocommit: If set to True it will automatically commit the UNLOAD statement. + Otherwise it will be committed right before the redshift connection gets closed. + :type autocommit: bool + :param include_header: If set to True the s3 file contains the header columns. + :type include_header: bool + :param table_as_file_name: If set to True, the s3 file will be named as the table. + Applicable when ``table`` param provided. + :type table_as_file_name: bool + """ + + template_fields = ("s3_bucket", "s3_key", "schema", "table", "unload_options") + template_ext = () + ui_color = "#ededed" + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + s3_bucket: str, + s3_key: str, + schema: str = None, + table: str = None, + select_query: str = None, + redshift_conn_id: str = "redshift_default", + aws_conn_id: str = "aws_default", + verify: Optional[Union[bool, str]] = None, + unload_options: Optional[List] = None, + autocommit: bool = False, + include_header: bool = False, + table_as_file_name: bool = True, # Set to True by default for not breaking current workflows + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.s3_bucket = s3_bucket + self.s3_key = f"{s3_key}/{table}_" if (table and table_as_file_name) else s3_key + self.schema = schema + self.table = table + self.redshift_conn_id = redshift_conn_id + self.aws_conn_id = aws_conn_id + self.verify = verify + self.unload_options = unload_options or [] # type: List + self.autocommit = autocommit + self.include_header = include_header + self.table_as_file_name = table_as_file_name + + self._select_query = None + if select_query: + self._select_query = select_query + elif self.schema and self.table: + self._select_query = f"SELECT * FROM {self.schema}.{self.table}" + else: + raise ValueError( + "Please provide both `schema` and `table` params or `select_query` to fetch the data." + ) + + if self.include_header and "HEADER" not in [ + uo.upper().strip() for uo in self.unload_options + ]: + self.unload_options = list(self.unload_options) + [ + "HEADER", + ] + + def _build_unload_query( + self, + credentials_block: str, + select_query: str, + s3_key: str, + unload_options: str, + ) -> str: + return f""" + UNLOAD ('{select_query}') + TO 's3://{self.s3_bucket}/{s3_key}' + with credentials + '{credentials_block}' + {unload_options}; + """ + + def execute(self, context) -> None: + postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + + credentials = s3_hook.get_credentials() + credentials_block = build_credentials_block(credentials) + unload_options = "\n\t\t\t".join(self.unload_options) + + unload_query = self._build_unload_query( + credentials_block, self._select_query, self.s3_key, unload_options + ) + + self.log.info("Executing UNLOAD command...") + postgres_hook.run(unload_query, self.autocommit) + self.log.info("UNLOAD command complete...") diff --git a/reference/providers/amazon/aws/transfers/s3_to_ftp.py b/reference/providers/amazon/aws/transfers/s3_to_ftp.py new file mode 100644 index 0000000..3ecf1e5 --- /dev/null +++ b/reference/providers/amazon/aws/transfers/s3_to_ftp.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tempfile import NamedTemporaryFile + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.ftp.hooks.ftp import FTPHook +from airflow.utils.decorators import apply_defaults + + +class S3ToFTPOperator(BaseOperator): + """ + This operator enables the transferring of files from S3 to a FTP server. + + :param ftp_conn_id: The ftp connection id. The name or identifier for + establishing a connection to the FTP server. + :type ftp_conn_id: str + :param ftp_path: The ftp remote path. This is the specified file path for + uploading file to the FTP server. + :type ftp_path: str + :param s3_bucket: The targeted s3 bucket. This is the S3 bucket from + where the file is downloaded. + :type s3_bucket: str + :param s3_key: The targeted s3 key. This is the specified file path for + downloading the file from S3. + :type s3_key: str + """ + + template_fields = ("s3_bucket", "s3_key") + + @apply_defaults + def __init__( + self, + *, + s3_bucket, + s3_key, + ftp_path, + aws_conn_id="aws_default", + ftp_conn_id="ftp_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.ftp_path = ftp_path + self.aws_conn_id = aws_conn_id + self.ftp_conn_id = ftp_conn_id + + def execute(self, context): + s3_hook = S3Hook(self.aws_conn_id) + ftp_hook = FTPHook(ftp_conn_id=self.ftp_conn_id) + + s3_obj = s3_hook.get_key(self.s3_key, self.s3_bucket) + + with NamedTemporaryFile() as local_tmp_file: + s3_obj.download_fileobj(local_tmp_file) + ftp_hook.store_file(self.ftp_path, local_tmp_file.name) diff --git a/reference/providers/amazon/aws/transfers/s3_to_redshift.py b/reference/providers/amazon/aws/transfers/s3_to_redshift.py new file mode 100644 index 0000000..6f256ed --- /dev/null +++ b/reference/providers/amazon/aws/transfers/s3_to_redshift.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.utils.redshift import build_credentials_block +from airflow.providers.postgres.hooks.postgres import PostgresHook +from airflow.utils.decorators import apply_defaults + + +class S3ToRedshiftOperator(BaseOperator): + """ + Executes an COPY command to load files from s3 to Redshift + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:S3ToRedshiftOperator` + + :param schema: reference to a specific schema in redshift database + :type schema: str + :param table: reference to a specific table in redshift database + :type table: str + :param s3_bucket: reference to a specific S3 bucket + :type s3_bucket: str + :param s3_key: reference to a specific S3 key + :type s3_key: str + :param redshift_conn_id: reference to a specific redshift database + :type redshift_conn_id: str + :param aws_conn_id: reference to a specific S3 connection + If the AWS connection contains 'aws_iam_role' in ``extras`` + the operator will use AWS STS credentials with a token + https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str + :param copy_options: reference to a list of COPY options + :type copy_options: list + :param truncate_table: whether or not to truncate the destination table before the copy + :type truncate_table: bool + """ + + template_fields = ("s3_bucket", "s3_key", "schema", "table", "copy_options") + template_ext = () + ui_color = "#99e699" + + @apply_defaults + def __init__( + self, + *, + schema: str, + table: str, + s3_bucket: str, + s3_key: str, + redshift_conn_id: str = "redshift_default", + aws_conn_id: str = "aws_default", + verify: Optional[Union[bool, str]] = None, + copy_options: Optional[List] = None, + autocommit: bool = False, + truncate_table: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.schema = schema + self.table = table + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.redshift_conn_id = redshift_conn_id + self.aws_conn_id = aws_conn_id + self.verify = verify + self.copy_options = copy_options or [] + self.autocommit = autocommit + self.truncate_table = truncate_table + + def _build_copy_query(self, credentials_block: str, copy_options: str) -> str: + return f""" + COPY {self.schema}.{self.table} + FROM 's3://{self.s3_bucket}/{self.s3_key}' + with credentials + '{credentials_block}' + {copy_options}; + """ + + def execute(self, context) -> None: + postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + credentials = s3_hook.get_credentials() + credentials_block = build_credentials_block(credentials) + copy_options = "\n\t\t\t".join(self.copy_options) + + copy_statement = self._build_copy_query(credentials_block, copy_options) + + if self.truncate_table: + truncate_statement = f"TRUNCATE TABLE {self.schema}.{self.table};" + sql = f""" + BEGIN; + {truncate_statement} + {copy_statement} + COMMIT + """ + else: + sql = copy_statement + + self.log.info("Executing COPY command...") + postgres_hook.run(sql, self.autocommit) + self.log.info("COPY command complete...") diff --git a/reference/providers/amazon/aws/transfers/s3_to_sftp.py b/reference/providers/amazon/aws/transfers/s3_to_sftp.py new file mode 100644 index 0000000..1a1a09b --- /dev/null +++ b/reference/providers/amazon/aws/transfers/s3_to_sftp.py @@ -0,0 +1,85 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tempfile import NamedTemporaryFile +from urllib.parse import urlparse + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.ssh.hooks.ssh import SSHHook +from airflow.utils.decorators import apply_defaults + + +class S3ToSFTPOperator(BaseOperator): + """ + This operator enables the transferring of files from S3 to a SFTP server. + + :param sftp_conn_id: The sftp connection id. The name or identifier for + establishing a connection to the SFTP server. + :type sftp_conn_id: str + :param sftp_path: The sftp remote path. This is the specified file path for + uploading file to the SFTP server. + :type sftp_path: str + :param s3_conn_id: The s3 connection id. The name or identifier for + establishing a connection to S3 + :type s3_conn_id: str + :param s3_bucket: The targeted s3 bucket. This is the S3 bucket from + where the file is downloaded. + :type s3_bucket: str + :param s3_key: The targeted s3 key. This is the specified file path for + downloading the file from S3. + :type s3_key: str + """ + + template_fields = ("s3_key", "sftp_path") + + @apply_defaults + def __init__( + self, + *, + s3_bucket: str, + s3_key: str, + sftp_path: str, + sftp_conn_id: str = "ssh_default", + s3_conn_id: str = "aws_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sftp_conn_id = sftp_conn_id + self.sftp_path = sftp_path + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.s3_conn_id = s3_conn_id + + @staticmethod + def get_s3_key(s3_key: str) -> str: + """This parses the correct format for S3 keys regardless of how the S3 url is passed.""" + parsed_s3_key = urlparse(s3_key) + return parsed_s3_key.path.lstrip("/") + + def execute(self, context) -> None: + self.s3_key = self.get_s3_key(self.s3_key) + ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id) + s3_hook = S3Hook(self.s3_conn_id) + + s3_client = s3_hook.get_conn() + sftp_client = ssh_hook.get_conn().open_sftp() + + with NamedTemporaryFile("w") as f: + s3_client.download_file(self.s3_bucket, self.s3_key, f.name) + sftp_client.put(f.name, self.sftp_path) diff --git a/reference/providers/amazon/aws/transfers/sftp_to_s3.py b/reference/providers/amazon/aws/transfers/sftp_to_s3.py new file mode 100644 index 0000000..9b882eb --- /dev/null +++ b/reference/providers/amazon/aws/transfers/sftp_to_s3.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tempfile import NamedTemporaryFile +from urllib.parse import urlparse + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.ssh.hooks.ssh import SSHHook +from airflow.utils.decorators import apply_defaults + + +class SFTPToS3Operator(BaseOperator): + """ + This operator enables the transferring of files from a SFTP server to + Amazon S3. + + :param sftp_conn_id: The sftp connection id. The name or identifier for + establishing a connection to the SFTP server. + :type sftp_conn_id: str + :param sftp_path: The sftp remote path. This is the specified file path + for downloading the file from the SFTP server. + :type sftp_path: str + :param s3_conn_id: The s3 connection id. The name or identifier for + establishing a connection to S3 + :type s3_conn_id: str + :param s3_bucket: The targeted s3 bucket. This is the S3 bucket to where + the file is uploaded. + :type s3_bucket: str + :param s3_key: The targeted s3 key. This is the specified path for + uploading the file to S3. + :type s3_key: str + """ + + template_fields = ("s3_key", "sftp_path") + + @apply_defaults + def __init__( + self, + *, + s3_bucket: str, + s3_key: str, + sftp_path: str, + sftp_conn_id: str = "ssh_default", + s3_conn_id: str = "aws_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sftp_conn_id = sftp_conn_id + self.sftp_path = sftp_path + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.s3_conn_id = s3_conn_id + + @staticmethod + def get_s3_key(s3_key: str) -> str: + """This parses the correct format for S3 keys regardless of how the S3 url is passed.""" + parsed_s3_key = urlparse(s3_key) + return parsed_s3_key.path.lstrip("/") + + def execute(self, context) -> None: + self.s3_key = self.get_s3_key(self.s3_key) + ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id) + s3_hook = S3Hook(self.s3_conn_id) + + sftp_client = ssh_hook.get_conn().open_sftp() + + with NamedTemporaryFile("w") as f: + sftp_client.get(self.sftp_path, f.name) + + s3_hook.load_file( + filename=f.name, + key=self.s3_key, + bucket_name=self.s3_bucket, + replace=True, + ) diff --git a/reference/providers/amazon/aws/utils/__init__.py b/reference/providers/amazon/aws/utils/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/amazon/aws/utils/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/amazon/aws/utils/emailer.py b/reference/providers/amazon/aws/utils/emailer.py new file mode 100644 index 0000000..6895703 --- /dev/null +++ b/reference/providers/amazon/aws/utils/emailer.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Airflow module for email backend using AWS SES""" + +from typing import List, Optional, Union + +from airflow.providers.amazon.aws.hooks.ses import SESHook + + +def send_email( + to: Union[List[str], str], + subject: str, + html_content: str, + files: Optional[List] = None, + cc: Optional[Union[List[str], str]] = None, + bcc: Optional[Union[List[str], str]] = None, + mime_subtype: str = "mixed", + mime_charset: str = "utf-8", + conn_id: str = "aws_default", + **kwargs, +) -> None: + """Email backend for SES.""" + hook = SESHook(aws_conn_id=conn_id) + hook.send_email( + mail_from=None, + to=to, + subject=subject, + html_content=html_content, + files=files, + cc=cc, + bcc=bcc, + mime_subtype=mime_subtype, + mime_charset=mime_charset, + ) diff --git a/reference/providers/amazon/aws/utils/redshift.py b/reference/providers/amazon/aws/utils/redshift.py new file mode 100644 index 0000000..0c6799b --- /dev/null +++ b/reference/providers/amazon/aws/utils/redshift.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging + +from botocore.credentials import ReadOnlyCredentials + +log = logging.getLogger(__name__) + + +def build_credentials_block(credentials: ReadOnlyCredentials) -> str: + """ + Generate AWS credentials block for Redshift COPY and UNLOAD + commands, as noted in AWS docs + https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials + + :param credentials: ReadOnlyCredentials object from `botocore` + :return: str + """ + if credentials.token: + log.debug("STS token found in credentials, including it in the command") + # these credentials are obtained from AWS STS + # so the token must be included in the CREDENTIALS clause + credentials_line = ( + f"aws_access_key_id={credentials.access_key};" + f"aws_secret_access_key={credentials.secret_key};" + f"token={credentials.token}" + ) + + else: + credentials_line = ( + f"aws_access_key_id={credentials.access_key};" + f"aws_secret_access_key={credentials.secret_key}" + ) + + return credentials_line diff --git a/reference/providers/amazon/provider.yaml b/reference/providers/amazon/provider.yaml new file mode 100644 index 0000000..9cc1ed1 --- /dev/null +++ b/reference/providers/amazon/provider.yaml @@ -0,0 +1,369 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-amazon +name: Amazon +description: | + Amazon integration (including `Amazon Web Services (AWS) `__). + +versions: + - 1.2.0 + - 1.1.0 + - 1.0.0 + +integrations: + - integration-name: Amazon Athena + external-doc-url: https://aws.amazon.com/athena/ + logo: /integration-logos/aws/Amazon-Athena_light-bg@4x.png + tags: [aws] + - integration-name: Amazon CloudFormation + external-doc-url: https://aws.amazon.com/cloudformation/ + logo: /integration-logos/aws/AWS-CloudFormation_light-bg@4x.png + tags: [aws] + - integration-name: Amazon CloudWatch Logs + external-doc-url: https://aws.amazon.com/cloudwatch/ + logo: /integration-logos/aws/Amazon-CloudWatch_light-bg@4x.png + tags: [aws] + - integration-name: Amazon DataSync + external-doc-url: https://aws.amazon.com/datasync/ + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/datasync.rst + tags: [aws] + - integration-name: Amazon DynamoDB + external-doc-url: https://aws.amazon.com/dynamodb/ + logo: /integration-logos/aws/Amazon-DynamoDB_light-bg@4x.png + tags: [aws] + - integration-name: Amazon EC2 + external-doc-url: https://aws.amazon.com/ec2/ + logo: /integration-logos/aws/Amazon-EC2_light-bg@4x.png + tags: [aws] + - integration-name: Amazon ECS + external-doc-url: https://aws.amazon.com/ecs/ + logo: /integration-logos/aws/Amazon-Elastic-Container-Service_light-bg@4x.png + tags: [aws] + - integration-name: Amazon ElastiCache + external-doc-url: https://aws.amazon.com/elasticache/redis// + logo: /integration-logos/aws/Amazon-ElastiCache_light-bg@4x.png + tags: [aws] + - integration-name: Amazon EMR + external-doc-url: https://aws.amazon.com/emr/ + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/emr.rst + logo: /integration-logos/aws/Amazon-EMR_light-bg@4x.png + tags: [aws] + - integration-name: Amazon Glacier + external-doc-url: https://aws.amazon.com/glacier/ + logo: /integration-logos/aws/Amazon-S3-Glacier_light-bg@4x.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/glacier.rst + tags: [aws] + - integration-name: Amazon Kinesis Data Firehose + external-doc-url: https://aws.amazon.com/kinesis/data-firehose/ + logo: /integration-logos/aws/Amazon-Kinesis-Data-Firehose_light-bg@4x.png + tags: [aws] + - integration-name: Amazon Redshift + external-doc-url: https://aws.amazon.com/redshift/ + logo: /integration-logos/aws/Amazon-Redshift_light-bg@4x.png + tags: [aws] + - integration-name: Amazon SageMaker + external-doc-url: https://aws.amazon.com/sagemaker/ + logo: /integration-logos/aws/Amazon-SageMaker_light-bg@4x.png + tags: [aws] + - integration-name: Amazon SecretsManager + external-doc-url: https://aws.amazon.com/secrets-manager/ + logo: /integration-logos/aws/AWS-Secrets-Manager_light-bg@4x.png + tags: [aws] + - integration-name: Amazon Simple Email Service (SES) + external-doc-url: https://aws.amazon.com/ses/ + logo: /integration-logos/aws/Amazon-Simple-Email-Service-SES_light-bg@4x.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/ecs.rst + tags: [aws] + - integration-name: Amazon Simple Notification Service (SNS) + external-doc-url: https://aws.amazon.com/sns/ + logo: /integration-logos/aws/Amazon-Simple-Notification-Service-SNS_light-bg@4x.png + tags: [aws] + - integration-name: Amazon Simple Queue Service (SQS) + external-doc-url: https://aws.amazon.com/sqs/ + logo: /integration-logos/aws/Amazon-Simple-Queue-Service-SQS_light-bg@4x.png + tags: [aws] + - integration-name: Amazon Simple Storage Service (S3) + external-doc-url: https://aws.amazon.com/s3/ + logo: /integration-logos/aws/Amazon-Simple-Storage-Service-S3_light-bg@4x.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/s3.rst + tags: [aws] + - integration-name: Amazon Web Services + external-doc-url: https://aws.amazon.com/ + logo: /integration-logos/aws/AWS-Cloud-alt_light-bg@4x.png + tags: [aws] + - integration-name: AWS Batch + external-doc-url: https://aws.amazon.com/batch/ + logo: /integration-logos/aws/AWS-Batch_light-bg@4x.png + tags: [aws] + - integration-name: AWS DataSync + external-doc-url: https://aws.amazon.com/datasync/ + logo: /integration-logos/aws/AWS-DataSync_light-bg@4x.png + tags: [aws] + - integration-name: AWS Glue + external-doc-url: https://aws.amazon.com/glue/ + logo: /integration-logos/aws/AWS-Glue_light-bg@4x.png + tags: [aws] + - integration-name: AWS Lambda + external-doc-url: https://aws.amazon.com/lambda/ + logo: /integration-logos/aws/AWS-Lambda_light-bg@4x.png + tags: [aws] + - integration-name: AWS Step Functions + external-doc-url: https://aws.amazon.com/step-functions/ + logo: /integration-logos/aws/AWS-Step-Functions_light-bg@4x.png + tags: [aws] + +operators: + - integration-name: Amazon Athena + python-modules: + - airflow.providers.amazon.aws.operators.athena + - integration-name: AWS Batch + python-modules: + - airflow.providers.amazon.aws.operators.batch + - integration-name: Amazon CloudFormation + python-modules: + - airflow.providers.amazon.aws.operators.cloud_formation + - integration-name: Amazon DataSync + python-modules: + - airflow.providers.amazon.aws.operators.datasync + - integration-name: Amazon EC2 + python-modules: + - airflow.providers.amazon.aws.operators.ec2_start_instance + - airflow.providers.amazon.aws.operators.ec2_stop_instance + - integration-name: Amazon ECS + python-modules: + - airflow.providers.amazon.aws.operators.ecs + - integration-name: Amazon EMR + python-modules: + - airflow.providers.amazon.aws.operators.emr_add_steps + - airflow.providers.amazon.aws.operators.emr_create_job_flow + - airflow.providers.amazon.aws.operators.emr_modify_cluster + - airflow.providers.amazon.aws.operators.emr_terminate_job_flow + - integration-name: Amazon Glacier + python-modules: + - airflow.providers.amazon.aws.operators.glacier + - integration-name: AWS Glue + python-modules: + - airflow.providers.amazon.aws.operators.glue + - airflow.providers.amazon.aws.operators.glue_crawler + - integration-name: Amazon Simple Storage Service (S3) + python-modules: + - airflow.providers.amazon.aws.operators.s3_bucket + - airflow.providers.amazon.aws.operators.s3_bucket_tagging + - airflow.providers.amazon.aws.operators.s3_copy_object + - airflow.providers.amazon.aws.operators.s3_delete_objects + - airflow.providers.amazon.aws.operators.s3_file_transform + - airflow.providers.amazon.aws.operators.s3_list + - integration-name: Amazon SageMaker + python-modules: + - airflow.providers.amazon.aws.operators.sagemaker_base + - airflow.providers.amazon.aws.operators.sagemaker_endpoint + - airflow.providers.amazon.aws.operators.sagemaker_endpoint_config + - airflow.providers.amazon.aws.operators.sagemaker_model + - airflow.providers.amazon.aws.operators.sagemaker_processing + - airflow.providers.amazon.aws.operators.sagemaker_training + - airflow.providers.amazon.aws.operators.sagemaker_transform + - airflow.providers.amazon.aws.operators.sagemaker_tuning + - integration-name: Amazon Simple Notification Service (SNS) + python-modules: + - airflow.providers.amazon.aws.operators.sns + - integration-name: Amazon Simple Queue Service (SQS) + python-modules: + - airflow.providers.amazon.aws.operators.sqs + - integration-name: AWS Step Functions + python-modules: + - airflow.providers.amazon.aws.operators.step_function_get_execution_output + - airflow.providers.amazon.aws.operators.step_function_start_execution + +sensors: + - integration-name: Amazon Athena + python-modules: + - airflow.providers.amazon.aws.sensors.athena + - integration-name: Amazon CloudFormation + python-modules: + - airflow.providers.amazon.aws.sensors.cloud_formation + - integration-name: Amazon EC2 + python-modules: + - airflow.providers.amazon.aws.sensors.ec2_instance_state + - integration-name: Amazon EMR + python-modules: + - airflow.providers.amazon.aws.sensors.emr_base + - airflow.providers.amazon.aws.sensors.emr_job_flow + - airflow.providers.amazon.aws.sensors.emr_step + - integration-name: Amazon Glacier + python-modules: + - airflow.providers.amazon.aws.sensors.glacier + - integration-name: AWS Glue + python-modules: + - airflow.providers.amazon.aws.sensors.glue + - airflow.providers.amazon.aws.sensors.glue_crawler + - airflow.providers.amazon.aws.sensors.glue_catalog_partition + - integration-name: Amazon Redshift + python-modules: + - airflow.providers.amazon.aws.sensors.redshift + - integration-name: Amazon Simple Storage Service (S3) + python-modules: + - airflow.providers.amazon.aws.sensors.s3_key + - airflow.providers.amazon.aws.sensors.s3_keys_unchanged + - airflow.providers.amazon.aws.sensors.s3_prefix + - integration-name: Amazon SageMaker + python-modules: + - airflow.providers.amazon.aws.sensors.sagemaker_base + - airflow.providers.amazon.aws.sensors.sagemaker_endpoint + - airflow.providers.amazon.aws.sensors.sagemaker_training + - airflow.providers.amazon.aws.sensors.sagemaker_transform + - airflow.providers.amazon.aws.sensors.sagemaker_tuning + - integration-name: Amazon Simple Queue Service (SQS) + python-modules: + - airflow.providers.amazon.aws.sensors.sqs + - integration-name: AWS Step Functions + python-modules: + - airflow.providers.amazon.aws.sensors.step_function_execution + +hooks: + - integration-name: Amazon Athena + python-modules: + - airflow.providers.amazon.aws.hooks.athena + - integration-name: Amazon DynamoDB + python-modules: + - airflow.providers.amazon.aws.hooks.dynamodb + - airflow.providers.amazon.aws.hooks.aws_dynamodb + - integration-name: Amazon Web Services + python-modules: + - airflow.providers.amazon.aws.hooks.base_aws + - integration-name: AWS Batch + python-modules: + - airflow.providers.amazon.aws.hooks.batch_client + - airflow.providers.amazon.aws.hooks.batch_waiters + - integration-name: Amazon CloudFormation + python-modules: + - airflow.providers.amazon.aws.hooks.cloud_formation + - integration-name: Amazon DataSync + python-modules: + - airflow.providers.amazon.aws.hooks.datasync + - integration-name: Amazon EC2 + python-modules: + - airflow.providers.amazon.aws.hooks.ec2 + - integration-name: Amazon ElastiCache + python-modules: + - airflow.providers.amazon.aws.hooks.elasticache_replication_group + - integration-name: Amazon EMR + python-modules: + - airflow.providers.amazon.aws.hooks.emr + - integration-name: Amazon Glacier + python-modules: + - airflow.providers.amazon.aws.hooks.glacier + - integration-name: AWS Glue + python-modules: + - airflow.providers.amazon.aws.hooks.glue + - airflow.providers.amazon.aws.hooks.glue_crawler + - airflow.providers.amazon.aws.hooks.glue_catalog + - integration-name: Amazon Kinesis Data Firehose + python-modules: + - airflow.providers.amazon.aws.hooks.kinesis + - integration-name: AWS Lambda + python-modules: + - airflow.providers.amazon.aws.hooks.lambda_function + - integration-name: Amazon CloudWatch Logs + python-modules: + - airflow.providers.amazon.aws.hooks.logs + - integration-name: Amazon Redshift + python-modules: + - airflow.providers.amazon.aws.hooks.redshift + - integration-name: Amazon Simple Storage Service (S3) + python-modules: + - airflow.providers.amazon.aws.hooks.s3 + - integration-name: Amazon SageMaker + python-modules: + - airflow.providers.amazon.aws.hooks.sagemaker + - integration-name: Amazon Simple Email Service (SES) + python-modules: + - airflow.providers.amazon.aws.hooks.ses + - integration-name: Amazon SecretsManager + python-modules: + - airflow.providers.amazon.aws.hooks.secrets_manager + - integration-name: Amazon Simple Notification Service (SNS) + python-modules: + - airflow.providers.amazon.aws.hooks.sns + - integration-name: Amazon Simple Queue Service (SQS) + python-modules: + - airflow.providers.amazon.aws.hooks.sqs + - integration-name: AWS Step Functions + python-modules: + - airflow.providers.amazon.aws.hooks.step_function + +transfers: + - source-integration-name: Amazon DynamoDB + target-integration-name: Amazon Simple Storage Service (S3) + python-module: airflow.providers.amazon.aws.transfers.dynamodb_to_s3 + - source-integration-name: Google Cloud Storage (GCS) + target-integration-name: Amazon Simple Storage Service (S3) + python-module: airflow.providers.amazon.aws.transfers.gcs_to_s3 + - source-integration-name: Amazon Glacier + target-integration-name: Google Cloud Storage (GCS) + how-to-guide: /docs/apache-airflow-providers-amazon/operators/transfer/glacier_to_gcs.rst + python-module: airflow.providers.amazon.aws.transfers.glacier_to_gcs + - source-integration-name: Google + target-integration-name: Amazon Simple Storage Service (S3) + how-to-guide: /docs/apache-airflow-providers-amazon/operators/google_api_to_s3_transfer.rst + python-module: airflow.providers.amazon.aws.transfers.google_api_to_s3 + - source-integration-name: Apache Hive + target-integration-name: Amazon DynamoDB + python-module: airflow.providers.amazon.aws.transfers.hive_to_dynamodb + - source-integration-name: Internet Message Access Protocol (IMAP) + target-integration-name: Amazon Simple Storage Service (S3) + how-to-guide: /docs/apache-airflow-providers-amazon/operators/imap_attachment_to_s3.rst + python-module: airflow.providers.amazon.aws.transfers.imap_attachment_to_s3 + - source-integration-name: MongoDB + target-integration-name: Amazon Simple Storage Service (S3) + python-module: airflow.providers.amazon.aws.transfers.mongo_to_s3 + - source-integration-name: MySQL + target-integration-name: Amazon Simple Storage Service (S3) + python-module: airflow.providers.amazon.aws.transfers.mysql_to_s3 + - source-integration-name: Amazon Redshift + target-integration-name: Amazon Simple Storage Service (S3) + python-module: airflow.providers.amazon.aws.transfers.redshift_to_s3 + - source-integration-name: Amazon Simple Storage Service (S3) + target-integration-name: Amazon Redshift + how-to-guide: /docs/apache-airflow-providers-amazon/operators/s3_to_redshift.rst + python-module: airflow.providers.amazon.aws.transfers.s3_to_redshift + - source-integration-name: Amazon Simple Storage Service (S3) + target-integration-name: SSH File Transfer Protocol (SFTP) + python-module: airflow.providers.amazon.aws.transfers.s3_to_sftp + - source-integration-name: SSH File Transfer Protocol (SFTP) + target-integration-name: Amazon Simple Storage Service (S3) + python-module: airflow.providers.amazon.aws.transfers.sftp_to_s3 + - source-integration-name: Amazon Simple Storage Service (S3) + target-integration-name: File Transfer Protocol (FTP) + python-module: airflow.providers.amazon.aws.transfers.s3_to_ftp + - source-integration-name: Exasol + target-integration-name: Amazon Simple Storage Service (S3) + python-module: airflow.providers.amazon.aws.transfers.exasol_to_s3 + - source-integration-name: File Transfer Protocol (FTP) + target-integration-name: Amazon Simple Storage Service (S3) + python-module: airflow.providers.amazon.aws.transfers.ftp_to_s3 + +hook-class-names: + - airflow.providers.amazon.aws.hooks.s3.S3Hook + - airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook + - airflow.providers.amazon.aws.hooks.emr.EmrHook diff --git a/reference/providers/apache/__init__.py b/reference/providers/apache/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/beam/CHANGELOG.rst b/reference/providers/apache/beam/CHANGELOG.rst new file mode 100644 index 0000000..c3129b4 --- /dev/null +++ b/reference/providers/apache/beam/CHANGELOG.rst @@ -0,0 +1,35 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Bug fixes +~~~~~~~~~ + +* ``Improve Apache Beam operators - refactor operator - common Dataflow logic (#14094)`` +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` +* ``Remove WARNINGs from BeamHook (#14554)`` + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/apache/beam/README.md b/reference/providers/apache/beam/README.md new file mode 100644 index 0000000..b5a6183 --- /dev/null +++ b/reference/providers/apache/beam/README.md @@ -0,0 +1,92 @@ + + +# Package apache-airflow-providers-apache-beam + +Release: 1.0.0 + +**Table of contents** + +- [Provider package](#provider-package) +- [Installation](#installation) +- [PIP requirements](#pip-requirements) +- [Cross provider package dependencies](#cross-provider-package-dependencies) +- [Provider class summary](#provider-classes-summary) + - [Operators](#operators) + - [Transfer operators](#transfer-operators) + - [Hooks](#hooks) +- [Releases](#releases) + +## Provider package + +This is a provider package for `apache.beam` provider. All classes for this provider package +are in `airflow.providers.apache.beam` python package. + +## Installation + +NOTE! + +On November 2020, new version of PIP (20.3) has been released with a new, 2020 resolver. This resolver +does not yet work with Apache Airflow and might lead to errors in installation - depends on your choice +of extras. In order to install Airflow you need to either downgrade pip to version 20.2.4 +`pip install --upgrade pip==20.2.4` or, in case you use Pip 20.3, you need to add option +`--use-deprecated legacy-resolver` to your pip install command. + +You can install this package on top of an existing airflow 2.\* installation via +`pip install apache-airflow-providers-apache-beam` + +## Cross provider package dependencies + +Those are dependencies that might be needed in order to use all the features of the package. +You need to install the specified provider packages in order to use them. + +You can install such cross-provider dependencies when installing from PyPI. For example: + +```bash +pip install apache-airflow-providers-apache-beam[google] +``` + +| Dependent package | Extra | +| :---------------------------------------------------------------------------------- | :----- | +| [apache-airflow-providers-google](pypi.org/project/apache-airflow-providers-google) | google | + +# Provider classes summary + +In Airflow 2.0, all operators, transfers, hooks, sensors, secrets for the `apache.beam` provider +are in the `airflow.providers.apache.beam` package. You can read more about the naming conventions used +in [Naming conventions for provider packages](github.com/apache/airflow/blob/master/CONTRIBUTING.rst#naming-conventions-for-provider-packages) + +## Operators + +### New operators + +| New Airflow 2.0 operators: `airflow.providers.apache.beam` package | +| :------------------------------------------------------------------------------------------------------------------------------------ | +| [operators.beam.BeamRunJavaPipelineOperator](github.com/apache/airflow/blob/master/airflow/providers/apache/beam/operators/beam.py) | +| [operators.beam.BeamRunPythonPipelineOperator](github.com/apache/airflow/blob/master/airflow/providers/apache/beam/operators/beam.py) | + +## Hooks + +### New hooks + +| New Airflow 2.0 hooks: `airflow.providers.apache.beam` package | +| :------------------------------------------------------------------------------------------------------- | +| [hooks.beam.BeamHook](github.com/apache/airflow/blob/master/airflow/providers/apache/beam/hooks/beam.py) | + +## Releases diff --git a/reference/providers/apache/beam/__init__.py b/reference/providers/apache/beam/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/beam/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/beam/example_dags/__init__.py b/reference/providers/apache/beam/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/beam/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/beam/example_dags/example_beam.py b/reference/providers/apache/beam/example_dags/example_beam.py new file mode 100644 index 0000000..1b2d5ed --- /dev/null +++ b/reference/providers/apache/beam/example_dags/example_beam.py @@ -0,0 +1,329 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Apache Beam operators +""" +import os +from urllib.parse import urlparse + +from airflow import models +from airflow.providers.apache.beam.operators.beam import ( + BeamRunJavaPipelineOperator, + BeamRunPythonPipelineOperator, +) +from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus +from airflow.providers.google.cloud.operators.dataflow import DataflowConfiguration +from airflow.providers.google.cloud.sensors.dataflow import DataflowJobStatusSensor +from airflow.providers.google.cloud.transfers.gcs_to_local import ( + GCSToLocalFilesystemOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCS_INPUT = os.environ.get( + "APACHE_BEAM_PYTHON", "gs://apache-beam-samples/shakespeare/kinglear.txt" +) +GCS_TMP = os.environ.get("APACHE_BEAM_GCS_TMP", "gs://test-dataflow-example/temp/") +GCS_STAGING = os.environ.get( + "APACHE_BEAM_GCS_STAGING", "gs://test-dataflow-example/staging/" +) +GCS_OUTPUT = os.environ.get( + "APACHE_BEAM_GCS_OUTPUT", "gs://test-dataflow-example/output" +) +GCS_PYTHON = os.environ.get( + "APACHE_BEAM_PYTHON", "gs://test-dataflow-example/wordcount_debugging.py" +) +GCS_PYTHON_DATAFLOW_ASYNC = os.environ.get( + "APACHE_BEAM_PYTHON_DATAFLOW_ASYNC", + "gs://test-dataflow-example/wordcount_debugging.py", +) + +GCS_JAR_DIRECT_RUNNER = os.environ.get( + "APACHE_BEAM_DIRECT_RUNNER_JAR", + "gs://test-dataflow-example/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-DirectRunner.jar", +) +GCS_JAR_DATAFLOW_RUNNER = os.environ.get( + "APACHE_BEAM_DATAFLOW_RUNNER_JAR", + "gs://test-dataflow-example/word-count-beam-bundled-0.1.jar", +) +GCS_JAR_SPARK_RUNNER = os.environ.get( + "APACHE_BEAM_SPARK_RUNNER_JAR", + "gs://test-dataflow-example/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-SparkRunner.jar", +) +GCS_JAR_FLINK_RUNNER = os.environ.get( + "APACHE_BEAM_FLINK_RUNNER_JAR", + "gs://test-dataflow-example/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-FlinkRunner.jar", +) + +GCS_JAR_DIRECT_RUNNER_PARTS = urlparse(GCS_JAR_DIRECT_RUNNER) +GCS_JAR_DIRECT_RUNNER_BUCKET_NAME = GCS_JAR_DIRECT_RUNNER_PARTS.netloc +GCS_JAR_DIRECT_RUNNER_OBJECT_NAME = GCS_JAR_DIRECT_RUNNER_PARTS.path[1:] +GCS_JAR_DATAFLOW_RUNNER_PARTS = urlparse(GCS_JAR_DATAFLOW_RUNNER) +GCS_JAR_DATAFLOW_RUNNER_BUCKET_NAME = GCS_JAR_DATAFLOW_RUNNER_PARTS.netloc +GCS_JAR_DATAFLOW_RUNNER_OBJECT_NAME = GCS_JAR_DATAFLOW_RUNNER_PARTS.path[1:] +GCS_JAR_SPARK_RUNNER_PARTS = urlparse(GCS_JAR_SPARK_RUNNER) +GCS_JAR_SPARK_RUNNER_BUCKET_NAME = GCS_JAR_SPARK_RUNNER_PARTS.netloc +GCS_JAR_SPARK_RUNNER_OBJECT_NAME = GCS_JAR_SPARK_RUNNER_PARTS.path[1:] +GCS_JAR_FLINK_RUNNER_PARTS = urlparse(GCS_JAR_FLINK_RUNNER) +GCS_JAR_FLINK_RUNNER_BUCKET_NAME = GCS_JAR_FLINK_RUNNER_PARTS.netloc +GCS_JAR_FLINK_RUNNER_OBJECT_NAME = GCS_JAR_FLINK_RUNNER_PARTS.path[1:] + + +default_args = { + "default_pipeline_options": { + "output": "/tmp/example_beam", + }, + "trigger_rule": "all_done", +} + + +with models.DAG( + "example_beam_native_java_direct_runner", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag_native_java_direct_runner: + + # [START howto_operator_start_java_direct_runner_pipeline] + jar_to_local_direct_runner = GCSToLocalFilesystemOperator( + task_id="jar_to_local_direct_runner", + bucket=GCS_JAR_DIRECT_RUNNER_BUCKET_NAME, + object_name=GCS_JAR_DIRECT_RUNNER_OBJECT_NAME, + filename="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar", + ) + + start_java_pipeline_direct_runner = BeamRunJavaPipelineOperator( + task_id="start_java_pipeline_direct_runner", + jar="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar", + pipeline_options={ + "output": "/tmp/start_java_pipeline_direct_runner", + "inputFile": GCS_INPUT, + }, + job_class="org.apache.beam.examples.WordCount", + ) + + jar_to_local_direct_runner >> start_java_pipeline_direct_runner + # [END howto_operator_start_java_direct_runner_pipeline] + +with models.DAG( + "example_beam_native_java_dataflow_runner", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag_native_java_dataflow_runner: + # [START howto_operator_start_java_dataflow_runner_pipeline] + jar_to_local_dataflow_runner = GCSToLocalFilesystemOperator( + task_id="jar_to_local_dataflow_runner", + bucket=GCS_JAR_DATAFLOW_RUNNER_BUCKET_NAME, + object_name=GCS_JAR_DATAFLOW_RUNNER_OBJECT_NAME, + filename="/tmp/beam_wordcount_dataflow_runner_{{ ds_nodash }}.jar", + ) + + start_java_pipeline_dataflow = BeamRunJavaPipelineOperator( + task_id="start_java_pipeline_dataflow", + runner="DataflowRunner", + jar="/tmp/beam_wordcount_dataflow_runner_{{ ds_nodash }}.jar", + pipeline_options={ + "tempLocation": GCS_TMP, + "stagingLocation": GCS_STAGING, + "output": GCS_OUTPUT, + }, + job_class="org.apache.beam.examples.WordCount", + dataflow_config={"job_name": "{{task.task_id}}", "location": "us-central1"}, + ) + + jar_to_local_dataflow_runner >> start_java_pipeline_dataflow + # [END howto_operator_start_java_dataflow_runner_pipeline] + +with models.DAG( + "example_beam_native_java_spark_runner", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag_native_java_spark_runner: + + jar_to_local_spark_runner = GCSToLocalFilesystemOperator( + task_id="jar_to_local_spark_runner", + bucket=GCS_JAR_SPARK_RUNNER_BUCKET_NAME, + object_name=GCS_JAR_SPARK_RUNNER_OBJECT_NAME, + filename="/tmp/beam_wordcount_spark_runner_{{ ds_nodash }}.jar", + ) + + start_java_pipeline_spark_runner = BeamRunJavaPipelineOperator( + task_id="start_java_pipeline_spark_runner", + runner="SparkRunner", + jar="/tmp/beam_wordcount_spark_runner_{{ ds_nodash }}.jar", + pipeline_options={ + "output": "/tmp/start_java_pipeline_spark_runner", + "inputFile": GCS_INPUT, + }, + job_class="org.apache.beam.examples.WordCount", + ) + + jar_to_local_spark_runner >> start_java_pipeline_spark_runner + +with models.DAG( + "example_beam_native_java_flink_runner", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag_native_java_flink_runner: + + jar_to_local_flink_runner = GCSToLocalFilesystemOperator( + task_id="jar_to_local_flink_runner", + bucket=GCS_JAR_FLINK_RUNNER_BUCKET_NAME, + object_name=GCS_JAR_FLINK_RUNNER_OBJECT_NAME, + filename="/tmp/beam_wordcount_flink_runner_{{ ds_nodash }}.jar", + ) + + start_java_pipeline_flink_runner = BeamRunJavaPipelineOperator( + task_id="start_java_pipeline_flink_runner", + runner="FlinkRunner", + jar="/tmp/beam_wordcount_flink_runner_{{ ds_nodash }}.jar", + pipeline_options={ + "output": "/tmp/start_java_pipeline_flink_runner", + "inputFile": GCS_INPUT, + }, + job_class="org.apache.beam.examples.WordCount", + ) + + jar_to_local_flink_runner >> start_java_pipeline_flink_runner + + +with models.DAG( + "example_beam_native_python", + default_args=default_args, + start_date=days_ago(1), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag_native_python: + + # [START howto_operator_start_python_direct_runner_pipeline_local_file] + start_python_pipeline_local_direct_runner = BeamRunPythonPipelineOperator( + task_id="start_python_pipeline_local_direct_runner", + py_file="apache_beam.examples.wordcount", + py_options=["-m"], + py_requirements=["apache-beam[gcp]==2.26.0"], + py_interpreter="python3", + py_system_site_packages=False, + ) + # [END howto_operator_start_python_direct_runner_pipeline_local_file] + + # [START howto_operator_start_python_direct_runner_pipeline_gcs_file] + start_python_pipeline_direct_runner = BeamRunPythonPipelineOperator( + task_id="start_python_pipeline_direct_runner", + py_file=GCS_PYTHON, + py_options=[], + pipeline_options={"output": GCS_OUTPUT}, + py_requirements=["apache-beam[gcp]==2.26.0"], + py_interpreter="python3", + py_system_site_packages=False, + ) + # [END howto_operator_start_python_direct_runner_pipeline_gcs_file] + + # [START howto_operator_start_python_dataflow_runner_pipeline_gcs_file] + start_python_pipeline_dataflow_runner = BeamRunPythonPipelineOperator( + task_id="start_python_pipeline_dataflow_runner", + runner="DataflowRunner", + py_file=GCS_PYTHON, + pipeline_options={ + "tempLocation": GCS_TMP, + "stagingLocation": GCS_STAGING, + "output": GCS_OUTPUT, + }, + py_options=[], + py_requirements=["apache-beam[gcp]==2.26.0"], + py_interpreter="python3", + py_system_site_packages=False, + dataflow_config=DataflowConfiguration( + job_name="{{task.task_id}}", + project_id=GCP_PROJECT_ID, + location="us-central1", + ), + ) + # [END howto_operator_start_python_dataflow_runner_pipeline_gcs_file] + + start_python_pipeline_local_spark_runner = BeamRunPythonPipelineOperator( + task_id="start_python_pipeline_local_spark_runner", + py_file="apache_beam.examples.wordcount", + runner="SparkRunner", + py_options=["-m"], + py_requirements=["apache-beam[gcp]==2.26.0"], + py_interpreter="python3", + py_system_site_packages=False, + ) + + start_python_pipeline_local_flink_runner = BeamRunPythonPipelineOperator( + task_id="start_python_pipeline_local_flink_runner", + py_file="apache_beam.examples.wordcount", + runner="FlinkRunner", + py_options=["-m"], + pipeline_options={ + "output": "/tmp/start_python_pipeline_local_flink_runner", + }, + py_requirements=["apache-beam[gcp]==2.26.0"], + py_interpreter="python3", + py_system_site_packages=False, + ) + + [ + start_python_pipeline_local_direct_runner, + start_python_pipeline_direct_runner, + ] >> start_python_pipeline_local_flink_runner >> start_python_pipeline_local_spark_runner + + +with models.DAG( + "example_beam_native_python_dataflow_async", + default_args=default_args, + start_date=days_ago(1), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag_native_python_dataflow_async: + # [START howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file] + start_python_job_dataflow_runner_async = BeamRunPythonPipelineOperator( + task_id="start_python_job_dataflow_runner_async", + runner="DataflowRunner", + py_file=GCS_PYTHON_DATAFLOW_ASYNC, + pipeline_options={ + "tempLocation": GCS_TMP, + "stagingLocation": GCS_STAGING, + "output": GCS_OUTPUT, + }, + py_options=[], + py_requirements=["apache-beam[gcp]==2.26.0"], + py_interpreter="python3", + py_system_site_packages=False, + dataflow_config=DataflowConfiguration( + job_name="{{task.task_id}}", + project_id=GCP_PROJECT_ID, + location="us-central1", + wait_until_finished=False, + ), + ) + + wait_for_python_job_dataflow_runner_async_done = DataflowJobStatusSensor( + task_id="wait-for-python-job-async-done", + job_id="{{task_instance.xcom_pull('start_python_job_dataflow_runner_async')['dataflow_job_id']}}", + expected_statuses={DataflowJobStatus.JOB_STATE_DONE}, + project_id=GCP_PROJECT_ID, + location="us-central1", + ) + + start_python_job_dataflow_runner_async >> wait_for_python_job_dataflow_runner_async_done + # [END howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file] diff --git a/reference/providers/apache/beam/hooks/__init__.py b/reference/providers/apache/beam/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/beam/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/beam/hooks/beam.py b/reference/providers/apache/beam/hooks/beam.py new file mode 100644 index 0000000..5f7783f --- /dev/null +++ b/reference/providers/apache/beam/hooks/beam.py @@ -0,0 +1,298 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Apache Beam Hook.""" +import json +import select +import shlex +import subprocess +import textwrap +from tempfile import TemporaryDirectory +from typing import Callable, List, Optional + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.python_virtualenv import prepare_virtualenv + + +class BeamRunnerType: + """ + Helper class for listing runner types. + For more information about runners see: + https://beam.apache.org/documentation/ + """ + + DataflowRunner = "DataflowRunner" + DirectRunner = "DirectRunner" + SparkRunner = "SparkRunner" + FlinkRunner = "FlinkRunner" + SamzaRunner = "SamzaRunner" + NemoRunner = "NemoRunner" + JetRunner = "JetRunner" + Twister2Runner = "Twister2Runner" + + +def beam_options_to_args(options: dict) -> List[str]: + """ + Returns a formatted pipeline options from a dictionary of arguments + + The logic of this method should be compatible with Apache Beam: + https://github.com/apache/beam/blob/b56740f0e8cd80c2873412847d0b336837429fb9/sdks/python/ + apache_beam/options/pipeline_options.py#L230-L251 + + :param options: Dictionary with options + :type options: dict + :return: List of arguments + :rtype: List[str] + """ + if not options: + return [] + + args: List[str] = [] + for attr, value in options.items(): + if value is None or (isinstance(value, bool) and value): + args.append(f"--{attr}") + elif isinstance(value, list): + args.extend([f"--{attr}={v}" for v in value]) + else: + args.append(f"--{attr}={value}") + return args + + +class BeamCommandRunner(LoggingMixin): + """ + Class responsible for running pipeline command in subprocess + + :param cmd: Parts of the command to be run in subprocess + :type cmd: List[str] + :param process_line_callback: Optional callback which can be used to process + stdout and stderr to detect job id + :type process_line_callback: Optional[Callable[[str], None]] + """ + + def __init__( + self, + cmd: List[str], + process_line_callback: Optional[Callable[[str], None]] = None, + ) -> None: + super().__init__() + self.log.info("Running command: %s", " ".join(shlex.quote(c) for c in cmd)) + self.process_line_callback = process_line_callback + self.job_id: Optional[str] = None + self._proc = subprocess.Popen( + cmd, + shell=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=True, + ) + + def _process_fd(self, fd): + """ + Prints output to logs. + + :param fd: File descriptor. + """ + if fd not in (self._proc.stdout, self._proc.stderr): + raise Exception("No data in stderr or in stdout.") + + fd_to_log = { + self._proc.stderr: self.log.warning, + self._proc.stdout: self.log.info, + } + func_log = fd_to_log[fd] + + while True: + line = fd.readline().decode() + if not line: + return + if self.process_line_callback: + self.process_line_callback(line) + func_log(line.rstrip("\n")) + + def wait_for_done(self) -> None: + """Waits for Apache Beam pipeline to complete.""" + self.log.info("Start waiting for Apache Beam process to complete.") + reads = [self._proc.stderr, self._proc.stdout] + while True: + # Wait for at least one available fd. + readable_fds, _, _ = select.select(reads, [], [], 5) + if readable_fds is None: + self.log.info("Waiting for Apache Beam process to complete.") + continue + + for readable_fd in readable_fds: + self._process_fd(readable_fd) + + if self._proc.poll() is not None: + break + + # Corner case: check if more output was created between the last read and the process termination + for readable_fd in reads: + self._process_fd(readable_fd) + + self.log.info("Process exited with return code: %s", self._proc.returncode) + + if self._proc.returncode != 0: + raise AirflowException( + f"Apache Beam process failed with return code {self._proc.returncode}" + ) + + +class BeamHook(BaseHook): + """ + Hook for Apache Beam. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + + :param runner: Runner type + :type runner: str + """ + + def __init__( + self, + runner: str, + ) -> None: + self.runner = runner + super().__init__() + + def _start_pipeline( + self, + variables: dict, + command_prefix: List[str], + process_line_callback: Optional[Callable[[str], None]] = None, + ) -> None: + cmd = command_prefix + [ + f"--runner={self.runner}", + ] + if variables: + cmd.extend(beam_options_to_args(variables)) + cmd_runner = BeamCommandRunner( + cmd=cmd, + process_line_callback=process_line_callback, + ) + cmd_runner.wait_for_done() + + def start_python_pipeline( # pylint: disable=too-many-arguments + self, + variables: dict, + py_file: str, + py_options: List[str], + py_interpreter: str = "python3", + py_requirements: Optional[List[str]] = None, + py_system_site_packages: bool = False, + process_line_callback: Optional[Callable[[str], None]] = None, + ): + """ + Starts Apache Beam python pipeline. + + :param variables: Variables passed to the pipeline. + :type variables: Dict + :param py_options: Additional options. + :type py_options: List[str] + :param py_interpreter: Python version of the Apache Beam pipeline. + If None, this defaults to the python3. + To track python versions supported by beam and related + issues check: https://issues.apache.org/jira/browse/BEAM-1251 + :type py_interpreter: str + :param py_requirements: Additional python package(s) to install. + If a value is passed to this parameter, a new virtual environment has been created with + additional packages installed. + + You could also install the apache-beam package if it is not installed on your system or you want + to use a different version. + :type py_requirements: List[str] + :param py_system_site_packages: Whether to include system_site_packages in your virtualenv. + See virtualenv documentation for more information. + + This option is only relevant if the ``py_requirements`` parameter is not None. + :type py_system_site_packages: bool + :param on_new_job_id_callback: Callback called when the job ID is known. + :type on_new_job_id_callback: callable + """ + if "labels" in variables: + variables["labels"] = [ + f"{key}={value}" for key, value in variables["labels"].items() + ] + + if py_requirements is not None: + if not py_requirements and not py_system_site_packages: + warning_invalid_environment = textwrap.dedent( + """\ + Invalid method invocation. You have disabled inclusion of system packages and empty list + required for installation, so it is not possible to create a valid virtual environment. + In the virtual environment, apache-beam package must be installed for your job to be \ + executed. To fix this problem: + * install apache-beam on the system, then set parameter py_system_site_packages to True, + * add apache-beam to the list of required packages in parameter py_requirements. + """ + ) + raise AirflowException(warning_invalid_environment) + + with TemporaryDirectory(prefix="apache-beam-venv") as tmp_dir: + py_interpreter = prepare_virtualenv( + venv_directory=tmp_dir, + python_bin=py_interpreter, + system_site_packages=py_system_site_packages, + requirements=py_requirements, + ) + command_prefix = [py_interpreter] + py_options + [py_file] + + self._start_pipeline( + variables=variables, + command_prefix=command_prefix, + process_line_callback=process_line_callback, + ) + else: + command_prefix = [py_interpreter] + py_options + [py_file] + + self._start_pipeline( + variables=variables, + command_prefix=command_prefix, + process_line_callback=process_line_callback, + ) + + def start_java_pipeline( + self, + variables: dict, + jar: str, + job_class: Optional[str] = None, + process_line_callback: Optional[Callable[[str], None]] = None, + ) -> None: + """ + Starts Apache Beam Java pipeline. + + :param variables: Variables passed to the job. + :type variables: dict + :param jar: Name of the jar for the pipeline + :type job_class: str + :param job_class: Name of the java class for the pipeline. + :type job_class: str + """ + if "labels" in variables: + variables["labels"] = json.dumps(variables["labels"], separators=(",", ":")) + + command_prefix = ( + ["java", "-cp", jar, job_class] if job_class else ["java", "-jar", jar] + ) + self._start_pipeline( + variables=variables, + command_prefix=command_prefix, + process_line_callback=process_line_callback, + ) diff --git a/reference/providers/apache/beam/operators/__init__.py b/reference/providers/apache/beam/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/beam/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/beam/operators/beam.py b/reference/providers/apache/beam/operators/beam.py new file mode 100644 index 0000000..2917f5a --- /dev/null +++ b/reference/providers/apache/beam/operators/beam.py @@ -0,0 +1,497 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Apache Beam operators.""" +import copy +from abc import ABCMeta +from contextlib import ExitStack +from typing import Callable, List, Optional, Tuple, Union + +from airflow.models import BaseOperator +from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType +from airflow.providers.google.cloud.hooks.dataflow import ( + DataflowHook, + process_line_and_extract_dataflow_job_id_callback, +) +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.cloud.operators.dataflow import ( + CheckJobRunning, + DataflowConfiguration, +) +from airflow.utils.decorators import apply_defaults +from airflow.utils.helpers import convert_camel_to_snake +from airflow.version import version + + +class BeamDataflowMixin(metaclass=ABCMeta): + """ + Helper class to store common, Dataflow specific logic for both + :class:`~airflow.providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator` and + :class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`. + """ + + dataflow_hook: Optional[DataflowHook] + dataflow_config: Optional[DataflowConfiguration] + + def _set_dataflow( + self, pipeline_options: dict, job_name_variable_key: Optional[str] = None + ) -> Tuple[str, dict, Callable[[str], None]]: + self.dataflow_hook = self.__set_dataflow_hook() + self.dataflow_config.project_id = ( + self.dataflow_config.project_id or self.dataflow_hook.project_id + ) + dataflow_job_name = self.__get_dataflow_job_name() + pipeline_options = self.__get_dataflow_pipeline_options( + pipeline_options, dataflow_job_name, job_name_variable_key + ) + process_line_callback = self.__get_dataflow_process_callback() + return dataflow_job_name, pipeline_options, process_line_callback + + def __set_dataflow_hook(self) -> DataflowHook: + self.dataflow_hook = DataflowHook( + gcp_conn_id=self.dataflow_config.gcp_conn_id or self.gcp_conn_id, + delegate_to=self.dataflow_config.delegate_to or self.delegate_to, + poll_sleep=self.dataflow_config.poll_sleep, + impersonation_chain=self.dataflow_config.impersonation_chain, + drain_pipeline=self.dataflow_config.drain_pipeline, + cancel_timeout=self.dataflow_config.cancel_timeout, + wait_until_finished=self.dataflow_config.wait_until_finished, + ) + return self.dataflow_hook + + def __get_dataflow_job_name(self) -> str: + return DataflowHook.build_dataflow_job_name( + self.dataflow_config.job_name, self.dataflow_config.append_job_name + ) + + def __get_dataflow_pipeline_options( + self, pipeline_options: dict, job_name: str, job_name_key: Optional[str] = None + ) -> dict: + pipeline_options = copy.deepcopy(pipeline_options) + if job_name_key is not None: + pipeline_options[job_name_key] = job_name + pipeline_options["project"] = self.dataflow_config.project_id + pipeline_options["region"] = self.dataflow_config.location + pipeline_options.setdefault("labels", {}).update( + {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")} + ) + return pipeline_options + + def __get_dataflow_process_callback(self) -> Callable[[str], None]: + def set_current_dataflow_job_id(job_id): + self.dataflow_job_id = job_id + + return process_line_and_extract_dataflow_job_id_callback( + on_new_job_id_callback=set_current_dataflow_job_id + ) + + +class BeamRunPythonPipelineOperator(BaseOperator, BeamDataflowMixin): + """ + Launching Apache Beam pipelines written in Python. Note that both + ``default_pipeline_options`` and ``pipeline_options`` will be merged to specify pipeline + execution parameter, and ``default_pipeline_options`` is expected to save + high-level options, for instances, project and zone information, which + apply to all beam operators in the DAG. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BeamRunPythonPipelineOperator` + + .. seealso:: + For more detail on Apache Beam have a look at the reference: + https://beam.apache.org/documentation/ + + :param py_file: Reference to the python Apache Beam pipeline file.py, e.g., + /some/local/file/path/to/your/python/pipeline/file. (templated) + :type py_file: str + :param runner: Runner on which pipeline will be run. By default "DirectRunner" is being used. + Other possible options: DataflowRunner, SparkRunner, FlinkRunner. + See: :class:`~providers.apache.beam.hooks.beam.BeamRunnerType` + See: https://beam.apache.org/documentation/runners/capability-matrix/ + + :type runner: str + :param py_options: Additional python options, e.g., ["-m", "-v"]. + :type py_options: list[str] + :param default_pipeline_options: Map of default pipeline options. + :type default_pipeline_options: dict + :param pipeline_options: Map of pipeline options.The key must be a dictionary. + The value can contain different types: + + * If the value is None, the single option - ``--key`` (without value) will be added. + * If the value is False, this option will be skipped + * If the value is True, the single option - ``--key`` (without value) will be added. + * If the value is list, the many options will be added for each key. + If the value is ``['A', 'B']`` and the key is ``key`` then the ``--key=A --key-B`` options + will be left + * Other value types will be replaced with the Python textual representation. + + When defining labels (``labels`` option), you can also provide a dictionary. + :type pipeline_options: dict + :param py_interpreter: Python version of the beam pipeline. + If None, this defaults to the python3. + To track python versions supported by beam and related + issues check: https://issues.apache.org/jira/browse/BEAM-1251 + :type py_interpreter: str + :param py_requirements: Additional python package(s) to install. + If a value is passed to this parameter, a new virtual environment has been created with + additional packages installed. + + You could also install the apache_beam package if it is not installed on your system or you want + to use a different version. + :type py_requirements: List[str] + :param py_system_site_packages: Whether to include system_site_packages in your virtualenv. + See virtualenv documentation for more information. + + This option is only relevant if the ``py_requirements`` parameter is not None. + :param gcp_conn_id: Optional. + The connection ID to use connecting to Google Cloud Storage if python file is on GCS. + :type gcp_conn_id: str + :param delegate_to: Optional. + The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param dataflow_config: Dataflow configuration, used when runner type is set to DataflowRunner + :type dataflow_config: Union[dict, providers.google.cloud.operators.dataflow.DataflowConfiguration] + """ + + template_fields = [ + "py_file", + "runner", + "pipeline_options", + "default_pipeline_options", + "dataflow_config", + ] + template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"} + + @apply_defaults + def __init__( + self, + *, + py_file: str, + runner: str = "DirectRunner", + default_pipeline_options: Optional[dict] = None, + pipeline_options: Optional[dict] = None, + py_interpreter: str = "python3", + py_options: Optional[List[str]] = None, + py_requirements: Optional[List[str]] = None, + py_system_site_packages: bool = False, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.py_file = py_file + self.runner = runner + self.py_options = py_options or [] + self.default_pipeline_options = default_pipeline_options or {} + self.pipeline_options = pipeline_options or {} + self.pipeline_options.setdefault("labels", {}).update( + {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")} + ) + self.py_interpreter = py_interpreter + self.py_requirements = py_requirements + self.py_system_site_packages = py_system_site_packages + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.dataflow_config = dataflow_config or {} + self.beam_hook: Optional[BeamHook] = None + self.dataflow_hook: Optional[DataflowHook] = None + self.dataflow_job_id: Optional[str] = None + + if ( + self.dataflow_config + and self.runner.lower() != BeamRunnerType.DataflowRunner.lower() + ): + self.log.warning( + "dataflow_config is defined but runner is different than DataflowRunner (%s)", + self.runner, + ) + + def execute(self, context): + """Execute the Apache Beam Pipeline.""" + self.beam_hook = BeamHook(runner=self.runner) + pipeline_options = self.default_pipeline_options.copy() + process_line_callback: Optional[Callable] = None + is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower() + dataflow_job_name: Optional[str] = None + + if isinstance(self.dataflow_config, dict): + self.dataflow_config = DataflowConfiguration(**self.dataflow_config) + + if is_dataflow: + ( + dataflow_job_name, + pipeline_options, + process_line_callback, + ) = self._set_dataflow( + pipeline_options=pipeline_options, job_name_variable_key="job_name" + ) + + pipeline_options.update(self.pipeline_options) + + # Convert argument names from lowerCamelCase to snake case. + formatted_pipeline_options = { + convert_camel_to_snake(key): pipeline_options[key] + for key in pipeline_options + } + + with ExitStack() as exit_stack: + if self.py_file.lower().startswith("gs://"): + gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to) + tmp_gcs_file = exit_stack.enter_context( # pylint: disable=no-member + gcs_hook.provide_file(object_url=self.py_file) + ) + self.py_file = tmp_gcs_file.name + + self.beam_hook.start_python_pipeline( + variables=formatted_pipeline_options, + py_file=self.py_file, + py_options=self.py_options, + py_interpreter=self.py_interpreter, + py_requirements=self.py_requirements, + py_system_site_packages=self.py_system_site_packages, + process_line_callback=process_line_callback, + ) + + if is_dataflow: + self.dataflow_hook.wait_for_done( # pylint: disable=no-value-for-parameter + job_name=dataflow_job_name, + location=self.dataflow_config.location, + job_id=self.dataflow_job_id, + multiple_jobs=False, + ) + + return {"dataflow_job_id": self.dataflow_job_id} + + def on_kill(self) -> None: + if self.dataflow_hook and self.dataflow_job_id: + self.log.info( + "Dataflow job with id: `%s` was requested to be cancelled.", + self.dataflow_job_id, + ) + self.dataflow_hook.cancel_job( + job_id=self.dataflow_job_id, + project_id=self.dataflow_config.project_id, + ) + + +# pylint: disable=too-many-instance-attributes +class BeamRunJavaPipelineOperator(BaseOperator, BeamDataflowMixin): + """ + Launching Apache Beam pipelines written in Java. + + Note that both + ``default_pipeline_options`` and ``pipeline_options`` will be merged to specify pipeline + execution parameter, and ``default_pipeline_options`` is expected to save + high-level pipeline_options, for instances, project and zone information, which + apply to all Apache Beam operators in the DAG. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BeamRunJavaPipelineOperator` + + .. seealso:: + For more detail on Apache Beam have a look at the reference: + https://beam.apache.org/documentation/ + + You need to pass the path to your jar file as a file reference with the ``jar`` + parameter, the jar needs to be a self executing jar (see documentation here: + https://beam.apache.org/documentation/runners/dataflow/#self-executing-jar). + Use ``pipeline_options`` to pass on pipeline_options to your job. + + :param jar: The reference to a self executing Apache Beam jar (templated). + :type jar: str + :param runner: Runner on which pipeline will be run. By default "DirectRunner" is being used. + See: + https://beam.apache.org/documentation/runners/capability-matrix/ + :type runner: str + :param job_class: The name of the Apache Beam pipeline class to be executed, it + is often not the main class configured in the pipeline jar file. + :type job_class: str + :param default_pipeline_options: Map of default job pipeline_options. + :type default_pipeline_options: dict + :param pipeline_options: Map of job specific pipeline_options.The key must be a dictionary. + The value can contain different types: + + * If the value is None, the single option - ``--key`` (without value) will be added. + * If the value is False, this option will be skipped + * If the value is True, the single option - ``--key`` (without value) will be added. + * If the value is list, the many pipeline_options will be added for each key. + If the value is ``['A', 'B']`` and the key is ``key`` then the ``--key=A --key-B`` pipeline_options + will be left + * Other value types will be replaced with the Python textual representation. + + When defining labels (``labels`` option), you can also provide a dictionary. + :type pipeline_options: dict + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Storage if jar is on GCS + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param dataflow_config: Dataflow configuration, used when runner type is set to DataflowRunner + :type dataflow_config: Union[dict, providers.google.cloud.operators.dataflow.DataflowConfiguration] + """ + + template_fields = [ + "jar", + "runner", + "job_class", + "pipeline_options", + "default_pipeline_options", + "dataflow_config", + ] + template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"} + ui_color = "#0273d4" + + @apply_defaults + def __init__( + self, + *, + jar: str, + runner: str = "DirectRunner", + job_class: Optional[str] = None, + default_pipeline_options: Optional[dict] = None, + pipeline_options: Optional[dict] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.jar = jar + self.runner = runner + self.default_pipeline_options = default_pipeline_options or {} + self.pipeline_options = pipeline_options or {} + self.job_class = job_class + self.dataflow_config = dataflow_config or {} + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.dataflow_job_id = None + self.dataflow_hook: Optional[DataflowHook] = None + self.beam_hook: Optional[BeamHook] = None + self._dataflow_job_name: Optional[str] = None + + if ( + self.dataflow_config + and self.runner.lower() != BeamRunnerType.DataflowRunner.lower() + ): + self.log.warning( + "dataflow_config is defined but runner is different than DataflowRunner (%s)", + self.runner, + ) + + def execute(self, context): + """Execute the Apache Beam Pipeline.""" + self.beam_hook = BeamHook(runner=self.runner) + pipeline_options = self.default_pipeline_options.copy() + process_line_callback: Optional[Callable] = None + is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower() + dataflow_job_name: Optional[str] = None + + if isinstance(self.dataflow_config, dict): + self.dataflow_config = DataflowConfiguration(**self.dataflow_config) + + if is_dataflow: + ( + dataflow_job_name, + pipeline_options, + process_line_callback, + ) = self._set_dataflow( + pipeline_options=pipeline_options, job_name_variable_key=None + ) + + pipeline_options.update(self.pipeline_options) + + with ExitStack() as exit_stack: + if self.jar.lower().startswith("gs://"): + gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to) + tmp_gcs_file = exit_stack.enter_context( # pylint: disable=no-member + gcs_hook.provide_file(object_url=self.jar) + ) + self.jar = tmp_gcs_file.name + + if is_dataflow: + is_running = False + if self.dataflow_config.check_if_running != CheckJobRunning.IgnoreJob: + is_running = ( + # The reason for disable=no-value-for-parameter is that project_id parameter is + # required but here is not passed, moreover it cannot be passed here. + # This method is wrapped by @_fallback_to_project_id_from_variables decorator which + # fallback project_id value from variables and raise error if project_id is + # defined both in variables and as parameter (here is already defined in variables) + self.dataflow_hook.is_job_dataflow_running( # pylint: disable=no-value-for-parameter + name=self.dataflow_config.job_name, + variables=pipeline_options, + ) + ) + while ( + is_running + and self.dataflow_config.check_if_running + == CheckJobRunning.WaitForRun + ): + # The reason for disable=no-value-for-parameter is that project_id parameter is + # required but here is not passed, moreover it cannot be passed here. + # This method is wrapped by @_fallback_to_project_id_from_variables decorator which + # fallback project_id value from variables and raise error if project_id is + # defined both in variables and as parameter (here is already defined in variables) + # pylint: disable=no-value-for-parameter + is_running = self.dataflow_hook.is_job_dataflow_running( + name=self.dataflow_config.job_name, + variables=pipeline_options, + ) + if not is_running: + pipeline_options["jobName"] = dataflow_job_name + self.beam_hook.start_java_pipeline( + variables=pipeline_options, + jar=self.jar, + job_class=self.job_class, + process_line_callback=process_line_callback, + ) + self.dataflow_hook.wait_for_done( + job_name=dataflow_job_name, + location=self.dataflow_config.location, + job_id=self.dataflow_job_id, + multiple_jobs=self.dataflow_config.multiple_jobs, + project_id=self.dataflow_config.project_id, + ) + + else: + self.beam_hook.start_java_pipeline( + variables=pipeline_options, + jar=self.jar, + job_class=self.job_class, + process_line_callback=process_line_callback, + ) + + return {"dataflow_job_id": self.dataflow_job_id} + + def on_kill(self) -> None: + if self.dataflow_hook and self.dataflow_job_id: + self.log.info( + "Dataflow job with id: `%s` was requested to be cancelled.", + self.dataflow_job_id, + ) + self.dataflow_hook.cancel_job( + job_id=self.dataflow_job_id, + project_id=self.dataflow_config.project_id, + ) diff --git a/reference/providers/apache/beam/provider.yaml b/reference/providers/apache/beam/provider.yaml new file mode 100644 index 0000000..8634f89 --- /dev/null +++ b/reference/providers/apache/beam/provider.yaml @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-apache-beam +name: Apache Beam +description: | + `Apache Beam `__. + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Apache Beam + external-doc-url: https://beam.apache.org/ + how-to-guide: + - /docs/apache-airflow-providers-apache-beam/operators.rst + tags: [apache] + +operators: + - integration-name: Apache Beam + python-modules: + - airflow.providers.apache.beam.operators.beam + +hooks: + - integration-name: Apache Beam + python-modules: + - airflow.providers.apache.beam.hooks.beam diff --git a/reference/providers/apache/cassandra/CHANGELOG.rst b/reference/providers/apache/cassandra/CHANGELOG.rst new file mode 100644 index 0000000..c95472c --- /dev/null +++ b/reference/providers/apache/cassandra/CHANGELOG.rst @@ -0,0 +1,31 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/apache/cassandra/__init__.py b/reference/providers/apache/cassandra/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/cassandra/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/cassandra/example_dags/__init__.py b/reference/providers/apache/cassandra/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/cassandra/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/cassandra/example_dags/example_cassandra_dag.py b/reference/providers/apache/cassandra/example_dags/example_cassandra_dag.py new file mode 100644 index 0000000..2b90277 --- /dev/null +++ b/reference/providers/apache/cassandra/example_dags/example_cassandra_dag.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG to check if a Cassandra Table and a Records exists +or not using `CassandraTableSensor` and `CassandraRecordSensor`. +""" +from airflow.models import DAG +from airflow.providers.apache.cassandra.sensors.record import CassandraRecordSensor +from airflow.providers.apache.cassandra.sensors.table import CassandraTableSensor +from airflow.utils.dates import days_ago + +args = { + "owner": "Airflow", +} + +with DAG( + dag_id="example_cassandra_operator", + default_args=args, + schedule_interval=None, + start_date=days_ago(2), + tags=["example"], +) as dag: + # [START howto_operator_cassandra_table_sensor] + table_sensor = CassandraTableSensor( + task_id="cassandra_table_sensor", + cassandra_conn_id="cassandra_default", + table="keyspace_name.table_name", + ) + # [END howto_operator_cassandra_table_sensor] + + # [START howto_operator_cassandra_record_sensor] + record_sensor = CassandraRecordSensor( + task_id="cassandra_record_sensor", + cassandra_conn_id="cassandra_default", + table="keyspace_name.table_name", + keys={"p1": "v1", "p2": "v2"}, + ) + # [END howto_operator_cassandra_record_sensor] diff --git a/reference/providers/apache/cassandra/hooks/__init__.py b/reference/providers/apache/cassandra/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/cassandra/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/cassandra/hooks/cassandra.py b/reference/providers/apache/cassandra/hooks/cassandra.py new file mode 100644 index 0000000..9ea859a --- /dev/null +++ b/reference/providers/apache/cassandra/hooks/cassandra.py @@ -0,0 +1,229 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains hook to integrate with Apache Cassandra.""" + +from typing import Any, Dict, Union + +from airflow.hooks.base import BaseHook +from airflow.utils.log.logging_mixin import LoggingMixin +from cassandra.auth import PlainTextAuthProvider +from cassandra.cluster import Cluster, Session +from cassandra.policies import ( + DCAwareRoundRobinPolicy, + RoundRobinPolicy, + TokenAwarePolicy, + WhiteListRoundRobinPolicy, +) + +Policy = Union[ + DCAwareRoundRobinPolicy, + RoundRobinPolicy, + TokenAwarePolicy, + WhiteListRoundRobinPolicy, +] + + +class CassandraHook(BaseHook, LoggingMixin): + """ + Hook used to interact with Cassandra + + Contact points can be specified as a comma-separated string in the 'hosts' + field of the connection. + + Port can be specified in the port field of the connection. + + If SSL is enabled in Cassandra, pass in a dict in the extra field as kwargs for + ``ssl.wrap_socket()``. For example:: + + { + 'ssl_options' : { + 'ca_certs' : PATH_TO_CA_CERTS + } + } + + Default load balancing policy is RoundRobinPolicy. To specify a different + LB policy:: + + - DCAwareRoundRobinPolicy + { + 'load_balancing_policy': 'DCAwareRoundRobinPolicy', + 'load_balancing_policy_args': { + 'local_dc': LOCAL_DC_NAME, // optional + 'used_hosts_per_remote_dc': SOME_INT_VALUE, // optional + } + } + - WhiteListRoundRobinPolicy + { + 'load_balancing_policy': 'WhiteListRoundRobinPolicy', + 'load_balancing_policy_args': { + 'hosts': ['HOST1', 'HOST2', 'HOST3'] + } + } + - TokenAwarePolicy + { + 'load_balancing_policy': 'TokenAwarePolicy', + 'load_balancing_policy_args': { + 'child_load_balancing_policy': CHILD_POLICY_NAME, // optional + 'child_load_balancing_policy_args': { ... } // optional + } + } + + For details of the Cluster config, see cassandra.cluster. + """ + + conn_name_attr = "cassandra_conn_id" + default_conn_name = "cassandra_default" + conn_type = "cassandra" + hook_name = "Cassandra" + + def __init__(self, cassandra_conn_id: str = default_conn_name): + super().__init__() + conn = self.get_connection(cassandra_conn_id) + + conn_config = {} + if conn.host: + conn_config["contact_points"] = conn.host.split(",") + + if conn.port: + conn_config["port"] = int(conn.port) + + if conn.login: + conn_config["auth_provider"] = PlainTextAuthProvider( + username=conn.login, password=conn.password + ) + + policy_name = conn.extra_dejson.get("load_balancing_policy", None) + policy_args = conn.extra_dejson.get("load_balancing_policy_args", {}) + lb_policy = self.get_lb_policy(policy_name, policy_args) + if lb_policy: + conn_config["load_balancing_policy"] = lb_policy + + cql_version = conn.extra_dejson.get("cql_version", None) + if cql_version: + conn_config["cql_version"] = cql_version + + ssl_options = conn.extra_dejson.get("ssl_options", None) + if ssl_options: + conn_config["ssl_options"] = ssl_options + + protocol_version = conn.extra_dejson.get("protocol_version", None) + if protocol_version: + conn_config["protocol_version"] = protocol_version + + self.cluster = Cluster(**conn_config) + self.keyspace = conn.schema + self.session = None + + def get_conn(self) -> Session: + """Returns a cassandra Session object""" + if self.session and not self.session.is_shutdown: + return self.session + self.session = self.cluster.connect(self.keyspace) + return self.session + + def get_cluster(self) -> Cluster: + """Returns Cassandra cluster.""" + return self.cluster + + def shutdown_cluster(self) -> None: + """Closes all sessions and connections associated with this Cluster.""" + if not self.cluster.is_shutdown: + self.cluster.shutdown() + + @staticmethod + def get_lb_policy(policy_name: str, policy_args: Dict[str, Any]) -> Policy: + """ + Creates load balancing policy. + + :param policy_name: Name of the policy to use. + :type policy_name: str + :param policy_args: Parameters for the policy. + :type policy_args: Dict + """ + if policy_name == "DCAwareRoundRobinPolicy": + local_dc = policy_args.get("local_dc", "") + used_hosts_per_remote_dc = int( + policy_args.get("used_hosts_per_remote_dc", 0) + ) + return DCAwareRoundRobinPolicy(local_dc, used_hosts_per_remote_dc) + + if policy_name == "WhiteListRoundRobinPolicy": + hosts = policy_args.get("hosts") + if not hosts: + raise Exception("Hosts must be specified for WhiteListRoundRobinPolicy") + return WhiteListRoundRobinPolicy(hosts) + + if policy_name == "TokenAwarePolicy": + allowed_child_policies = ( + "RoundRobinPolicy", + "DCAwareRoundRobinPolicy", + "WhiteListRoundRobinPolicy", + ) + child_policy_name = policy_args.get( + "child_load_balancing_policy", "RoundRobinPolicy" + ) + child_policy_args = policy_args.get("child_load_balancing_policy_args", {}) + if child_policy_name not in allowed_child_policies: + return TokenAwarePolicy(RoundRobinPolicy()) + else: + child_policy = CassandraHook.get_lb_policy( + child_policy_name, child_policy_args + ) + return TokenAwarePolicy(child_policy) + + # Fallback to default RoundRobinPolicy + return RoundRobinPolicy() + + def table_exists(self, table: str) -> bool: + """ + Checks if a table exists in Cassandra + + :param table: Target Cassandra table. + Use dot notation to target a specific keyspace. + :type table: str + """ + keyspace = self.keyspace + if "." in table: + keyspace, table = table.split(".", 1) + cluster_metadata = self.get_conn().cluster.metadata + return ( + keyspace in cluster_metadata.keyspaces + and table in cluster_metadata.keyspaces[keyspace].tables + ) + + def record_exists(self, table: str, keys: Dict[str, str]) -> bool: + """ + Checks if a record exists in Cassandra + + :param table: Target Cassandra table. + Use dot notation to target a specific keyspace. + :type table: str + :param keys: The keys and their values to check the existence. + :type keys: dict + """ + keyspace = self.keyspace + if "." in table: + keyspace, table = table.split(".", 1) + ks_str = " AND ".join(f"{key}=%({key})s" for key in keys.keys()) + query = f"SELECT * FROM {keyspace}.{table} WHERE {ks_str}" + try: + result = self.get_conn().execute(query, keys) + return result.one() is not None + except Exception: # pylint: disable=broad-except + return False diff --git a/reference/providers/apache/cassandra/provider.yaml b/reference/providers/apache/cassandra/provider.yaml new file mode 100644 index 0000000..dd36233 --- /dev/null +++ b/reference/providers/apache/cassandra/provider.yaml @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-apache-cassandra +name: Apache Cassandra +description: | + `Apache Cassandra `__. + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Apache Cassandra + external-doc-url: http://cassandra.apache.org/ + how-to-guide: + - /docs/apache-airflow-providers-apache-cassandra/operators.rst + logo: /integration-logos/apache/cassandra-3.png + tags: [apache] + +sensors: + - integration-name: Apache Cassandra + python-modules: + - airflow.providers.apache.cassandra.sensors.record + - airflow.providers.apache.cassandra.sensors.table + +hooks: + - integration-name: Apache Cassandra + python-modules: + - airflow.providers.apache.cassandra.hooks.cassandra + +hook-class-names: + - airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook diff --git a/reference/providers/apache/cassandra/sensors/__init__.py b/reference/providers/apache/cassandra/sensors/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/cassandra/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/cassandra/sensors/record.py b/reference/providers/apache/cassandra/sensors/record.py new file mode 100644 index 0000000..e42a31a --- /dev/null +++ b/reference/providers/apache/cassandra/sensors/record.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains sensor that check the existence +of a record in a Cassandra cluster. +""" + +from typing import Any, Dict + +from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class CassandraRecordSensor(BaseSensorOperator): + """ + Checks for the existence of a record in a Cassandra cluster. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CassandraRecordSensor` + + For example, if you want to wait for a record that has values 'v1' and 'v2' for each + primary keys 'p1' and 'p2' to be populated in keyspace 'k' and table 't', + instantiate it as follows: + + >>> cassandra_sensor = CassandraRecordSensor(table="k.t", + ... keys={"p1": "v1", "p2": "v2"}, + ... cassandra_conn_id="cassandra_default", + ... task_id="cassandra_sensor") + + :param table: Target Cassandra table. + Use dot notation to target a specific keyspace. + :type table: str + :param keys: The keys and their values to be monitored + :type keys: dict + :param cassandra_conn_id: The connection ID to use + when connecting to Cassandra cluster + :type cassandra_conn_id: str + """ + + template_fields = ("table", "keys") + + @apply_defaults + def __init__( + self, *, table: str, keys: Dict[str, str], cassandra_conn_id: str, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + self.cassandra_conn_id = cassandra_conn_id + self.table = table + self.keys = keys + + def poke(self, context: Dict[str, str]) -> bool: + self.log.info("Sensor check existence of record: %s", self.keys) + hook = CassandraHook(self.cassandra_conn_id) + return hook.record_exists(self.table, self.keys) diff --git a/reference/providers/apache/cassandra/sensors/table.py b/reference/providers/apache/cassandra/sensors/table.py new file mode 100644 index 0000000..3a4818f --- /dev/null +++ b/reference/providers/apache/cassandra/sensors/table.py @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This module contains sensor that check the existence +of a table in a Cassandra cluster. +""" + +from typing import Any, Dict + +from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class CassandraTableSensor(BaseSensorOperator): + """ + Checks for the existence of a table in a Cassandra cluster. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CassandraTableSensor` + + + For example, if you want to wait for a table called 't' to be created + in a keyspace 'k', instantiate it as follows: + + >>> cassandra_sensor = CassandraTableSensor(table="k.t", + ... cassandra_conn_id="cassandra_default", + ... task_id="cassandra_sensor") + + :param table: Target Cassandra table. + Use dot notation to target a specific keyspace. + :type table: str + :param cassandra_conn_id: The connection ID to use + when connecting to Cassandra cluster + :type cassandra_conn_id: str + """ + + template_fields = ("table",) + + @apply_defaults + def __init__(self, *, table: str, cassandra_conn_id: str, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.cassandra_conn_id = cassandra_conn_id + self.table = table + + def poke(self, context: Dict[Any, Any]) -> bool: + self.log.info("Sensor check existence of table: %s", self.table) + hook = CassandraHook(self.cassandra_conn_id) + return hook.table_exists(self.table) diff --git a/reference/providers/apache/druid/CHANGELOG.rst b/reference/providers/apache/druid/CHANGELOG.rst new file mode 100644 index 0000000..9a081c3 --- /dev/null +++ b/reference/providers/apache/druid/CHANGELOG.rst @@ -0,0 +1,44 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.1.0 +..... + +Features +~~~~~~~~ + +* ``Refactor SQL/BigQuery/Qubole/Druid Check operators (#12677)`` + +Bugfixes +~~~~~~~~ + +* ``Bugfix: DruidOperator fails to submit ingestion tasks (#14418)`` + +1.0.1 +..... + +Updated documentation and readme files. + + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/apache/druid/__init__.py b/reference/providers/apache/druid/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/druid/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/druid/hooks/__init__.py b/reference/providers/apache/druid/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/druid/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/druid/hooks/druid.py b/reference/providers/apache/druid/hooks/druid.py new file mode 100644 index 0000000..84784d4 --- /dev/null +++ b/reference/providers/apache/druid/hooks/druid.py @@ -0,0 +1,199 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +from typing import Any, Dict, Iterable, Optional, Tuple + +import requests +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.hooks.dbapi import DbApiHook +from pydruid.db import connect + + +class DruidHook(BaseHook): + """ + Connection to Druid overlord for ingestion + + To connect to a Druid cluster that is secured with the druid-basic-security + extension, add the username and password to the druid ingestion connection. + + :param druid_ingest_conn_id: The connection id to the Druid overlord machine + which accepts index jobs + :type druid_ingest_conn_id: str + :param timeout: The interval between polling + the Druid job for the status of the ingestion job. + Must be greater than or equal to 1 + :type timeout: int + :param max_ingestion_time: The maximum ingestion time before assuming the job failed + :type max_ingestion_time: int + """ + + def __init__( + self, + druid_ingest_conn_id: str = "druid_ingest_default", + timeout: int = 1, + max_ingestion_time: Optional[int] = None, + ) -> None: + + super().__init__() + self.druid_ingest_conn_id = druid_ingest_conn_id + self.timeout = timeout + self.max_ingestion_time = max_ingestion_time + self.header = {"content-type": "application/json"} + + if self.timeout < 1: + raise ValueError("Druid timeout should be equal or greater than 1") + + def get_conn_url(self) -> str: + """Get Druid connection url""" + conn = self.get_connection(self.druid_ingest_conn_id) + host = conn.host + port = conn.port + conn_type = "http" if not conn.conn_type else conn.conn_type + endpoint = conn.extra_dejson.get("endpoint", "") + return f"{conn_type}://{host}:{port}/{endpoint}" + + def get_auth(self) -> Optional[requests.auth.HTTPBasicAuth]: + """ + Return username and password from connections tab as requests.auth.HTTPBasicAuth object. + + If these details have not been set then returns None. + """ + conn = self.get_connection(self.druid_ingest_conn_id) + user = conn.login + password = conn.password + if user is not None and password is not None: + return requests.auth.HTTPBasicAuth(user, password) + else: + return None + + def submit_indexing_job(self, json_index_spec: Dict[str, Any]) -> None: + """Submit Druid ingestion job""" + url = self.get_conn_url() + + self.log.info("Druid ingestion spec: %s", json_index_spec) + req_index = requests.post( + url, data=json_index_spec, headers=self.header, auth=self.get_auth() + ) + if req_index.status_code != 200: + raise AirflowException( + f"Did not get 200 when submitting the Druid job to {url}" + ) + + req_json = req_index.json() + # Wait until the job is completed + druid_task_id = req_json["task"] + self.log.info("Druid indexing task-id: %s", druid_task_id) + + running = True + + sec = 0 + while running: + req_status = requests.get( + f"{url}/{druid_task_id}/status", auth=self.get_auth() + ) + + self.log.info("Job still running for %s seconds...", sec) + + if self.max_ingestion_time and sec > self.max_ingestion_time: + # ensure that the job gets killed if the max ingestion time is exceeded + requests.post(f"{url}/{druid_task_id}/shutdown", auth=self.get_auth()) + raise AirflowException( + "Druid ingestion took more than " + f"{self.max_ingestion_time} seconds" + ) + + time.sleep(self.timeout) + + sec += self.timeout + + status = req_status.json()["status"]["status"] + if status == "RUNNING": + running = True + elif status == "SUCCESS": + running = False # Great success! + elif status == "FAILED": + raise AirflowException( + "Druid indexing job failed, check console for more info" + ) + else: + raise AirflowException(f"Could not get status of the job, got {status}") + + self.log.info("Successful index") + + +class DruidDbApiHook(DbApiHook): + """ + Interact with Druid broker + + This hook is purely for users to query druid broker. + For ingestion, please use druidHook. + """ + + conn_name_attr = "druid_broker_conn_id" + default_conn_name = "druid_broker_default" + conn_type = "druid" + hook_name = "Druid" + supports_autocommit = False + + def get_conn(self) -> connect: + """Establish a connection to druid broker.""" + conn = self.get_connection(self.conn_name_attr) + druid_broker_conn = connect( + host=conn.host, + port=conn.port, + path=conn.extra_dejson.get("endpoint", "/druid/v2/sql"), + scheme=conn.extra_dejson.get("schema", "http"), + user=conn.login, + password=conn.password, + ) + self.log.info( + "Get the connection to druid broker on %s using user %s", + conn.host, + conn.login, + ) + return druid_broker_conn + + def get_uri(self) -> str: + """ + Get the connection uri for druid broker. + + e.g: druid://localhost:8082/druid/v2/sql/ + """ + conn = self.get_connection(getattr(self, self.conn_name_attr)) + host = conn.host + if conn.port is not None: + host += f":{conn.port}" + conn_type = "druid" if not conn.conn_type else conn.conn_type + endpoint = conn.extra_dejson.get("endpoint", "druid/v2/sql") + return f"{conn_type}://{host}/{endpoint}" + + def set_autocommit(self, conn: connect, autocommit: bool) -> NotImplemented: + raise NotImplementedError() + + def insert_rows( + self, + table: str, + rows: Iterable[Tuple[str]], + target_fields: Optional[Iterable[str]] = None, + commit_every: int = 1000, + replace: bool = False, + **kwargs: Any, + ) -> NotImplemented: + raise NotImplementedError() diff --git a/reference/providers/apache/druid/operators/__init__.py b/reference/providers/apache/druid/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/druid/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/druid/operators/druid.py b/reference/providers/apache/druid/operators/druid.py new file mode 100644 index 0000000..8d63836 --- /dev/null +++ b/reference/providers/apache/druid/operators/druid.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Optional + +from airflow.models import BaseOperator +from airflow.providers.apache.druid.hooks.druid import DruidHook +from airflow.utils.decorators import apply_defaults + + +class DruidOperator(BaseOperator): + """ + Allows to submit a task directly to druid + + :param json_index_file: The filepath to the druid index specification + :type json_index_file: str + :param druid_ingest_conn_id: The connection id of the Druid overlord which + accepts index jobs + :type druid_ingest_conn_id: str + """ + + template_fields = ("json_index_file",) + template_ext = (".json",) + + @apply_defaults + def __init__( + self, + *, + json_index_file: str, + druid_ingest_conn_id: str = "druid_ingest_default", + max_ingestion_time: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.json_index_file = json_index_file + self.conn_id = druid_ingest_conn_id + self.max_ingestion_time = max_ingestion_time + + def execute(self, context: Dict[Any, Any]) -> None: + hook = DruidHook( + druid_ingest_conn_id=self.conn_id, + max_ingestion_time=self.max_ingestion_time, + ) + self.log.info("Submitting %s", self.json_index_file) + hook.submit_indexing_job(self.json_index_file) diff --git a/reference/providers/apache/druid/operators/druid_check.py b/reference/providers/apache/druid/operators/druid_check.py new file mode 100644 index 0000000..65b1c5e --- /dev/null +++ b/reference/providers/apache/druid/operators/druid_check.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import warnings + +from airflow.operators.sql import SQLCheckOperator + + +class DruidCheckOperator(SQLCheckOperator): + """ + This class is deprecated. + Please use `airflow.operators.sql.SQLCheckOperator`. + """ + + def __init__(self, druid_broker_conn_id: str = "druid_broker_default", **kwargs): + warnings.warn( + """This class is deprecated. + Please use `airflow.operators.sql.SQLCheckOperator`.""", + DeprecationWarning, + stacklevel=3, + ) + super().__init__(conn_id=druid_broker_conn_id, **kwargs) diff --git a/reference/providers/apache/druid/provider.yaml b/reference/providers/apache/druid/provider.yaml new file mode 100644 index 0000000..4e1eab5 --- /dev/null +++ b/reference/providers/apache/druid/provider.yaml @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-apache-druid +name: Apache Druid +description: | + `Apache Druid `__. + +versions: + - 1.1.0 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Apache Druid + external-doc-url: https://druid.apache.org/ + logo: /integration-logos/apache/druid-1.png + tags: [apache] + +operators: + - integration-name: Apache Druid + python-modules: + - airflow.providers.apache.druid.operators.druid + - airflow.providers.apache.druid.operators.druid_check + +hooks: + - integration-name: Apache Druid + python-modules: + - airflow.providers.apache.druid.hooks.druid + +hook-class-names: + - airflow.providers.apache.druid.hooks.druid.DruidDbApiHook + +transfers: + - source-integration-name: Apache Hive + target-integration-name: Apache Druid + python-module: airflow.providers.apache.druid.transfers.hive_to_druid diff --git a/reference/providers/apache/druid/transfers/__init__.py b/reference/providers/apache/druid/transfers/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/druid/transfers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/druid/transfers/hive_to_druid.py b/reference/providers/apache/druid/transfers/hive_to_druid.py new file mode 100644 index 0000000..8e0649d --- /dev/null +++ b/reference/providers/apache/druid/transfers/hive_to_druid.py @@ -0,0 +1,256 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains operator to move data from Hive to Druid.""" + +from typing import Any, Dict, List, Optional + +from airflow.models import BaseOperator +from airflow.providers.apache.druid.hooks.druid import DruidHook +from airflow.providers.apache.hive.hooks.hive import HiveCliHook, HiveMetastoreHook +from airflow.utils.decorators import apply_defaults + +LOAD_CHECK_INTERVAL = 5 +DEFAULT_TARGET_PARTITION_SIZE = 5000000 + + +class HiveToDruidOperator(BaseOperator): + """ + Moves data from Hive to Druid, [del]note that for now the data is loaded + into memory before being pushed to Druid, so this operator should + be used for smallish amount of data.[/del] + + :param sql: SQL query to execute against the Druid database. (templated) + :type sql: str + :param druid_data# the datasource you want to ingest into in druid + :type druid_data# str + :param ts_dim: the timestamp dimension + :type ts_dim: str + :param metric_spec: the metrics you want to define for your data + :type metric_spec: list + :param hive_cli_conn_id: the hive connection id + :type hive_cli_conn_id: str + :param druid_ingest_conn_id: the druid ingest connection id + :type druid_ingest_conn_id: str + :param metastore_conn_id: the metastore connection id + :type metastore_conn_id: str + :param hadoop_dependency_coordinates: list of coordinates to squeeze + int the ingest json + :type hadoop_dependency_coordinates: list[str] + :param intervals: list of time intervals that defines segments, + this is passed as is to the json object. (templated) + :type intervals: list + :param num_shards: Directly specify the number of shards to create. + :type num_shards: float + :param target_partition_size: Target number of rows to include in a partition, + :type target_partition_size: int + :param query_granularity: The minimum granularity to be able to query results at and the granularity of + the data inside the segment. E.g. a value of "minute" will mean that data is aggregated at minutely + granularity. That is, if there are collisions in the tuple (minute(timestamp), dimensions), then it + will aggregate values together using the aggregators instead of storing individual rows. + A granularity of 'NONE' means millisecond granularity. + :type query_granularity: str + :param segment_granularity: The granularity to create time chunks at. Multiple segments can be created per + time chunk. For example, with 'DAY' segmentGranularity, the events of the same day fall into the + same time chunk which can be optionally further partitioned into multiple segments based on other + configurations and input size. + :type segment_granularity: str + :param hive_tblproperties: additional properties for tblproperties in + hive for the staging table + :type hive_tblproperties: dict + :param job_properties: additional properties for job + :type job_properties: dict + """ + + template_fields = ("sql", "intervals") + template_ext = (".sql",) + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + sql: str, + druid_data# str, + ts_dim: str, + metric_spec: Optional[List[Any]] = None, + hive_cli_conn_id: str = "hive_cli_default", + druid_ingest_conn_id: str = "druid_ingest_default", + metastore_conn_id: str = "metastore_default", + hadoop_dependency_coordinates: Optional[List[str]] = None, + intervals: Optional[List[Any]] = None, + num_shards: float = -1, + target_partition_size: int = -1, + query_granularity: str = "NONE", + segment_granularity: str = "DAY", + hive_tblproperties: Optional[Dict[Any, Any]] = None, + job_properties: Optional[Dict[Any, Any]] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.sql = sql + self.druid_datasource = druid_datasource + self.ts_dim = ts_dim + self.intervals = intervals or ["{{ ds }}/{{ tomorrow_ds }}"] + self.num_shards = num_shards + self.target_partition_size = target_partition_size + self.query_granularity = query_granularity + self.segment_granularity = segment_granularity + self.metric_spec = metric_spec or [{"name": "count", "type": "count"}] + self.hive_cli_conn_id = hive_cli_conn_id + self.hadoop_dependency_coordinates = hadoop_dependency_coordinates + self.druid_ingest_conn_id = druid_ingest_conn_id + self.metastore_conn_id = metastore_conn_id + self.hive_tblproperties = hive_tblproperties or {} + self.job_properties = job_properties + + def execute(self, context: Dict[str, Any]) -> None: + hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) + self.log.info("Extracting data from Hive") + hive_table = "druid." + context["task_instance_key_str"].replace(".", "_") + sql = self.sql.strip().strip(";") + tblproperties = "".join( + [f", '{k}' = '{v}'" for k, v in self.hive_tblproperties.items()] + ) + hql = f"""\ + SET mapred.output.compress=false; + SET hive.exec.compress.output=false; + DROP TABLE IF EXISTS {hive_table}; + CREATE TABLE {hive_table} + ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + STORED AS TEXTFILE + TBLPROPERTIES ('serialization.null.format' = ''{tblproperties}) + AS + {sql} + """ + self.log.info("Running command:\n %s", hql) + hive.run_cli(hql) + + meta_hook = HiveMetastoreHook(self.metastore_conn_id) + + # Get the Hive table and extract the columns + table = meta_hook.get_table(hive_table) + columns = [col.name for col in table.sd.cols] + + # Get the path on hdfs + static_path = meta_hook.get_table(hive_table).sd.location + + druid = DruidHook(druid_ingest_conn_id=self.druid_ingest_conn_id) + + try: + index_spec = self.construct_ingest_query( + static_path=static_path, + columns=columns, + ) + + self.log.info("Inserting rows into Druid, hdfs path: %s", static_path) + + druid.submit_indexing_job(index_spec) + + self.log.info("Load seems to have succeeded!") + finally: + self.log.info("Cleaning up by dropping the temp Hive table %s", hive_table) + hql = f"DROP TABLE IF EXISTS {hive_table}" + hive.run_cli(hql) + + def construct_ingest_query( + self, static_path: str, columns: List[str] + ) -> Dict[str, Any]: + """ + Builds an ingest query for an HDFS TSV load. + + :param static_path: The path on hdfs where the data is + :type static_path: str + :param columns: List of all the columns that are available + :type columns: list + """ + # backward compatibility for num_shards, + # but target_partition_size is the default setting + # and overwrites the num_shards + num_shards = self.num_shards + target_partition_size = self.target_partition_size + if self.target_partition_size == -1: + if self.num_shards == -1: + target_partition_size = DEFAULT_TARGET_PARTITION_SIZE + else: + num_shards = -1 + + metric_names = [ + m["fieldName"] for m in self.metric_spec if m["type"] != "count" + ] + + # Take all the columns, which are not the time dimension + # or a metric, as the dimension columns + dimensions = [c for c in columns if c not in metric_names and c != self.ts_dim] + + ingest_query_dict: Dict[str, Any] = { + "type": "index_hadoop", + "spec": { + "dataSchema": { + "metricsSpec": self.metric_spec, + "granularitySpec": { + "queryGranularity": self.query_granularity, + "intervals": self.intervals, + "type": "uniform", + "segmentGranularity": self.segment_granularity, + }, + "parser": { + "type": "string", + "parseSpec": { + "columns": columns, + "dimensionsSpec": { + "dimensionExclusions": [], + "dimensions": dimensions, # list of names + "spatialDimensions": [], + }, + "timestampSpec": {"column": self.ts_dim, "format": "auto"}, + "format": "tsv", + }, + }, + "dataSource": self.druid_datasource, + }, + "tuningConfig": { + "type": "hadoop", + "jobProperties": { + "mapreduce.job.user.classpath.first": "false", + "mapreduce.map.output.compress": "false", + "mapreduce.output.fileoutputformat.compress": "false", + }, + "partitionsSpec": { + "type": "hashed", + "targetPartitionSize": target_partition_size, + "numShards": num_shards, + }, + }, + "ioConfig": { + "inputSpec": {"paths": static_path, "type": "static"}, + "type": "hadoop", + }, + }, + } + + if self.job_properties: + ingest_query_dict["spec"]["tuningConfig"]["jobProperties"].update( + self.job_properties + ) + + if self.hadoop_dependency_coordinates: + ingest_query_dict[ + "hadoopDependencyCoordinates" + ] = self.hadoop_dependency_coordinates + + return ingest_query_dict diff --git a/reference/providers/apache/hdfs/CHANGELOG.rst b/reference/providers/apache/hdfs/CHANGELOG.rst new file mode 100644 index 0000000..c95472c --- /dev/null +++ b/reference/providers/apache/hdfs/CHANGELOG.rst @@ -0,0 +1,31 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/apache/hdfs/__init__.py b/reference/providers/apache/hdfs/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/hdfs/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/hdfs/hooks/__init__.py b/reference/providers/apache/hdfs/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/hdfs/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/hdfs/hooks/hdfs.py b/reference/providers/apache/hdfs/hooks/hdfs.py new file mode 100644 index 0000000..9c1fe15 --- /dev/null +++ b/reference/providers/apache/hdfs/hooks/hdfs.py @@ -0,0 +1,123 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for HDFS operations""" +from typing import Any, Optional + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + +try: + from snakebite.client import ( # pylint: disable=syntax-error + AutoConfigClient, + Client, + HAClient, + Namenode, + ) + + snakebite_loaded = True +except ImportError: + snakebite_loaded = False + + +class HDFSHookException(AirflowException): + """Exception specific for HDFS""" + + +class HDFSHook(BaseHook): + """ + Interact with HDFS. This class is a wrapper around the snakebite library. + + :param hdfs_conn_id: Connection id to fetch connection info + :type hdfs_conn_id: str + :param proxy_user: effective user for HDFS operations + :type proxy_user: str + :param autoconfig: use snakebite's automatically configured client + :type autoconfig: bool + """ + + conn_name_attr = "hdfs_conn_id" + default_conn_name = "hdfs_default" + conn_type = "hdfs" + hook_name = "HDFS" + + def __init__( + self, + hdfs_conn_id: str = "hdfs_default", + proxy_user: Optional[str] = None, + autoconfig: bool = False, + ): + super().__init__() + if not snakebite_loaded: + raise ImportError( + "This HDFSHook implementation requires snakebite, but " + "snakebite is not compatible with Python 3 " + "(as of August 2015). Please use Python 2 if you require " + "this hook -- or help by submitting a PR!" + ) + self.hdfs_conn_id = hdfs_conn_id + self.proxy_user = proxy_user + self.autoconfig = autoconfig + + def get_conn(self) -> Any: + """Returns a snakebite HDFSClient object.""" + # When using HAClient, proxy_user must be the same, so is ok to always + # take the first. + effective_user = self.proxy_user + autoconfig = self.autoconfig + use_sasl = conf.get("core", "security") == "kerberos" + + try: + connections = self.get_connections(self.hdfs_conn_id) + + if not effective_user: + effective_user = connections[0].login + if not autoconfig: + autoconfig = connections[0].extra_dejson.get("autoconfig", False) + hdfs_namenode_principal = connections[0].extra_dejson.get( + "hdfs_namenode_principal" + ) + except AirflowException: + if not autoconfig: + raise + + if autoconfig: + # will read config info from $HADOOP_HOME conf files + client = AutoConfigClient(effective_user=effective_user, use_sasl=use_sasl) + elif len(connections) == 1: + client = Client( + connections[0].host, + connections[0].port, + effective_user=effective_user, + use_sasl=use_sasl, + hdfs_namenode_principal=hdfs_namenode_principal, + ) + elif len(connections) > 1: + name_node = [Namenode(conn.host, conn.port) for conn in connections] + client = HAClient( + name_node, + effective_user=effective_user, + use_sasl=use_sasl, + hdfs_namenode_principal=hdfs_namenode_principal, + ) + else: + raise HDFSHookException( + "conn_id doesn't exist in the repository and autoconfig is not specified" + ) + + return client diff --git a/reference/providers/apache/hdfs/hooks/webhdfs.py b/reference/providers/apache/hdfs/hooks/webhdfs.py new file mode 100644 index 0000000..0358b43 --- /dev/null +++ b/reference/providers/apache/hdfs/hooks/webhdfs.py @@ -0,0 +1,163 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Web HDFS""" +import logging +import socket +from typing import Any, Optional + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.models.connection import Connection +from hdfs import HdfsError, InsecureClient + +log = logging.getLogger(__name__) + +_kerberos_security_mode = conf.get("core", "security") == "kerberos" +if _kerberos_security_mode: + try: + from hdfs.ext.kerberos import ( # pylint: disable=ungrouped-imports + KerberosClient, + ) + except ImportError: + log.error("Could not load the Kerberos extension for the WebHDFSHook.") + raise + + +class AirflowWebHDFSHookException(AirflowException): + """Exception specific for WebHDFS hook""" + + +class WebHDFSHook(BaseHook): + """ + Interact with HDFS. This class is a wrapper around the hdfscli library. + + :param webhdfs_conn_id: The connection id for the webhdfs client to connect to. + :type webhdfs_conn_id: str + :param proxy_user: The user used to authenticate. + :type proxy_user: str + """ + + def __init__( + self, webhdfs_conn_id: str = "webhdfs_default", proxy_user: Optional[str] = None + ): + super().__init__() + self.webhdfs_conn_id = webhdfs_conn_id + self.proxy_user = proxy_user + + def get_conn(self) -> Any: + """ + Establishes a connection depending on the security mode set via config or environment variable. + :return: a hdfscli InsecureClient or KerberosClient object. + :rtype: hdfs.InsecureClient or hdfs.ext.kerberos.KerberosClient + """ + connection = self._find_valid_server() + if connection is None: + raise AirflowWebHDFSHookException("Failed to locate the valid server.") + return connection + + def _find_valid_server(self) -> Any: + connections = self.get_connections(self.webhdfs_conn_id) + for connection in connections: + host_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.log.info( + "Trying to connect to %s:%s", connection.host, connection.port + ) + try: + conn_check = host_socket.connect_ex((connection.host, connection.port)) + if conn_check == 0: + self.log.info("Trying namenode %s", connection.host) + client = self._get_client(connection) + client.status("/") + self.log.info("Using namenode %s for hook", connection.host) + host_socket.close() + return client + else: + self.log.error( + "Could not connect to %s:%s", connection.host, connection.port + ) + host_socket.close() + except HdfsError as hdfs_error: + self.log.error( + "Read operation on namenode %s failed with error: %s", + connection.host, + hdfs_error, + ) + return None + + def _get_client(self, connection: Connection) -> Any: + connection_str = f"http://{connection.host}:{connection.port}" + + if _kerberos_security_mode: + client = KerberosClient(connection_str) + else: + proxy_user = self.proxy_user or connection.login + client = InsecureClient(connection_str, user=proxy_user) + + return client + + def check_for_path(self, hdfs_path: str) -> bool: + """ + Check for the existence of a path in HDFS by querying FileStatus. + + :param hdfs_path: The path to check. + :type hdfs_path: str + :return: True if the path exists and False if not. + :rtype: bool + """ + conn = self.get_conn() + + status = conn.status(hdfs_path, strict=False) + return bool(status) + + def load_file( + self, + # str, + destination: str, + overwrite: bool = True, + parallelism: int = 1, + **kwargs: Any, + ) -> None: + r""" + Uploads a file to HDFS. + + :param # Local path to file or folder. + If it's a folder, all the files inside of it will be uploaded. + .. note:: This implies that folders empty of files will not be created remotely. + + :type # str + :param destination: PTarget HDFS path. + If it already exists and is a directory, files will be uploaded inside. + :type destination: str + :param overwrite: Overwrite any existing file or directory. + :type overwrite: bool + :param parallelism: Number of threads to use for parallelization. + A value of `0` (or negative) uses as many threads as there are files. + :type parallelism: int + :param kwargs: Keyword arguments forwarded to :meth:`hdfs.client.Client.upload`. + """ + conn = self.get_conn() + + conn.upload( + hdfs_path=destination, + local_path=source, + overwrite=overwrite, + n_threads=parallelism, + **kwargs, + ) + self.log.debug("Uploaded file %s to %s", source, destination) diff --git a/reference/providers/apache/hdfs/provider.yaml b/reference/providers/apache/hdfs/provider.yaml new file mode 100644 index 0000000..afba361 --- /dev/null +++ b/reference/providers/apache/hdfs/provider.yaml @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-apache-hdfs +name: Apache HDFS +description: | + `Hadoop Distributed File System (HDFS) `__ + and `WebHDFS `__. + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Hadoop Distributed File System (HDFS) + external-doc-url: https://hadoop.apache.org/docs/r1.2.1/hdfs_design.html + how-to-guide: + - /docs/apache-airflow-providers-apache-hdfs/operators.rst + logo: /integration-logos/apache/hadoop.png + tags: [apache] + - integration-name: WebHDFS + external-doc-url: https://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-hdfs/WebHDFS.html + logo: /integration-logos/apache/hadoop.png + tags: [apache] + +sensors: + - integration-name: Hadoop Distributed File System (HDFS) + python-modules: + - airflow.providers.apache.hdfs.sensors.hdfs + - integration-name: WebHDFS + python-modules: + - airflow.providers.apache.hdfs.sensors.web_hdfs + +hooks: + - integration-name: Hadoop Distributed File System (HDFS) + python-modules: + - airflow.providers.apache.hdfs.hooks.hdfs + - integration-name: WebHDFS + python-modules: + - airflow.providers.apache.hdfs.hooks.webhdfs + +hook-class-names: + - airflow.providers.apache.hdfs.hooks.hdfs.HDFSHook diff --git a/reference/providers/apache/hdfs/sensors/__init__.py b/reference/providers/apache/hdfs/sensors/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/hdfs/sensors/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/hdfs/sensors/hdfs.py b/reference/providers/apache/hdfs/sensors/hdfs.py new file mode 100644 index 0000000..9dc9254 --- /dev/null +++ b/reference/providers/apache/hdfs/sensors/hdfs.py @@ -0,0 +1,211 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +import re +import sys +from typing import Any, Dict, List, Optional, Pattern, Type + +from airflow import settings +from airflow.providers.apache.hdfs.hooks.hdfs import HDFSHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + +log = logging.getLogger(__name__) + + +class HdfsSensor(BaseSensorOperator): + """ + Waits for a file or folder to land in HDFS + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:HdfsSensor` + """ + + template_fields = ("filepath",) + ui_color = settings.WEB_COLORS["LIGHTBLUE"] + + @apply_defaults + def __init__( + self, + *, + filepath: str, + hdfs_conn_id: str = "hdfs_default", + ignored_ext: Optional[List[str]] = None, + ignore_copying: bool = True, + file_size: Optional[int] = None, + hook: Type[HDFSHook] = HDFSHook, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if ignored_ext is None: + ignored_ext = ["_COPYING_"] + self.filepath = filepath + self.hdfs_conn_id = hdfs_conn_id + self.file_size = file_size + self.ignored_ext = ignored_ext + self.ignore_copying = ignore_copying + self.hook = hook + + @staticmethod + def filter_for_filesize( + result: List[Dict[Any, Any]], size: Optional[int] = None + ) -> List[Dict[Any, Any]]: + """ + Will test the filepath result and test if its size is at least self.filesize + + :param result: a list of dicts returned by Snakebite ls + :param size: the file size in MB a file should be at least to trigger True + :return: (bool) depending on the matching criteria + """ + if size: + log.debug( + "Filtering for file size >= %s in files: %s", + size, + map(lambda x: x["path"], result), + ) + size *= settings.MEGABYTE + result = [x for x in result if x["length"] >= size] + log.debug("HdfsSensor.poke: after size filter result is %s", result) + return result + + @staticmethod + def filter_for_ignored_ext( + result: List[Dict[Any, Any]], ignored_ext: List[str], ignore_copying: bool + ) -> List[Dict[Any, Any]]: + """ + Will filter if instructed to do so the result to remove matching criteria + + :param result: list of dicts returned by Snakebite ls + :type result: list[dict] + :param ignored_ext: list of ignored extensions + :type ignored_ext: list + :param ignore_copying: shall we ignore ? + :type ignore_copying: bool + :return: list of dicts which were not removed + :rtype: list[dict] + """ + if ignore_copying: + regex_builder = r"^.*\.(%s$)$" % "$|".join(ignored_ext) + ignored_extensions_regex = re.compile(regex_builder) + log.debug( + "Filtering result for ignored extensions: %s in files %s", + ignored_extensions_regex.pattern, + map(lambda x: x["path"], result), + ) + result = [ + x for x in result if not ignored_extensions_regex.match(x["path"]) + ] + log.debug("HdfsSensor.poke: after ext filter result is %s", result) + return result + + def poke(self, context: Dict[Any, Any]) -> bool: + """Get a snakebite client connection and check for file.""" + sb_client = self.hook(self.hdfs_conn_id).get_conn() + self.log.info("Poking for file %s", self.filepath) + try: + # IMOO it's not right here, as there is no raise of any kind. + # if the filepath is let's say '/data/mydirectory', + # it's correct but if it is '/data/mydirectory/*', + # it's not correct as the directory exists and sb_client does not raise any error + # here is a quick fix + result = sb_client.ls([self.filepath], include_toplevel=False) + self.log.debug("HdfsSensor.poke: result is %s", result) + result = self.filter_for_ignored_ext( + result, self.ignored_ext, self.ignore_copying + ) + result = self.filter_for_filesize(result, self.file_size) + return bool(result) + except Exception: # pylint: disable=broad-except + e = sys.exc_info() + self.log.debug("Caught an exception !: %s", str(e)) + return False + + +class HdfsRegexSensor(HdfsSensor): + """ + Waits for matching files by matching on regex + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:HdfsRegexSensor` + """ + + def __init__(self, regex: Pattern[str], *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.regex = regex + + def poke(self, context: Dict[Any, Any]) -> bool: + """ + Poke matching files in a directory with self.regex + + :return: Bool depending on the search criteria + """ + sb_client = self.hook(self.hdfs_conn_id).get_conn() + self.log.info( + "Poking for %s to be a directory with files matching %s", + self.filepath, + self.regex.pattern, + ) + result = [ + f + for f in sb_client.ls([self.filepath], include_toplevel=False) + if f["file_type"] == "f" + and self.regex.match(f["path"].replace(f"{self.filepath}/", "")) + ] + result = self.filter_for_ignored_ext( + result, self.ignored_ext, self.ignore_copying + ) + result = self.filter_for_filesize(result, self.file_size) + return bool(result) + + +class HdfsFolderSensor(HdfsSensor): + """ + Waits for a non-empty directory + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:HdfsFolderSensor` + """ + + def __init__(self, be_empty: bool = False, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.be_empty = be_empty + + def poke(self, context: Dict[str, Any]) -> bool: + """ + Poke for a non empty directory + + :return: Bool depending on the search criteria + """ + sb_client = self.hook(self.hdfs_conn_id).get_conn() + result = sb_client.ls([self.filepath], include_toplevel=True) + result = self.filter_for_ignored_ext( + result, self.ignored_ext, self.ignore_copying + ) + result = self.filter_for_filesize(result, self.file_size) + if self.be_empty: + self.log.info("Poking for filepath %s to a empty directory", self.filepath) + return len(result) == 1 and result[0]["path"] == self.filepath + else: + self.log.info( + "Poking for filepath %s to a non empty directory", self.filepath + ) + result.pop(0) + return bool(result) and result[0]["file_type"] == "f" diff --git a/reference/providers/apache/hdfs/sensors/web_hdfs.py b/reference/providers/apache/hdfs/sensors/web_hdfs.py new file mode 100644 index 0000000..9a4449d --- /dev/null +++ b/reference/providers/apache/hdfs/sensors/web_hdfs.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict + +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class WebHdfsSensor(BaseSensorOperator): + """Waits for a file or folder to land in HDFS""" + + template_fields = ("filepath",) + + @apply_defaults + def __init__( + self, *, filepath: str, webhdfs_conn_id: str = "webhdfs_default", **kwargs: Any + ) -> None: + super().__init__(**kwargs) + self.filepath = filepath + self.webhdfs_conn_id = webhdfs_conn_id + + def poke(self, context: Dict[Any, Any]) -> bool: + from airflow.providers.apache.hdfs.hooks.webhdfs import WebHDFSHook + + hook = WebHDFSHook(self.webhdfs_conn_id) + self.log.info("Poking for file %s", self.filepath) + return hook.check_for_path(hdfs_path=self.filepath) diff --git a/reference/providers/apache/hive/CHANGELOG.rst b/reference/providers/apache/hive/CHANGELOG.rst new file mode 100644 index 0000000..e6d4837 --- /dev/null +++ b/reference/providers/apache/hive/CHANGELOG.rst @@ -0,0 +1,44 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.2 +..... + +Bug fixes +~~~~~~~~~ + +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` + + +1.0.1 +..... + +Updated documentation and readme files. + +Bug fixes +~~~~~~~~~ + +* ``Remove password if in LDAP or CUSTOM mode HiveServer2Hook (#11767)`` + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/apache/hive/__init__.py b/reference/providers/apache/hive/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/hive/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/hive/example_dags/__init__.py b/reference/providers/apache/hive/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/hive/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/hive/example_dags/example_twitter_README.md b/reference/providers/apache/hive/example_dags/example_twitter_README.md new file mode 100644 index 0000000..5298947 --- /dev/null +++ b/reference/providers/apache/hive/example_dags/example_twitter_README.md @@ -0,0 +1,57 @@ + + +# Example Twitter DAG + +**_Introduction:_** This example dag depicts a typical ETL process and is a perfect use case automation scenario for Airflow. Please note that the main scripts associated with the tasks are returning None. The purpose of this DAG is to demonstrate how to write a functional DAG within Airflow. + +**Background:** Twitter is a social networking platform that enables users to send or broadcast short messages (140 Characters). A user has a user ID, i.e. JohnDoe, which is also known as a Twitter Handle. A short message, or tweet, can either be sent directed at another user using the @ symbol (i.e. @JohnDoe) or can be broadcast with a hashtag # followed by the topic name. _As most of the data on twitter is public, and twitter provides a generous API to retrieve these data, Twitter is the so called Gold Mine for Text Mining based data analytic._ This example DAG was driven out of our real use case, where we have used the SEARCH API from twitter to retrieve tweets from yesterday. The DAG is scheduled to run each day, and therefore works in an ETL fashion. + +**_Overview:_** At first, we need tasks that will get the tweets of our interest and save them on the hard-disk. Then, we need subsequent tasks that will clean and analyze the tweets. Then we want to store these files into HDFS, and load them into a Data Warehousing platform like Hive or HBase. The main reason we have selected Hive here is because it gives us a familiar SQL like interface, and makes our life of writing different queries a lot easier. Finally, the DAG needs to store a summarized result to a traditional database, i.e. MySQL or PostgreSQL, which is used by a reporting or business intelligence application. In other words, we basically want to achieve the following steps: + +1. Fetch Tweets +1. Clean Tweets +1. Analyze Tweets +1. Put Tweets to HDFS +1. Load data to Hive +1. Save Summary to MySQL + +**_Screenshot:_** + + +**_Example Structure:_** In this example dag, we are collecting tweets for four users account or twitter handle. Each twitter handle has two channels, incoming tweets and outgoing tweets. Hence, in this example, by running the fetch_tweet task, we should have eight output files. For better management, each of the eight output files should be saved with the yesterday's date (we are collecting tweets from yesterday), i.e. toTwitter_A_2016-03-21.csv. We are using three kind of operators: PythonOperator, BashOperator, and HiveOperator. However, for this example only the Python scripts are stored externally. Hence this example DAG only has the following directory structure: + +The python functions here are just placeholders. In case you are interested to actually make this DAG fully functional, first start with filling out the scripts as separate files and importing them into the DAG with absolute or relative import. My approach was to store the retrieved data in memory using Pandas dataframe first, and then use the built in method to save the CSV file on hard-disk. +The eight different CSV files are then put into eight different folders within HDFS. Each of the newly inserted files are then loaded into eight different external hive tables. Hive tables can be external or internal. In this case, we are inserting the data right into the table, and so we are making our tables internal. Each file is inserted into the respected Hive table named after the twitter channel, i.e. toTwitter_A or fromTwitter_A. It is also important to note that when we created the tables, we facilitated for partitioning by date using the variable dt and declared comma as the row deliminator. The partitioning is very handy and ensures our query execution time remains constant even with growing volume of data. +As most probably these folders and hive tables doesn't exist in your system, you will get an error for these tasks within the DAG. If you rebuild a function DAG from this example, make sure those folders and hive tables exists. When you create the table, keep the consideration of table partitioning and declaring comma as the row deliminator in your mind. Furthermore, you may also need to skip headers on each read and ensure that the user under which you have Airflow running has the right permission access. Below is a sample HQL snippet on creating such table: + +``` +CREATE TABLE toTwitter_A(id BIGINT, id_str STRING + created_at STRING, text STRING) + PARTITIONED BY (dt STRING) + ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + STORED AS TEXTFILE; + alter table toTwitter_A SET serdeproperties ('skip.header.line.count' = '1'); +``` + +When you review the code for the DAG, you will notice that these tasks are generated using for loop. These two for loops could be combined into one loop. However, in most cases, you will be running different analysis on your incoming and outgoing tweets, and hence they are kept separated in this example. +Final step is a running the broker script, brokerapi.py, which will run queries in Hive and store the summarized data to MySQL in our case. To connect to Hive, pyhs2 library is extremely useful and easy to use. To insert data into MySQL from Python, sqlalchemy is also a good one to use. +I hope you find this tutorial useful. If you have question feel free to ask me on [Twitter](twitter.com/EkhtiarSyed).

+-Ekhtiar Syed +Last Update: 8-April-2016 diff --git a/reference/providers/apache/hive/example_dags/example_twitter_dag.py b/reference/providers/apache/hive/example_dags/example_twitter_dag.py new file mode 100644 index 0000000..ba6b0d3 --- /dev/null +++ b/reference/providers/apache/hive/example_dags/example_twitter_dag.py @@ -0,0 +1,199 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# -------------------------------------------------------------------------------- +# Written By: Ekhtiar Syed +# Last Update: 8th April 2016 +# Caveat: This Dag will not run because of missing scripts. +# The purpose of this is to give you a sample of a real world example DAG! +# -------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------- +# Load The Dependencies +# -------------------------------------------------------------------------------- +""" +This is an example dag for managing twitter data. +""" +from datetime import date, timedelta + +from airflow import DAG +from airflow.operators.bash import BashOperator +from airflow.operators.python import PythonOperator +from airflow.providers.apache.hive.operators.hive import HiveOperator +from airflow.utils.dates import days_ago + +# -------------------------------------------------------------------------------- +# Create a few placeholder scripts. In practice these would be different python +# script files, which are imported in this section with absolute or relative imports +# -------------------------------------------------------------------------------- + + +def fetchtweets(): + """ + This is a placeholder for fetchtweets. + """ + + +def cleantweets(): + """ + This is a placeholder for cleantweets. + """ + + +def analyzetweets(): + """ + This is a placeholder for analyzetweets. + """ + + +def transfertodb(): + """ + This is a placeholder for transfertodb. + """ + + +# -------------------------------------------------------------------------------- +# set default arguments +# -------------------------------------------------------------------------------- + +default_args = { + "owner": "Ekhtiar", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), + # 'queue': 'bash_queue', + # 'pool': 'backfill', + # 'priority_weight': 10, + # 'end_date': datetime(2016, 1, 1), +} + +with DAG( + dag_id="example_twitter_dag", + default_args=default_args, + schedule_interval="@daily", + start_date=days_ago(5), + tags=["example"], +) as dag: + + # -------------------------------------------------------------------------------- + # This task should call Twitter API and retrieve tweets from yesterday from and to + # for the four twitter users (Twitter_A,..,Twitter_D) There should be eight csv + # output files generated by this task and naming convention + # is direction(from or to)_twitterHandle_date.csv + # -------------------------------------------------------------------------------- + + fetch_tweets = PythonOperator(task_id="fetch_tweets", python_callable=fetchtweets) + + # -------------------------------------------------------------------------------- + # Clean the eight files. In this step you can get rid of or cherry pick columns + # and different parts of the text + # -------------------------------------------------------------------------------- + + clean_tweets = PythonOperator(task_id="clean_tweets", python_callable=cleantweets) + + clean_tweets << fetch_tweets + + # -------------------------------------------------------------------------------- + # In this section you can use a script to analyze the twitter data. Could simply + # be a sentiment analysis through algorithms like bag of words or something more + # complicated. You can also take a look at Web Services to do such tasks + # -------------------------------------------------------------------------------- + + analyze_tweets = PythonOperator( + task_id="analyze_tweets", python_callable=analyzetweets + ) + + analyze_tweets << clean_tweets + + # -------------------------------------------------------------------------------- + # Although this is the last task, we need to declare it before the next tasks as we + # will use set_downstream This task will extract summary from Hive data and store + # it to MySQL + # -------------------------------------------------------------------------------- + + hive_to_mysql = PythonOperator( + task_id="hive_to_mysql", python_callable=transfertodb + ) + + # -------------------------------------------------------------------------------- + # The following tasks are generated using for loop. The first task puts the eight + # csv files to HDFS. The second task loads these files from HDFS to respected Hive + # tables. These two for loops could be combined into one loop. However, in most cases, + # you will be running different analysis on your incoming and outgoing tweets, + # and hence they are kept separated in this example. + # -------------------------------------------------------------------------------- + + from_channels = ["fromTwitter_A", "fromTwitter_B", "fromTwitter_C", "fromTwitter_D"] + to_channels = ["toTwitter_A", "toTwitter_B", "toTwitter_C", "toTwitter_D"] + yesterday = date.today() - timedelta(days=1) + dt = yesterday.strftime("%Y-%m-%d") + # define where you want to store the tweets csv file in your local directory + local_dir = "/tmp/" + # define the location where you want to store in HDFS + hdfs_dir = " /tmp/" + + for channel in to_channels: + + file_name = "to_" + channel + "_" + yesterday.strftime("%Y-%m-%d") + ".csv" + + load_to_hdfs = BashOperator( + task_id="put_" + channel + "_to_hdfs", + bash_command="HADOOP_USER_NAME=hdfs hadoop fs -put -f " + + local_dir + + file_name + + hdfs_dir + + channel + + "/", + ) + + load_to_hdfs << analyze_tweets + + load_to_hive = HiveOperator( + task_id="load_" + channel + "_to_hive", + hql="LOAD DATA INPATH '" + hdfs_dir + channel + "/" + file_name + "' " + "INTO TABLE " + channel + " " + "PARTITION(dt='" + dt + "')", + ) + load_to_hive << load_to_hdfs + load_to_hive >> hive_to_mysql + + for channel in from_channels: + file_name = "from_" + channel + "_" + yesterday.strftime("%Y-%m-%d") + ".csv" + load_to_hdfs = BashOperator( + task_id="put_" + channel + "_to_hdfs", + bash_command="HADOOP_USER_NAME=hdfs hadoop fs -put -f " + + local_dir + + file_name + + hdfs_dir + + channel + + "/", + ) + + load_to_hdfs << analyze_tweets + + load_to_hive = HiveOperator( + task_id="load_" + channel + "_to_hive", + hql="LOAD DATA INPATH '" + hdfs_dir + channel + "/" + file_name + "' " + "INTO TABLE " + channel + " " + "PARTITION(dt='" + dt + "')", + ) + + load_to_hive << load_to_hdfs + load_to_hive >> hive_to_mysql diff --git a/reference/providers/apache/hive/hooks/__init__.py b/reference/providers/apache/hive/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/hive/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/hive/hooks/hive.py b/reference/providers/apache/hive/hooks/hive.py new file mode 100644 index 0000000..d9f3828 --- /dev/null +++ b/reference/providers/apache/hive/hooks/hive.py @@ -0,0 +1,1117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import contextlib +import os +import re +import socket +import subprocess +import time +from collections import OrderedDict +from tempfile import NamedTemporaryFile, TemporaryDirectory +from typing import Any, Dict, List, Optional, Union + +import pandas +import unicodecsv as csv +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.hooks.dbapi import DbApiHook +from airflow.security import utils +from airflow.utils.helpers import as_flattened_list +from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING + +HIVE_QUEUE_PRIORITIES = ["VERY_HIGH", "HIGH", "NORMAL", "LOW", "VERY_LOW"] + + +def get_context_from_env_var() -> Dict[Any, Any]: + """ + Extract context from env variable, e.g. dag_id, task_id and execution_date, + so that they can be used inside BashOperator and PythonOperator. + + :return: The context of interest. + """ + return { + format_map["default"]: os.environ.get(format_map["env_var_format"], "") + for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values() + } + + +class HiveCliHook(BaseHook): + """Simple wrapper around the hive CLI. + + It also supports the ``beeline`` + a lighter CLI that runs JDBC and is replacing the heavier + traditional CLI. To enable ``beeline``, set the use_beeline param in the + extra field of your connection as in ``{ "use_beeline": true }`` + + Note that you can also set default hive CLI parameters using the + ``hive_cli_params`` to be used in your connection as in + ``{"hive_cli_params": "-hiveconf mapred.job.tracker=some.jobtracker:444"}`` + Parameters passed here can be overridden by run_cli's hive_conf param + + The extra connection parameter ``auth`` gets passed as in the ``jdbc`` + connection string as is. + + :param mapred_queue: queue used by the Hadoop Scheduler (Capacity or Fair) + :type mapred_queue: str + :param mapred_queue_priority: priority within the job queue. + Possible settings include: VERY_HIGH, HIGH, NORMAL, LOW, VERY_LOW + :type mapred_queue_priority: str + :param mapred_job_name: This name will appear in the jobtracker. + This can make monitoring easier. + :type mapred_job_name: str + """ + + conn_name_attr = "hive_cli_conn_id" + default_conn_name = "hive_cli_default" + conn_type = "hive_cli" + hook_name = "Hive Client Wrapper" + + def __init__( + self, + hive_cli_conn_id: str = default_conn_name, + run_as: Optional[str] = None, + mapred_queue: Optional[str] = None, + mapred_queue_priority: Optional[str] = None, + mapred_job_name: Optional[str] = None, + ) -> None: + super().__init__() + conn = self.get_connection(hive_cli_conn_id) + self.hive_cli_params: str = conn.extra_dejson.get("hive_cli_params", "") + self.use_beeline: bool = conn.extra_dejson.get("use_beeline", False) + self.auth = conn.extra_dejson.get("auth", "noSasl") + self.conn = conn + self.run_as = run_as + self.sub_process: Any = None + + if mapred_queue_priority: + mapred_queue_priority = mapred_queue_priority.upper() + if mapred_queue_priority not in HIVE_QUEUE_PRIORITIES: + raise AirflowException( + "Invalid Mapred Queue Priority. Valid values are: " + "{}".format(", ".join(HIVE_QUEUE_PRIORITIES)) + ) + + self.mapred_queue = mapred_queue or conf.get( + "hive", "default_hive_mapred_queue" + ) + self.mapred_queue_priority = mapred_queue_priority + self.mapred_job_name = mapred_job_name + + def _get_proxy_user(self) -> str: + """This function set the proper proxy_user value in case the user overwrite the default.""" + conn = self.conn + + proxy_user_value: str = conn.extra_dejson.get("proxy_user", "") + if proxy_user_value == "login" and conn.login: + return f"hive.server2.proxy.user={conn.login}" + if proxy_user_value == "owner" and self.run_as: + return f"hive.server2.proxy.user={self.run_as}" + if proxy_user_value != "": # There is a custom proxy user + return f"hive.server2.proxy.user={proxy_user_value}" + return proxy_user_value # The default proxy user (undefined) + + def _prepare_cli_cmd(self) -> List[Any]: + """This function creates the command list from available information""" + conn = self.conn + hive_bin = "hive" + cmd_extra = [] + + if self.use_beeline: + hive_bin = "beeline" + jdbc_url = f"jdbc:hive2://{conn.host}:{conn.port}/{conn.schema}" + if conf.get("core", "security") == "kerberos": + template = conn.extra_dejson.get("principal", "hive/_HOST@EXAMPLE.COM") + if "_HOST" in template: + template = utils.replace_hostname_pattern( + utils.get_components(template) + ) + + proxy_user = self._get_proxy_user() + + jdbc_url += f";principal={template};{proxy_user}" + elif self.auth: + jdbc_url += ";auth=" + self.auth + + jdbc_url = f'"{jdbc_url}"' + + cmd_extra += ["-u", jdbc_url] + if conn.login: + cmd_extra += ["-n", conn.login] + if conn.password: + cmd_extra += ["-p", conn.password] + + hive_params_list = self.hive_cli_params.split() + + return [hive_bin] + cmd_extra + hive_params_list + + @staticmethod + def _prepare_hiveconf(d: Dict[Any, Any]) -> List[Any]: + """ + This function prepares a list of hiveconf params + from a dictionary of key value pairs. + + :param d: + :type d: dict + + >>> hh = HiveCliHook() + >>> hive_conf = {"hive.exec.dynamic.partition": "true", + ... "hive.exec.dynamic.partition.mode": "nonstrict"} + >>> hh._prepare_hiveconf(hive_conf) + ["-hiveconf", "hive.exec.dynamic.partition=true",\ + "-hiveconf", "hive.exec.dynamic.partition.mode=nonstrict"] + """ + if not d: + return [] + return as_flattened_list( + zip(["-hiveconf"] * len(d), [f"{k}={v}" for k, v in d.items()]) + ) + + def run_cli( + self, + hql: Union[str, str], + schema: Optional[str] = None, + verbose: bool = True, + hive_conf: Optional[Dict[Any, Any]] = None, + ) -> Any: + """ + Run an hql statement using the hive cli. If hive_conf is specified + it should be a dict and the entries will be set as key/value pairs + in HiveConf + + + :param hive_conf: if specified these key value pairs will be passed + to hive as ``-hiveconf "key"="value"``. Note that they will be + passed after the ``hive_cli_params`` and thus will override + whatever values are specified in the database. + :type hive_conf: dict + + >>> hh = HiveCliHook() + >>> result = hh.run_cli("USE airflow;") + >>> ("OK" in result) + True + """ + conn = self.conn + schema = schema or conn.schema + if schema: + hql = f"USE {schema};\n{hql}" + + with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir: + with NamedTemporaryFile(dir=tmp_dir) as f: + hql += "\n" + f.write(hql.encode("UTF-8")) + f.flush() + hive_cmd = self._prepare_cli_cmd() + env_context = get_context_from_env_var() + # Only extend the hive_conf if it is defined. + if hive_conf: + env_context.update(hive_conf) + hive_conf_params = self._prepare_hiveconf(env_context) + if self.mapred_queue: + hive_conf_params.extend( + [ + "-hiveconf", + f"mapreduce.job.queuename={self.mapred_queue}", + "-hiveconf", + f"mapred.job.queue.name={self.mapred_queue}", + "-hiveconf", + f"tez.queue.name={self.mapred_queue}", + ] + ) + + if self.mapred_queue_priority: + hive_conf_params.extend( + [ + "-hiveconf", + f"mapreduce.job.priority={self.mapred_queue_priority}", + ] + ) + + if self.mapred_job_name: + hive_conf_params.extend( + ["-hiveconf", f"mapred.job.name={self.mapred_job_name}"] + ) + + hive_cmd.extend(hive_conf_params) + hive_cmd.extend(["-f", f.name]) + + if verbose: + self.log.info("%s", " ".join(hive_cmd)) + sub_process: Any = subprocess.Popen( + hive_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=tmp_dir, + close_fds=True, + ) + self.sub_process = sub_process + stdout = "" + while True: + line = sub_process.stdout.readline() + if not line: + break + stdout += line.decode("UTF-8") + if verbose: + self.log.info(line.decode("UTF-8").strip()) + sub_process.wait() + + if sub_process.returncode: + raise AirflowException(stdout) + + return stdout + + def test_hql(self, hql: Union[str, str]) -> None: + """Test an hql statement using the hive cli and EXPLAIN""" + create, insert, other = [], [], [] + for query in hql.split(";"): # naive + query_original = query + query = query.lower().strip() + + if query.startswith("create table"): + create.append(query_original) + elif query.startswith(("set ", "add jar ", "create temporary function")): + other.append(query_original) + elif query.startswith("insert"): + insert.append(query_original) + other_ = ";".join(other) + for query_set in [create, insert]: + for query in query_set: + + query_preview = " ".join(query.split())[:50] + self.log.info("Testing HQL [%s (...)]", query_preview) + if query_set == insert: + query = other_ + "; explain " + query + else: + query = "explain " + query + try: + self.run_cli(query, verbose=False) + except AirflowException as e: + message = e.args[0].split("\n")[-2] + self.log.info(message) + error_loc = re.search(r"(\d+):(\d+)", message) + if error_loc and error_loc.group(1).isdigit(): + lst = int(error_loc.group(1)) + begin = max(lst - 2, 0) + end = min(lst + 3, len(query.split("\n"))) + context = "\n".join(query.split("\n")[begin:end]) + self.log.info("Context :\n %s", context) + else: + self.log.info("SUCCESS") + + def load_df( + self, + df: pandas.DataFrame, + table: str, + field_dict: Optional[Dict[Any, Any]] = None, + delimiter: str = ",", + encoding: str = "utf8", + pandas_kwargs: Any = None, + **kwargs: Any, + ) -> None: + """ + Loads a pandas DataFrame into hive. + + Hive data types will be inferred if not passed but column names will + not be sanitized. + + :param df: DataFrame to load into a Hive table + :type df: pandas.DataFrame + :param table: target Hive table, use dot notation to target a + specific database + :type table: str + :param field_dict: mapping from column name to hive data type. + Note that it must be OrderedDict so as to keep columns' order. + :type field_dict: collections.OrderedDict + :param delimiter: field delimiter in the file + :type delimiter: str + :param encoding: str encoding to use when writing DataFrame to file + :type encoding: str + :param pandas_kwargs: passed to DataFrame.to_csv + :type pandas_kwargs: dict + :param kwargs: passed to self.load_file + """ + + def _infer_field_types_from_df(df: pandas.DataFrame) -> Dict[Any, Any]: + dtype_kind_hive_type = { + "b": "BOOLEAN", # boolean + "i": "BIGINT", # signed integer + "u": "BIGINT", # unsigned integer + "f": "DOUBLE", # floating-point + "c": "STRING", # complex floating-point + "M": "TIMESTAMP", # datetime + "O": "STRING", # object + "S": "STRING", # (byte-)string + "U": "STRING", # Unicode + "V": "STRING", # void + } + + order_type = OrderedDict() + for col, dtype in df.dtypes.iteritems(): + order_type[col] = dtype_kind_hive_type[dtype.kind] + return order_type + + if pandas_kwargs is None: + pandas_kwargs = {} + + with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir: + with NamedTemporaryFile(dir=tmp_dir, mode="w") as f: + if field_dict is None: + field_dict = _infer_field_types_from_df(df) + + df.to_csv( + path_or_buf=f, + sep=delimiter, + header=False, + index=False, + encoding=encoding, + date_format="%Y-%m-%d %H:%M:%S", + **pandas_kwargs, + ) + f.flush() + + return self.load_file( + filepath=f.name, + table=table, + delimiter=delimiter, + field_dict=field_dict, + **kwargs, + ) + + def load_file( + self, + filepath: str, + table: str, + delimiter: str = ",", + field_dict: Optional[Dict[Any, Any]] = None, + create: bool = True, + overwrite: bool = True, + partition: Optional[Dict[str, Any]] = None, + recreate: bool = False, + tblproperties: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Loads a local file into Hive + + Note that the table generated in Hive uses ``STORED AS textfile`` + which isn't the most efficient serialization format. If a + large amount of data is loaded and/or if the tables gets + queried considerably, you may want to use this operator only to + stage the data into a temporary table before loading it into its + final destination using a ``HiveOperator``. + + :param filepath: local filepath of the file to load + :type filepath: str + :param table: target Hive table, use dot notation to target a + specific database + :type table: str + :param delimiter: field delimiter in the file + :type delimiter: str + :param field_dict: A dictionary of the fields name in the file + as keys and their Hive types as values. + Note that it must be OrderedDict so as to keep columns' order. + :type field_dict: collections.OrderedDict + :param create: whether to create the table if it doesn't exist + :type create: bool + :param overwrite: whether to overwrite the data in table or partition + :type overwrite: bool + :param partition: target partition as a dict of partition columns + and values + :type partition: dict + :param recreate: whether to drop and recreate the table at every + execution + :type recreate: bool + :param tblproperties: TBLPROPERTIES of the hive table being created + :type tblproperties: dict + """ + hql = "" + if recreate: + hql += f"DROP TABLE IF EXISTS {table};\n" + if create or recreate: + if field_dict is None: + raise ValueError("Must provide a field dict when creating a table") + fields = ",\n ".join( + ["`{k}` {v}".format(k=k.strip("`"), v=v) for k, v in field_dict.items()] + ) + hql += f"CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n" + if partition: + pfields = ",\n ".join([p + " STRING" for p in partition]) + hql += f"PARTITIONED BY ({pfields})\n" + hql += "ROW FORMAT DELIMITED\n" + hql += f"FIELDS TERMINATED BY '{delimiter}'\n" + hql += "STORED AS textfile\n" + if tblproperties is not None: + tprops = ", ".join([f"'{k}'='{v}'" for k, v in tblproperties.items()]) + hql += f"TBLPROPERTIES({tprops})\n" + hql += ";" + self.log.info(hql) + self.run_cli(hql) + hql = f"LOAD DATA LOCAL INPATH '{filepath}' " + if overwrite: + hql += "OVERWRITE " + hql += f"INTO TABLE {table} " + if partition: + pvals = ", ".join([f"{k}='{v}'" for k, v in partition.items()]) + hql += f"PARTITION ({pvals})" + + # As a workaround for HIVE-10541, add a newline character + # at the end of hql (AIRFLOW-2412). + hql += ";\n" + + self.log.info(hql) + self.run_cli(hql) + + def kill(self) -> None: + """Kill Hive cli command""" + if hasattr(self, "sub_process"): + if self.sub_process.poll() is None: + print("Killing the Hive job") + self.sub_process.terminate() + time.sleep(60) + self.sub_process.kill() + + +class HiveMetastoreHook(BaseHook): + """Wrapper to interact with the Hive Metastore""" + + # java short max val + MAX_PART_COUNT = 32767 + + conn_name_attr = "metastore_conn_id" + default_conn_name = "metastore_default" + conn_type = "hive_metastore" + hook_name = "Hive Metastore Thrift" + + def __init__(self, metastore_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = metastore_conn_id + self.metastore = self.get_metastore_client() + + def __getstate__(self) -> Dict[str, Any]: + # This is for pickling to work despite the thrift hive client not + # being pickable + state = dict(self.__dict__) + del state["metastore"] + return state + + def __setstate__(self, d: Dict[str, Any]) -> None: + self.__dict__.update(d) + self.__dict__["metastore"] = self.get_metastore_client() + + def get_metastore_client(self) -> Any: + """Returns a Hive thrift client.""" + import hmsclient + from thrift.protocol import TBinaryProtocol + from thrift.transport import TSocket, TTransport + + conn = self._find_valid_server() + + if not conn: + raise AirflowException("Failed to locate the valid server.") + + auth_mechanism = conn.extra_dejson.get("authMechanism", "NOSASL") + + if conf.get("core", "security") == "kerberos": + auth_mechanism = conn.extra_dejson.get("authMechanism", "GSSAPI") + kerberos_service_name = conn.extra_dejson.get( + "kerberos_service_name", "hive" + ) + + conn_socket = TSocket.TSocket(conn.host, conn.port) + + if conf.get("core", "security") == "kerberos" and auth_mechanism == "GSSAPI": + try: + import saslwrapper as sasl + except ImportError: + import sasl + + def sasl_factory() -> sasl.Client: + sasl_client = sasl.Client() + sasl_client.setAttr("host", conn.host) + sasl_client.setAttr("service", kerberos_service_name) + sasl_client.init() + return sasl_client + + from thrift_sasl import TSaslClientTransport + + transport = TSaslClientTransport(sasl_factory, "GSSAPI", conn_socket) + else: + transport = TTransport.TBufferedTransport(conn_socket) + + protocol = TBinaryProtocol.TBinaryProtocol(transport) + + return hmsclient.HMSClient(iprot=protocol) + + def _find_valid_server(self) -> Any: + conns = self.get_connections(self.conn_id) + for conn in conns: + host_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.log.info("Trying to connect to %s:%s", conn.host, conn.port) + if host_socket.connect_ex((conn.host, conn.port)) == 0: + self.log.info("Connected to %s:%s", conn.host, conn.port) + host_socket.close() + return conn + else: + self.log.error("Could not connect to %s:%s", conn.host, conn.port) + return None + + def get_conn(self) -> Any: + return self.metastore + + def check_for_partition(self, schema: str, table: str, partition: str) -> bool: + """ + Checks whether a partition exists + + :param schema: Name of hive schema (database) @table belongs to + :type schema: str + :param table: Name of hive table @partition belongs to + :type schema: str + :partition: Expression that matches the partitions to check for + (eg `a = 'b' AND c = 'd'`) + :type schema: str + :rtype: bool + + >>> hh = HiveMetastoreHook() + >>> t = 'static_babynames_partitioned' + >>> hh.check_for_partition('airflow', t, "ds='2015-01-01'") + True + """ + with self.metastore as client: + partitions = client.get_partitions_by_filter(schema, table, partition, 1) + + return bool(partitions) + + def check_for_named_partition( + self, schema: str, table: str, partition_name: str + ) -> Any: + """ + Checks whether a partition with a given name exists + + :param schema: Name of hive schema (database) @table belongs to + :type schema: str + :param table: Name of hive table @partition belongs to + :type table: str + :partition: Name of the partitions to check for (eg `a=b/c=d`) + :type table: str + :rtype: bool + + >>> hh = HiveMetastoreHook() + >>> t = 'static_babynames_partitioned' + >>> hh.check_for_named_partition('airflow', t, "ds=2015-01-01") + True + >>> hh.check_for_named_partition('airflow', t, "ds=xxx") + False + """ + with self.metastore as client: + return client.check_for_named_partition(schema, table, partition_name) + + def get_table(self, table_name: str, db: str = "default") -> Any: + """Get a metastore table object + + >>> hh = HiveMetastoreHook() + >>> t = hh.get_table(db='airflow', table_name='static_babynames') + >>> t.tableName + 'static_babynames' + >>> [col.name for col in t.sd.cols] + ['state', 'year', 'name', 'gender', 'num'] + """ + if db == "default" and "." in table_name: + db, table_name = table_name.split(".")[:2] + with self.metastore as client: + return client.get_table(dbname=db, tbl_name=table_name) + + def get_tables(self, db: str, pattern: str = "*") -> Any: + """Get a metastore table object""" + with self.metastore as client: + tables = client.get_tables(db_name=db, pattern=pattern) + return client.get_table_objects_by_name(db, tables) + + def get_databases(self, pattern: str = "*") -> Any: + """Get a metastore table object""" + with self.metastore as client: + return client.get_databases(pattern) + + def get_partitions( + self, schema: str, table_name: str, partition_filter: Optional[str] = None + ) -> List[Any]: + """ + Returns a list of all partitions in a table. Works only + for tables with less than 32767 (java short max val). + For subpartitioned table, the number might easily exceed this. + + >>> hh = HiveMetastoreHook() + >>> t = 'static_babynames_partitioned' + >>> parts = hh.get_partitions(schema='airflow', table_name=t) + >>> len(parts) + 1 + >>> parts + [{'ds': '2015-01-01'}] + """ + with self.metastore as client: + table = client.get_table(dbname=schema, tbl_name=table_name) + if len(table.partitionKeys) == 0: + raise AirflowException("The table isn't partitioned") + else: + if partition_filter: + parts = client.get_partitions_by_filter( + db_name=schema, + tbl_name=table_name, + filter=partition_filter, + max_parts=HiveMetastoreHook.MAX_PART_COUNT, + ) + else: + parts = client.get_partitions( + db_name=schema, + tbl_name=table_name, + max_parts=HiveMetastoreHook.MAX_PART_COUNT, + ) + + pnames = [p.name for p in table.partitionKeys] + return [dict(zip(pnames, p.values)) for p in parts] + + @staticmethod + def _get_max_partition_from_part_specs( + part_specs: List[Any], + partition_key: Optional[str], + filter_map: Optional[Dict[str, Any]], + ) -> Any: + """ + Helper method to get max partition of partitions with partition_key + from part specs. key:value pair in filter_map will be used to + filter out partitions. + + :param part_specs: list of partition specs. + :type part_specs: list + :param partition_key: partition key name. + :type partition_key: str + :param filter_map: partition_key:partition_value map used for partition filtering, + e.g. {'key1': 'value1', 'key2': 'value2'}. + Only partitions matching all partition_key:partition_value + pairs will be considered as candidates of max partition. + :type filter_map: map + :return: Max partition or None if part_specs is empty. + :rtype: basestring + """ + if not part_specs: + return None + + # Assuming all specs have the same keys. + if partition_key not in part_specs[0].keys(): + raise AirflowException( + f"Provided partition_key {partition_key} is not in part_specs." + ) + is_subset = None + if filter_map: + is_subset = set(filter_map.keys()).issubset(set(part_specs[0].keys())) + if filter_map and not is_subset: + raise AirflowException( + "Keys in provided filter_map {} " + "are not subset of part_spec keys: {}".format( + ", ".join(filter_map.keys()), ", ".join(part_specs[0].keys()) + ) + ) + + candidates = [ + p_dict[partition_key] + for p_dict in part_specs + if filter_map is None + or all(item in p_dict.items() for item in filter_map.items()) + ] + + if not candidates: + return None + else: + return max(candidates) + + def max_partition( + self, + schema: str, + table_name: str, + field: Optional[str] = None, + filter_map: Optional[Dict[Any, Any]] = None, + ) -> Any: + """ + Returns the maximum value for all partitions with given field in a table. + If only one partition key exist in the table, the key will be used as field. + filter_map should be a partition_key:partition_value map and will be used to + filter out partitions. + + :param schema: schema name. + :type schema: str + :param table_name: table name. + :type table_name: str + :param field: partition key to get max partition from. + :type field: str + :param filter_map: partition_key:partition_value map used for partition filtering. + :type filter_map: map + + >>> hh = HiveMetastoreHook() + >>> filter_map = {'ds': '2015-01-01', 'ds': '2014-01-01'} + >>> t = 'static_babynames_partitioned' + >>> hh.max_partition(schema='airflow',\ + ... table_name=t, field='ds', filter_map=filter_map) + '2015-01-01' + """ + with self.metastore as client: + table = client.get_table(dbname=schema, tbl_name=table_name) + key_name_set = {key.name for key in table.partitionKeys} + if len(table.partitionKeys) == 1: + field = table.partitionKeys[0].name + elif not field: + raise AirflowException( + "Please specify the field you want the max value for." + ) + elif field not in key_name_set: + raise AirflowException("Provided field is not a partition key.") + + if filter_map and not set(filter_map.keys()).issubset(key_name_set): + raise AirflowException( + "Provided filter_map contains keys that are not partition key." + ) + + part_names = client.get_partition_names( + schema, table_name, max_parts=HiveMetastoreHook.MAX_PART_COUNT + ) + part_specs = [ + client.partition_name_to_spec(part_name) for part_name in part_names + ] + + return HiveMetastoreHook._get_max_partition_from_part_specs( + part_specs, field, filter_map + ) + + def table_exists(self, table_name: str, db: str = "default") -> bool: + """ + Check if table exists + + >>> hh = HiveMetastoreHook() + >>> hh.table_exists(db='airflow', table_name='static_babynames') + True + >>> hh.table_exists(db='airflow', table_name='does_not_exist') + False + """ + try: + self.get_table(table_name, db) + return True + except Exception: # pylint: disable=broad-except + return False + + def drop_partitions(self, table_name, part_vals, delete_data=False, db="default"): + """ + Drop partitions from the given table matching the part_vals input + + :param table_name: table name. + :type table_name: str + :param part_vals: list of partition specs. + :type part_vals: list + :param delete_data: Setting to control if underlying data have to deleted + in addition to dropping partitions. + :type delete_data: bool + :param db: Name of hive schema (database) @table belongs to + :type db: str + + >>> hh = HiveMetastoreHook() + >>> hh.drop_partitions(db='airflow', table_name='static_babynames', + part_vals="['2020-05-01']") + True + """ + if self.table_exists(table_name, db): + with self.metastore as client: + self.log.info( + "Dropping partition of table %s.%s matching the spec: %s", + db, + table_name, + part_vals, + ) + return client.drop_partition(db, table_name, part_vals, delete_data) + else: + self.log.info("Table %s.%s does not exist!", db, table_name) + return False + + +class HiveServer2Hook(DbApiHook): + """ + Wrapper around the pyhive library + + Notes: + * the default authMechanism is PLAIN, to override it you + can specify it in the ``extra`` of your connection in the UI + * the default for run_set_variable_statements is true, if you + are using impala you may need to set it to false in the + ``extra`` of your connection in the UI + """ + + conn_name_attr = "hiveserver2_conn_id" + default_conn_name = "hiveserver2_default" + conn_type = "hiveserver2" + hook_name = "Hive Server 2 Thrift" + supports_autocommit = False + + def get_conn(self, schema: Optional[str] = None) -> Any: + """Returns a Hive connection object.""" + username: Optional[str] = None + password: Optional[str] = None + # pylint: disable=no-member + db = self.get_connection(self.hiveserver2_conn_id) # type: ignore + + auth_mechanism = db.extra_dejson.get("authMechanism", "NONE") + if auth_mechanism == "NONE" and db.login is None: + # we need to give a username + username = "airflow" + kerberos_service_name = None + if conf.get("core", "security") == "kerberos": + auth_mechanism = db.extra_dejson.get("authMechanism", "KERBEROS") + kerberos_service_name = db.extra_dejson.get("kerberos_service_name", "hive") + + # pyhive uses GSSAPI instead of KERBEROS as a auth_mechanism identifier + if auth_mechanism == "GSSAPI": + self.log.warning( + "Detected deprecated 'GSSAPI' for authMechanism for %s. Please use 'KERBEROS' instead", + self.hiveserver2_conn_id, # type: ignore + ) + auth_mechanism = "KERBEROS" + + # Password should be set if and only if in LDAP or CUSTOM mode + if auth_mechanism in ("LDAP", "CUSTOM"): + password = db.password + + from pyhive.hive import connect + + return connect( + host=db.host, + port=db.port, + auth=auth_mechanism, + kerberos_service_name=kerberos_service_name, + username=db.login or username, + password=password, + database=schema or db.schema or "default", + ) + + # pylint: enable=no-member + + def _get_results( + self, + hql: Union[str, str, List[str]], + schema: str = "default", + fetch_size: Optional[int] = None, + hive_conf: Optional[Dict[Any, Any]] = None, + ) -> Any: + from pyhive.exc import ProgrammingError + + if isinstance(hql, str): + hql = [hql] + previous_description = None + with contextlib.closing(self.get_conn(schema)) as conn, contextlib.closing( + conn.cursor() + ) as cur: + + cur.arraysize = fetch_size or 1000 + + # not all query services (e.g. impala AIRFLOW-4434) support the set command + # pylint: disable=no-member + db = self.get_connection(self.hiveserver2_conn_id) # type: ignore + # pylint: enable=no-member + if db.extra_dejson.get("run_set_variable_statements", True): + env_context = get_context_from_env_var() + if hive_conf: + env_context.update(hive_conf) + for k, v in env_context.items(): + cur.execute(f"set {k}={v}") + + for statement in hql: + cur.execute(statement) + # we only get results of statements that returns + lowered_statement = statement.lower().strip() + if ( + lowered_statement.startswith("select") + or lowered_statement.startswith("with") + or lowered_statement.startswith("show") + or ( + lowered_statement.startswith("set") + and "=" not in lowered_statement + ) + ): + description = cur.description + if previous_description and previous_description != description: + message = """The statements are producing different descriptions: + Current: {} + Previous: {}""".format( + repr(description), repr(previous_description) + ) + raise ValueError(message) + elif not previous_description: + previous_description = description + yield description + try: + # DB API 2 raises when no results are returned + # we're silencing here as some statements in the list + # may be `SET` or DDL + yield from cur + except ProgrammingError: + self.log.debug("get_results returned no records") + + def get_results( + self, + hql: Union[str, str], + schema: str = "default", + fetch_size: Optional[int] = None, + hive_conf: Optional[Dict[Any, Any]] = None, + ) -> Dict[str, Any]: + """ + Get results of the provided hql in target schema. + + :param hql: hql to be executed. + :type hql: str or list + :param schema: target schema, default to 'default'. + :type schema: str + :param fetch_size: max size of result to fetch. + :type fetch_size: int + :param hive_conf: hive_conf to execute alone with the hql. + :type hive_conf: dict + :return: results of hql execution, dict with data (list of results) and header + :rtype: dict + """ + results_iter = self._get_results( + hql, schema, fetch_size=fetch_size, hive_conf=hive_conf + ) + header = next(results_iter) + results = {"data": list(results_iter), "header": header} + return results + + def to_csv( + self, + hql: Union[str, str], + csv_filepath: str, + schema: str = "default", + delimiter: str = ",", + lineterminator: str = "\r\n", + output_header: bool = True, + fetch_size: int = 1000, + hive_conf: Optional[Dict[Any, Any]] = None, + ) -> None: + """ + Execute hql in target schema and write results to a csv file. + + :param hql: hql to be executed. + :type hql: str or list + :param csv_filepath: filepath of csv to write results into. + :type csv_filepath: str + :param schema: target schema, default to 'default'. + :type schema: str + :param delimiter: delimiter of the csv file, default to ','. + :type delimiter: str + :param lineterminator: lineterminator of the csv file. + :type lineterminator: str + :param output_header: header of the csv file, default to True. + :type output_header: bool + :param fetch_size: number of result rows to write into the csv file, default to 1000. + :type fetch_size: int + :param hive_conf: hive_conf to execute alone with the hql. + :type hive_conf: dict + + """ + results_iter = self._get_results( + hql, schema, fetch_size=fetch_size, hive_conf=hive_conf + ) + header = next(results_iter) + message = None + + i = 0 + with open(csv_filepath, "wb") as file: + writer = csv.writer( + file, + delimiter=delimiter, + lineterminator=lineterminator, + encoding="utf-8", + ) + try: + if output_header: + self.log.debug("Cursor description is %s", header) + writer.writerow([c[0] for c in header]) + + for i, row in enumerate(results_iter, 1): + writer.writerow(row) + if i % fetch_size == 0: + self.log.info("Written %s rows so far.", i) + except ValueError as exception: + message = str(exception) + + if message: + # need to clean up the file first + os.remove(csv_filepath) + raise ValueError(message) + + self.log.info("Done. Loaded a total of %s rows.", i) + + def get_records( + self, + hql: Union[str, str], + schema: str = "default", + hive_conf: Optional[Dict[Any, Any]] = None, + ) -> Any: + """ + Get a set of records from a Hive query. + + :param hql: hql to be executed. + :type hql: str or list + :param schema: target schema, default to 'default'. + :type schema: str + :param hive_conf: hive_conf to execute alone with the hql. + :type hive_conf: dict + :return: result of hive execution + :rtype: list + + >>> hh = HiveServer2Hook() + >>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100" + >>> len(hh.get_records(sql)) + 100 + """ + return self.get_results(hql, schema=schema, hive_conf=hive_conf)["data"] + + def get_pandas_df( # type: ignore + self, + hql: Union[str, str], + schema: str = "default", + hive_conf: Optional[Dict[Any, Any]] = None, + **kwargs, + ) -> pandas.DataFrame: + """ + Get a pandas dataframe from a Hive query + + :param hql: hql to be executed. + :type hql: str or list + :param schema: target schema, default to 'default'. + :type schema: str + :param hive_conf: hive_conf to execute alone with the hql. + :type hive_conf: dict + :param kwargs: (optional) passed into pandas.DataFrame constructor + :type kwargs: dict + :return: result of hive execution + :rtype: DataFrame + + >>> hh = HiveServer2Hook() + >>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100" + >>> df = hh.get_pandas_df(sql) + >>> len(df.index) + 100 + + :return: pandas.DateFrame + """ + res = self.get_results(hql, schema=schema, hive_conf=hive_conf) + df = pandas.DataFrame(res["data"], **kwargs) + df.columns = [c[0] for c in res["header"]] + return df diff --git a/reference/providers/apache/hive/operators/__init__.py b/reference/providers/apache/hive/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/hive/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/hive/operators/hive.py b/reference/providers/apache/hive/operators/hive.py new file mode 100644 index 0000000..dfb48fd --- /dev/null +++ b/reference/providers/apache/hive/operators/hive.py @@ -0,0 +1,180 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import re +from typing import Any, Dict, Optional + +from airflow.configuration import conf +from airflow.models import BaseOperator +from airflow.providers.apache.hive.hooks.hive import HiveCliHook +from airflow.utils import operator_helpers +from airflow.utils.decorators import apply_defaults +from airflow.utils.operator_helpers import context_to_airflow_vars + + +class HiveOperator(BaseOperator): + """ + Executes hql code or hive script in a specific Hive database. + + :param hql: the hql to be executed. Note that you may also use + a relative path from the dag file of a (template) hive + script. (templated) + :type hql: str + :param hive_cli_conn_id: reference to the Hive database. (templated) + :type hive_cli_conn_id: str + :param hiveconfs: if defined, these key value pairs will be passed + to hive as ``-hiveconf "key"="value"`` + :type hiveconfs: dict + :param hiveconf_jinja_translate: when True, hiveconf-type templating + ${var} gets translated into jinja-type templating {{ var }} and + ${hiveconf:var} gets translated into jinja-type templating {{ var }}. + Note that you may want to use this along with the + ``DAG(user_defined_macros=myargs)`` parameter. View the DAG + object documentation for more details. + :type hiveconf_jinja_translate: bool + :param script_begin_tag: If defined, the operator will get rid of the + part of the script before the first occurrence of `script_begin_tag` + :type script_begin_tag: str + :param run_as_owner: Run HQL code as a DAG's owner. + :type run_as_owner: bool + :param mapred_queue: queue used by the Hadoop CapacityScheduler. (templated) + :type mapred_queue: str + :param mapred_queue_priority: priority within CapacityScheduler queue. + Possible settings include: VERY_HIGH, HIGH, NORMAL, LOW, VERY_LOW + :type mapred_queue_priority: str + :param mapred_job_name: This name will appear in the jobtracker. + This can make monitoring easier. + :type mapred_job_name: str + """ + + template_fields = ( + "hql", + "schema", + "hive_cli_conn_id", + "mapred_queue", + "hiveconfs", + "mapred_job_name", + "mapred_queue_priority", + ) + template_ext = ( + ".hql", + ".sql", + ) + ui_color = "#f0e4ec" + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + hql: str, + hive_cli_conn_id: str = "hive_cli_default", + schema: str = "default", + hiveconfs: Optional[Dict[Any, Any]] = None, + hiveconf_jinja_translate: bool = False, + script_begin_tag: Optional[str] = None, + run_as_owner: bool = False, + mapred_queue: Optional[str] = None, + mapred_queue_priority: Optional[str] = None, + mapred_job_name: Optional[str] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.hql = hql + self.hive_cli_conn_id = hive_cli_conn_id + self.schema = schema + self.hiveconfs = hiveconfs or {} + self.hiveconf_jinja_translate = hiveconf_jinja_translate + self.script_begin_tag = script_begin_tag + self.run_as = None + if run_as_owner: + self.run_as = self.dag.owner + self.mapred_queue = mapred_queue + self.mapred_queue_priority = mapred_queue_priority + self.mapred_job_name = mapred_job_name + self.mapred_job_name_template = conf.get( + "hive", + "mapred_job_name_template", + fallback="Airflow HiveOperator task for {hostname}.{dag_id}.{task_id}.{execution_date}", + ) + + # assigned lazily - just for consistency we can create the attribute with a + # `None` initial value, later it will be populated by the execute method. + # This also makes `on_kill` implementation consistent since it assumes `self.hook` + # is defined. + self.hook: Optional[HiveCliHook] = None + + def get_hook(self) -> HiveCliHook: + """Get Hive cli hook""" + return HiveCliHook( + hive_cli_conn_id=self.hive_cli_conn_id, + run_as=self.run_as, + mapred_queue=self.mapred_queue, + mapred_queue_priority=self.mapred_queue_priority, + mapred_job_name=self.mapred_job_name, + ) + + def prepare_template(self) -> None: + if self.hiveconf_jinja_translate: + self.hql = re.sub( + r"(\$\{(hiveconf:)?([ a-zA-Z0-9_]*)\})", r"{{ \g<3> }}", self.hql + ) + if self.script_begin_tag and self.script_begin_tag in self.hql: + self.hql = "\n".join(self.hql.split(self.script_begin_tag)[1:]) + + def execute(self, context: Dict[str, Any]) -> None: + self.log.info("Executing: %s", self.hql) + self.hook = self.get_hook() + + # set the mapred_job_name if it's not set with dag, task, execution time info + if not self.mapred_job_name: + ti = context["ti"] + self.hook.mapred_job_name = self.mapred_job_name_template.format( + dag_id=ti.dag_id, + task_id=ti.task_id, + execution_date=ti.execution_date.isoformat(), + hostname=ti.hostname.split(".")[0], + ) + + if self.hiveconf_jinja_translate: + self.hiveconfs = context_to_airflow_vars(context) + else: + self.hiveconfs.update(context_to_airflow_vars(context)) + + self.log.info("Passing HiveConf: %s", self.hiveconfs) + self.hook.run_cli(hql=self.hql, schema=self.schema, hive_conf=self.hiveconfs) + + def dry_run(self) -> None: + # Reset airflow environment variables to prevent + # existing env vars from impacting behavior. + self.clear_airflow_vars() + + self.hook = self.get_hook() + self.hook.test_hql(hql=self.hql) + + def on_kill(self) -> None: + if self.hook: + self.hook.kill() + + def clear_airflow_vars(self) -> None: + """Reset airflow environment variables to prevent existing ones from impacting behavior.""" + blank_env_vars = { + value["env_var_format"]: "" + for value in operator_helpers.AIRFLOW_VAR_NAME_FORMAT_MAPPING.values() + } + os.environ.update(blank_env_vars) diff --git a/reference/providers/apache/hive/operators/hive_stats.py b/reference/providers/apache/hive/operators/hive_stats.py new file mode 100644 index 0000000..a2a1a4d --- /dev/null +++ b/reference/providers/apache/hive/operators/hive_stats.py @@ -0,0 +1,192 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import warnings +from collections import OrderedDict +from typing import Any, Callable, Dict, List, Optional + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.providers.presto.hooks.presto import PrestoHook +from airflow.utils.decorators import apply_defaults + + +class HiveStatsCollectionOperator(BaseOperator): + """ + Gathers partition statistics using a dynamically generated Presto + query, inserts the stats into a MySql table with this format. Stats + overwrite themselves if you rerun the same date/partition. :: + + CREATE TABLE hive_stats ( + ds VARCHAR(16), + table_name VARCHAR(500), + metric VARCHAR(200), + value BIGINT + ); + + :param table: the source table, in the format ``database.table_name``. (templated) + :type table: str + :param partition: the source partition. (templated) + :type partition: dict of {col:value} + :param extra_exprs: dict of expression to run against the table where + keys are metric names and values are Presto compatible expressions + :type extra_exprs: dict + :param excluded_columns: list of columns to exclude, consider + excluding blobs, large json columns, ... + :type excluded_columns: list + :param assignment_func: a function that receives a column name and + a type, and returns a dict of metric names and an Presto expressions. + If None is returned, the global defaults are applied. If an + empty dictionary is returned, no stats are computed for that + column. + :type assignment_func: function + """ + + template_fields = ("table", "partition", "ds", "dttm") + ui_color = "#aff7a6" + + @apply_defaults + def __init__( + self, + *, + table: str, + partition: Any, + extra_exprs: Optional[Dict[str, Any]] = None, + excluded_columns: Optional[List[str]] = None, + assignment_func: Optional[ + Callable[[str, str], Optional[Dict[Any, Any]]] + ] = None, + metastore_conn_id: str = "metastore_default", + presto_conn_id: str = "presto_default", + mysql_conn_id: str = "airflow_db", + **kwargs: Any, + ) -> None: + if "col_blacklist" in kwargs: + warnings.warn( + "col_blacklist kwarg passed to {c} (task_id: {t}) is deprecated, please rename it to " + "excluded_columns instead".format( + c=self.__class__.__name__, t=kwargs.get("task_id") + ), + category=FutureWarning, + stacklevel=2, + ) + excluded_columns = kwargs.pop("col_blacklist") + super().__init__(**kwargs) + self.table = table + self.partition = partition + self.extra_exprs = extra_exprs or {} + self.excluded_columns = excluded_columns or [] # type: List[str] + self.metastore_conn_id = metastore_conn_id + self.presto_conn_id = presto_conn_id + self.mysql_conn_id = mysql_conn_id + self.assignment_func = assignment_func + self.ds = "{{ ds }}" + self.dttm = "{{ execution_date.isoformat() }}" + + def get_default_exprs(self, col: str, col_type: str) -> Dict[Any, Any]: + """Get default expressions""" + if col in self.excluded_columns: + return {} + exp = {(col, "non_null"): f"COUNT({col})"} + if col_type in ["double", "int", "bigint", "float"]: + exp[(col, "sum")] = f"SUM({col})" + exp[(col, "min")] = f"MIN({col})" + exp[(col, "max")] = f"MAX({col})" + exp[(col, "avg")] = f"AVG({col})" + elif col_type == "boolean": + exp[(col, "true")] = f"SUM(CASE WHEN {col} THEN 1 ELSE 0 END)" + exp[(col, "false")] = f"SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)" + elif col_type in ["string"]: + exp[(col, "len")] = f"SUM(CAST(LENGTH({col}) AS BIGINT))" + exp[(col, "approx_distinct")] = f"APPROX_DISTINCT({col})" + + return exp + + def execute(self, context: Optional[Dict[str, Any]] = None) -> None: + metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) + table = metastore.get_table(table_name=self.table) + field_types = {col.name: col.type for col in table.sd.cols} + + exprs: Any = {("", "count"): "COUNT(*)"} + for col, col_type in list(field_types.items()): + if self.assignment_func: + assign_exprs = self.assignment_func(col, col_type) + if assign_exprs is None: + assign_exprs = self.get_default_exprs(col, col_type) + else: + assign_exprs = self.get_default_exprs(col, col_type) + exprs.update(assign_exprs) + exprs.update(self.extra_exprs) + exprs = OrderedDict(exprs) + exprs_str = ",\n ".join( + [v + " AS " + k[0] + "__" + k[1] for k, v in exprs.items()] + ) + + where_clause_ = [f"{k} = '{v}'" for k, v in self.partition.items()] + where_clause = " AND\n ".join(where_clause_) + sql = f"SELECT {exprs_str} FROM {self.table} WHERE {where_clause};" + + presto = PrestoHook(presto_conn_id=self.presto_conn_id) + self.log.info("Executing SQL check: %s", sql) + row = presto.get_first(hql=sql) + self.log.info("Record: %s", row) + if not row: + raise AirflowException("The query returned None") + + part_json = json.dumps(self.partition, sort_keys=True) + + self.log.info("Deleting rows from previous runs if they exist") + mysql = MySqlHook(self.mysql_conn_id) + sql = f""" + SELECT 1 FROM hive_stats + WHERE + table_name='{self.table}' AND + partition_repr='{part_json}' AND + dttm='{self.dttm}' + LIMIT 1; + """ + if mysql.get_records(sql): + sql = f""" + DELETE FROM hive_stats + WHERE + table_name='{self.table}' AND + partition_repr='{part_json}' AND + dttm='{self.dttm}'; + """ + mysql.run(sql) + + self.log.info("Pivoting and loading cells into the Airflow db") + rows = [ + (self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) + for r in zip(exprs, row) + ] + mysql.insert_rows( + table="hive_stats", + rows=rows, + target_fields=[ + "ds", + "dttm", + "table_name", + "partition_repr", + "col", + "metric", + "value", + ], + ) diff --git a/reference/providers/apache/hive/provider.yaml b/reference/providers/apache/hive/provider.yaml new file mode 100644 index 0000000..fbd360f --- /dev/null +++ b/reference/providers/apache/hive/provider.yaml @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-apache-hive +name: Apache Hive +description: | + `Apache Hive `__ + +versions: + - 1.0.2 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Apache Hive + external-doc-url: https://hive.apache.org/ + logo: /integration-logos/apache/hive.png + tags: [apache] + +operators: + - integration-name: Apache Hive + python-modules: + - airflow.providers.apache.hive.operators.hive + - airflow.providers.apache.hive.operators.hive_stats + +sensors: + - integration-name: Apache Hive + python-modules: + - airflow.providers.apache.hive.sensors.hive_partition + - airflow.providers.apache.hive.sensors.metastore_partition + - airflow.providers.apache.hive.sensors.named_hive_partition + +hooks: + - integration-name: Apache Hive + python-modules: + - airflow.providers.apache.hive.hooks.hive + +transfers: + - source-integration-name: Vertica + target-integration-name: Apache Hive + python-module: airflow.providers.apache.hive.transfers.vertica_to_hive + - source-integration-name: Apache Hive + target-integration-name: MySQL + python-module: airflow.providers.apache.hive.transfers.hive_to_mysql + - source-integration-name: Apache Hive + target-integration-name: Samba + python-module: airflow.providers.apache.hive.transfers.hive_to_samba + - source-integration-name: Amazon Simple Storage Service (S3) + target-integration-name: Apache Hive + python-module: airflow.providers.apache.hive.transfers.s3_to_hive + - source-integration-name: MySQL + target-integration-name: Apache Hive + python-module: airflow.providers.apache.hive.transfers.mysql_to_hive + - source-integration-name: Microsoft SQL Server (MSSQL) + target-integration-name: Apache Hive + python-module: airflow.providers.apache.hive.transfers.mssql_to_hive + +hook-class-names: + - airflow.providers.apache.hive.hooks.hive.HiveCliHook + - airflow.providers.apache.hive.hooks.hive.HiveServer2Hook + - airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook diff --git a/reference/providers/apache/hive/sensors/__init__.py b/reference/providers/apache/hive/sensors/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/hive/sensors/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/hive/sensors/hive_partition.py b/reference/providers/apache/hive/sensors/hive_partition.py new file mode 100644 index 0000000..ba1204e --- /dev/null +++ b/reference/providers/apache/hive/sensors/hive_partition.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, Optional + +from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class HivePartitionSensor(BaseSensorOperator): + """ + Waits for a partition to show up in Hive. + + Note: Because ``partition`` supports general logical operators, it + can be inefficient. Consider using NamedHivePartitionSensor instead if + you don't need the full flexibility of HivePartitionSensor. + + :param table: The name of the table to wait for, supports the dot + notation (my_database.my_table) + :type table: str + :param partition: The partition clause to wait for. This is passed as + is to the metastore Thrift client ``get_partitions_by_filter`` method, + and apparently supports SQL like notation as in ``ds='2015-01-01' + AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"`` + :type partition: str + :param metastore_conn_id: reference to the metastore thrift service + connection id + :type metastore_conn_id: str + """ + + template_fields = ( + "schema", + "table", + "partition", + ) + ui_color = "#C5CAE9" + + @apply_defaults + def __init__( + self, + *, + table: str, + partition: Optional[str] = "ds='{{ ds }}'", + metastore_conn_id: str = "metastore_default", + schema: str = "default", + poke_interval: int = 60 * 3, + **kwargs: Any, + ): + super().__init__(poke_interval=poke_interval, **kwargs) + if not partition: + partition = "ds='{{ ds }}'" + self.metastore_conn_id = metastore_conn_id + self.table = table + self.partition = partition + self.schema = schema + + def poke(self, context: Dict[str, Any]) -> bool: + if "." in self.table: + self.schema, self.table = self.table.split(".") + self.log.info( + "Poking for table %s.%s, partition %s", + self.schema, + self.table, + self.partition, + ) + if not hasattr(self, "hook"): + hook = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) + return hook.check_for_partition(self.schema, self.table, self.partition) diff --git a/reference/providers/apache/hive/sensors/metastore_partition.py b/reference/providers/apache/hive/sensors/metastore_partition.py new file mode 100644 index 0000000..02a18d8 --- /dev/null +++ b/reference/providers/apache/hive/sensors/metastore_partition.py @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict + +from airflow.sensors.sql import SqlSensor +from airflow.utils.decorators import apply_defaults + + +class MetastorePartitionSensor(SqlSensor): + """ + An alternative to the HivePartitionSensor that talk directly to the + MySQL db. This was created as a result of observing sub optimal + queries generated by the Metastore thrift service when hitting + subpartitioned tables. The Thrift service's queries were written in a + way that wouldn't leverage the indexes. + + :param schema: the schema + :type schema: str + :param table: the table + :type table: str + :param partition_name: the partition name, as defined in the PARTITIONS + table of the Metastore. Order of the fields does matter. + Examples: ``ds=2016-01-01`` or + ``ds=2016-01-01/sub=foo`` for a sub partitioned table + :type partition_name: str + :param mysql_conn_id: a reference to the MySQL conn_id for the metastore + :type mysql_conn_id: str + """ + + template_fields = ("partition_name", "table", "schema") + ui_color = "#8da7be" + poke_context_fields = ("partition_name", "table", "schema", "mysql_conn_id") + + @apply_defaults + def __init__( + self, + *, + table: str, + partition_name: str, + schema: str = "default", + mysql_conn_id: str = "metastore_mysql", + **kwargs: Any, + ): + + self.partition_name = partition_name + self.table = table + self.schema = schema + self.first_poke = True + self.conn_id = mysql_conn_id + # TODO(aoen): We shouldn't be using SqlSensor here but MetastorePartitionSensor. + # The problem is the way apply_defaults works isn't compatible with inheritance. + # The inheritance model needs to be reworked in order to support overriding args/ + # kwargs with arguments here, then 'conn_id' and 'sql' can be passed into the + # constructor below and apply_defaults will no longer throw an exception. + super().__init__(**kwargs) + + def poke(self, context: Dict[str, Any]) -> Any: + if self.first_poke: + self.first_poke = False + if "." in self.table: + self.schema, self.table = self.table.split(".") + self.sql = """ + SELECT 'X' + FROM PARTITIONS A0 + LEFT OUTER JOIN TBLS B0 ON A0.TBL_ID = B0.TBL_ID + LEFT OUTER JOIN DBS C0 ON B0.DB_ID = C0.DB_ID + WHERE + B0.TBL_NAME = '{self.table}' AND + C0.NAME = '{self.schema}' AND + A0.PART_NAME = '{self.partition_name}'; + """.format( + self=self + ) + return super().poke(context) diff --git a/reference/providers/apache/hive/sensors/named_hive_partition.py b/reference/providers/apache/hive/sensors/named_hive_partition.py new file mode 100644 index 0000000..d06a7a0 --- /dev/null +++ b/reference/providers/apache/hive/sensors/named_hive_partition.py @@ -0,0 +1,117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, List, Tuple + +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class NamedHivePartitionSensor(BaseSensorOperator): + """ + Waits for a set of partitions to show up in Hive. + + :param partition_names: List of fully qualified names of the + partitions to wait for. A fully qualified name is of the + form ``schema.table/pk1=pv1/pk2=pv2``, for example, + default.users/ds=2016-01-01. This is passed as is to the metastore + Thrift client ``get_partitions_by_name`` method. Note that + you cannot use logical or comparison operators as in + HivePartitionSensor. + :type partition_names: list[str] + :param metastore_conn_id: reference to the metastore thrift service + connection id + :type metastore_conn_id: str + """ + + template_fields = ("partition_names",) + ui_color = "#8d99ae" + poke_context_fields = ("partition_names", "metastore_conn_id") + + @apply_defaults + def __init__( + self, + *, + partition_names: List[str], + metastore_conn_id: str = "metastore_default", + poke_interval: int = 60 * 3, + hook: Any = None, + **kwargs: Any, + ): + super().__init__(poke_interval=poke_interval, **kwargs) + + self.next_index_to_poke = 0 + if isinstance(partition_names, str): + raise TypeError("partition_names must be an array of strings") + + self.metastore_conn_id = metastore_conn_id + self.partition_names = partition_names + self.hook = hook + if self.hook and metastore_conn_id != "metastore_default": + self.log.warning( + "A hook was passed but a non default metastore_conn_id=%s was used", + metastore_conn_id, + ) + + @staticmethod + def parse_partition_name(partition: str) -> Tuple[Any, ...]: + """Get schema, table, and partition info.""" + first_split = partition.split(".", 1) + if len(first_split) == 1: + schema = "default" + table_partition = max(first_split) # poor man first + else: + schema, table_partition = first_split + second_split = table_partition.split("/", 1) + if len(second_split) == 1: + raise ValueError("Could not parse " + partition + "into table, partition") + else: + table, partition = second_split + return schema, table, partition + + def poke_partition(self, partition: str) -> Any: + """Check for a named partition.""" + if not self.hook: + from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook + + self.hook = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) + + schema, table, partition = self.parse_partition_name(partition) + + self.log.info("Poking for %s.%s/%s", schema, table, partition) + return self.hook.check_for_named_partition(schema, table, partition) + + def poke(self, context: Dict[str, Any]) -> bool: + + number_of_partitions = len(self.partition_names) + poke_index_start = self.next_index_to_poke + for i in range(number_of_partitions): + self.next_index_to_poke = (poke_index_start + i) % number_of_partitions + if not self.poke_partition(self.partition_names[self.next_index_to_poke]): + return False + + self.next_index_to_poke = 0 + return True + + def is_smart_sensor_compatible(self): + result = ( + not self.soft_fail + and not self.hook + and len(self.partition_names) <= 30 + and super().is_smart_sensor_compatible() + ) + return result diff --git a/reference/providers/apache/hive/transfers/__init__.py b/reference/providers/apache/hive/transfers/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/hive/transfers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/hive/transfers/hive_to_mysql.py b/reference/providers/apache/hive/transfers/hive_to_mysql.py new file mode 100644 index 0000000..094a5f5 --- /dev/null +++ b/reference/providers/apache/hive/transfers/hive_to_mysql.py @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains operator to move data from Hive to MySQL.""" +from tempfile import NamedTemporaryFile +from typing import Dict, Optional + +from airflow.models import BaseOperator +from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.utils.decorators import apply_defaults +from airflow.utils.operator_helpers import context_to_airflow_vars + + +class HiveToMySqlOperator(BaseOperator): + """ + Moves data from Hive to MySQL, note that for now the data is loaded + into memory before being pushed to MySQL, so this operator should + be used for smallish amount of data. + + :param sql: SQL query to execute against Hive server. (templated) + :type sql: str + :param mysql_table: target MySQL table, use dot notation to target a + specific database. (templated) + :type mysql_table: str + :param mysql_conn_id: source mysql connection + :type mysql_conn_id: str + :param hiveserver2_conn_id: destination hive connection + :type hiveserver2_conn_id: str + :param mysql_preoperator: sql statement to run against mysql prior to + import, typically use to truncate of delete in place + of the data coming in, allowing the task to be idempotent (running + the task twice won't double load data). (templated) + :type mysql_preoperator: str + :param mysql_postoperator: sql statement to run against mysql after the + import, typically used to move data from staging to + production and issue cleanup commands. (templated) + :type mysql_postoperator: str + :param bulk_load: flag to use bulk_load option. This loads mysql directly + from a tab-delimited text file using the LOAD DATA LOCAL INFILE command. + This option requires an extra connection parameter for the + destination MySQL connection: {'local_infile': true}. + :type bulk_load: bool + :param hive_conf: + :type hive_conf: dict + """ + + template_fields = ("sql", "mysql_table", "mysql_preoperator", "mysql_postoperator") + template_ext = (".sql",) + ui_color = "#a0e08c" + + @apply_defaults + def __init__( + self, + *, + sql: str, + mysql_table: str, + hiveserver2_conn_id: str = "hiveserver2_default", + mysql_conn_id: str = "mysql_default", + mysql_preoperator: Optional[str] = None, + mysql_postoperator: Optional[str] = None, + bulk_load: bool = False, + hive_conf: Optional[Dict] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sql = sql + self.mysql_table = mysql_table + self.mysql_conn_id = mysql_conn_id + self.mysql_preoperator = mysql_preoperator + self.mysql_postoperator = mysql_postoperator + self.hiveserver2_conn_id = hiveserver2_conn_id + self.bulk_load = bulk_load + self.hive_conf = hive_conf + + def execute(self, context): + hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) + + self.log.info("Extracting data from Hive: %s", self.sql) + hive_conf = context_to_airflow_vars(context) + if self.hive_conf: + hive_conf.update(self.hive_conf) + if self.bulk_load: + tmp_file = NamedTemporaryFile() + hive.to_csv( + self.sql, + tmp_file.name, + delimiter="\t", + lineterminator="\n", + output_header=False, + hive_conf=hive_conf, + ) + else: + hive_results = hive.get_records(self.sql, hive_conf=hive_conf) + + mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) + + if self.mysql_preoperator: + self.log.info("Running MySQL preoperator") + mysql.run(self.mysql_preoperator) + + self.log.info("Inserting rows into MySQL") + if self.bulk_load: + mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name) + tmp_file.close() + else: + mysql.insert_rows(table=self.mysql_table, rows=hive_results) + + if self.mysql_postoperator: + self.log.info("Running MySQL postoperator") + mysql.run(self.mysql_postoperator) + + self.log.info("Done.") diff --git a/reference/providers/apache/hive/transfers/hive_to_samba.py b/reference/providers/apache/hive/transfers/hive_to_samba.py new file mode 100644 index 0000000..864b9ae --- /dev/null +++ b/reference/providers/apache/hive/transfers/hive_to_samba.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains operator to move data from Hive to Samba.""" + +from tempfile import NamedTemporaryFile + +from airflow.models import BaseOperator +from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook +from airflow.providers.samba.hooks.samba import SambaHook +from airflow.utils.decorators import apply_defaults +from airflow.utils.operator_helpers import context_to_airflow_vars + + +class HiveToSambaOperator(BaseOperator): + """ + Executes hql code in a specific Hive database and loads the + results of the query as a csv to a Samba location. + + :param hql: the hql to be exported. (templated) + :type hql: str + :param destination_filepath: the file path to where the file will be pushed onto samba + :type destination_filepath: str + :param samba_conn_id: reference to the samba destination + :type samba_conn_id: str + :param hiveserver2_conn_id: reference to the hiveserver2 service + :type hiveserver2_conn_id: str + """ + + template_fields = ("hql", "destination_filepath") + template_ext = ( + ".hql", + ".sql", + ) + + @apply_defaults + def __init__( + self, + *, + hql: str, + destination_filepath: str, + samba_conn_id: str = "samba_default", + hiveserver2_conn_id: str = "hiveserver2_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.hiveserver2_conn_id = hiveserver2_conn_id + self.samba_conn_id = samba_conn_id + self.destination_filepath = destination_filepath + self.hql = hql.strip().rstrip(";") + + def execute(self, context): + with NamedTemporaryFile() as tmp_file: + self.log.info("Fetching file from Hive") + hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) + hive.to_csv( + hql=self.hql, + csv_filepath=tmp_file.name, + hive_conf=context_to_airflow_vars(context), + ) + self.log.info("Pushing to samba") + samba = SambaHook(samba_conn_id=self.samba_conn_id) + samba.push_from_local(self.destination_filepath, tmp_file.name) diff --git a/reference/providers/apache/hive/transfers/mssql_to_hive.py b/reference/providers/apache/hive/transfers/mssql_to_hive.py new file mode 100644 index 0000000..7cf01b4 --- /dev/null +++ b/reference/providers/apache/hive/transfers/mssql_to_hive.py @@ -0,0 +1,144 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains operator to move data from MSSQL to Hive.""" + +from collections import OrderedDict +from tempfile import NamedTemporaryFile +from typing import Dict, Optional + +import pymssql +import unicodecsv as csv +from airflow.models import BaseOperator +from airflow.providers.apache.hive.hooks.hive import HiveCliHook +from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook +from airflow.utils.decorators import apply_defaults + + +class MsSqlToHiveOperator(BaseOperator): + """ + Moves data from Microsoft SQL Server to Hive. The operator runs + your query against Microsoft SQL Server, stores the file locally + before loading it into a Hive table. If the ``create`` or + ``recreate`` arguments are set to ``True``, + a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated. + Hive data types are inferred from the cursor's metadata. + Note that the table generated in Hive uses ``STORED AS textfile`` + which isn't the most efficient serialization format. If a + large amount of data is loaded and/or if the table gets + queried considerably, you may want to use this operator only to + stage the data into a temporary table before loading it into its + final destination using a ``HiveOperator``. + + :param sql: SQL query to execute against the Microsoft SQL Server + database. (templated) + :type sql: str + :param hive_table: target Hive table, use dot notation to target a specific + database. (templated) + :type hive_table: str + :param create: whether to create the table if it doesn't exist + :type create: bool + :param recreate: whether to drop and recreate the table at every execution + :type recreate: bool + :param partition: target partition as a dict of partition columns and + values. (templated) + :type partition: dict + :param delimiter: field delimiter in the file + :type delimiter: str + :param mssql_conn_id: source Microsoft SQL Server connection + :type mssql_conn_id: str + :param hive_conn_id: destination hive connection + :type hive_conn_id: str + :param tblproperties: TBLPROPERTIES of the hive table being created + :type tblproperties: dict + """ + + template_fields = ("sql", "partition", "hive_table") + template_ext = (".sql",) + ui_color = "#a0e08c" + + @apply_defaults + def __init__( + self, + *, + sql: str, + hive_table: str, + create: bool = True, + recreate: bool = False, + partition: Optional[Dict] = None, + delimiter: str = chr(1), + mssql_conn_id: str = "mssql_default", + hive_cli_conn_id: str = "hive_cli_default", + tblproperties: Optional[Dict] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sql = sql + self.hive_table = hive_table + self.partition = partition + self.create = create + self.recreate = recreate + self.delimiter = delimiter + self.mssql_conn_id = mssql_conn_id + self.hive_cli_conn_id = hive_cli_conn_id + self.partition = partition or {} + self.tblproperties = tblproperties + + @classmethod + def type_map(cls, mssql_type: int) -> str: + """Maps MsSQL type to Hive type.""" + map_dict = { + pymssql.BINARY.value: "INT", # pylint: disable=c-extension-no-member + pymssql.DECIMAL.value: "FLOAT", # pylint: disable=c-extension-no-member + pymssql.NUMBER.value: "INT", # pylint: disable=c-extension-no-member + } + return map_dict.get(mssql_type, "STRING") + + def execute(self, context: Dict[str, str]): + mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id) + self.log.info("Dumping Microsoft SQL Server query results to local file") + with mssql.get_conn() as conn: + with conn.cursor() as cursor: + cursor.execute(self.sql) + with NamedTemporaryFile("w") as tmp_file: + csv_writer = csv.writer( + tmp_file, delimiter=self.delimiter, encoding="utf-8" + ) + field_dict = OrderedDict() + col_count = 0 + for field in cursor.description: + col_count += 1 + col_position = f"Column{col_count}" + field_dict[ + col_position if field[0] == "" else field[0] + ] = self.type_map(field[1]) + csv_writer.writerows(cursor) + tmp_file.flush() + + hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) + self.log.info("Loading file into Hive") + hive.load_file( + tmp_file.name, + self.hive_table, + field_dict=field_dict, + create=self.create, + partition=self.partition, + delimiter=self.delimiter, + recreate=self.recreate, + tblproperties=self.tblproperties, + ) diff --git a/reference/providers/apache/hive/transfers/mysql_to_hive.py b/reference/providers/apache/hive/transfers/mysql_to_hive.py new file mode 100644 index 0000000..fae85ba --- /dev/null +++ b/reference/providers/apache/hive/transfers/mysql_to_hive.py @@ -0,0 +1,170 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains operator to move data from MySQL to Druid.""" + +from collections import OrderedDict +from tempfile import NamedTemporaryFile +from typing import Dict, Optional + +import MySQLdb +import unicodecsv as csv +from airflow.models import BaseOperator +from airflow.providers.apache.hive.hooks.hive import HiveCliHook +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.utils.decorators import apply_defaults + + +class MySqlToHiveOperator(BaseOperator): + """ + Moves data from MySql to Hive. The operator runs your query against + MySQL, stores the file locally before loading it into a Hive table. + If the ``create`` or ``recreate`` arguments are set to ``True``, + a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated. + Hive data types are inferred from the cursor's metadata. Note that the + table generated in Hive uses ``STORED AS textfile`` + which isn't the most efficient serialization format. If a + large amount of data is loaded and/or if the table gets + queried considerably, you may want to use this operator only to + stage the data into a temporary table before loading it into its + final destination using a ``HiveOperator``. + + :param sql: SQL query to execute against the MySQL database. (templated) + :type sql: str + :param hive_table: target Hive table, use dot notation to target a + specific database. (templated) + :type hive_table: str + :param create: whether to create the table if it doesn't exist + :type create: bool + :param recreate: whether to drop and recreate the table at every + execution + :type recreate: bool + :param partition: target partition as a dict of partition columns + and values. (templated) + :type partition: dict + :param delimiter: field delimiter in the file + :type delimiter: str + :param quoting: controls when quotes should be generated by csv writer, + It can take on any of the csv.QUOTE_* constants. + :type quoting: str + :param quotechar: one-character string used to quote fields + containing special characters. + :type quotechar: str + :param escapechar: one-character string used by csv writer to escape + the delimiter or quotechar. + :type escapechar: str + :param mysql_conn_id: source mysql connection + :type mysql_conn_id: str + :param hive_conn_id: destination hive connection + :type hive_conn_id: str + :param tblproperties: TBLPROPERTIES of the hive table being created + :type tblproperties: dict + """ + + template_fields = ("sql", "partition", "hive_table") + template_ext = (".sql",) + ui_color = "#a0e08c" + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + sql: str, + hive_table: str, + create: bool = True, + recreate: bool = False, + partition: Optional[Dict] = None, + delimiter: str = chr(1), + quoting: Optional[str] = None, + quotechar: str = '"', + escapechar: Optional[str] = None, + mysql_conn_id: str = "mysql_default", + hive_cli_conn_id: str = "hive_cli_default", + tblproperties: Optional[Dict] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sql = sql + self.hive_table = hive_table + self.partition = partition + self.create = create + self.recreate = recreate + self.delimiter = str(delimiter) + self.quoting = quoting or csv.QUOTE_MINIMAL + self.quotechar = quotechar + self.escapechar = escapechar + self.mysql_conn_id = mysql_conn_id + self.hive_cli_conn_id = hive_cli_conn_id + self.partition = partition or {} + self.tblproperties = tblproperties + + @classmethod + def type_map(cls, mysql_type: int) -> str: + """Maps MySQL type to Hive type.""" + types = MySQLdb.constants.FIELD_TYPE + type_map = { + types.BIT: "INT", + types.DECIMAL: "DOUBLE", + types.NEWDECIMAL: "DOUBLE", + types.DOUBLE: "DOUBLE", + types.FLOAT: "DOUBLE", + types.INT24: "INT", + types.LONG: "BIGINT", + types.LONGLONG: "DECIMAL(38,0)", + types.SHORT: "INT", + types.TINY: "SMALLINT", + types.YEAR: "INT", + types.TIMESTAMP: "TIMESTAMP", + } + return type_map.get(mysql_type, "STRING") + + def execute(self, context: Dict[str, str]): + hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) + mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) + + self.log.info("Dumping MySQL query results to local file") + conn = mysql.get_conn() + cursor = conn.cursor() + cursor.execute(self.sql) + with NamedTemporaryFile("wb") as f: + csv_writer = csv.writer( + f, + delimiter=self.delimiter, + quoting=self.quoting, + quotechar=self.quotechar, + escapechar=self.escapechar, + encoding="utf-8", + ) + field_dict = OrderedDict() + for field in cursor.description: + field_dict[field[0]] = self.type_map(field[1]) + csv_writer.writerows(cursor) + f.flush() + cursor.close() + conn.close() + self.log.info("Loading file into Hive") + hive.load_file( + f.name, + self.hive_table, + field_dict=field_dict, + create=self.create, + partition=self.partition, + delimiter=self.delimiter, + recreate=self.recreate, + tblproperties=self.tblproperties, + ) diff --git a/reference/providers/apache/hive/transfers/s3_to_hive.py b/reference/providers/apache/hive/transfers/s3_to_hive.py new file mode 100644 index 0000000..959ad80 --- /dev/null +++ b/reference/providers/apache/hive/transfers/s3_to_hive.py @@ -0,0 +1,298 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains operator to move data from Hive to S3 bucket.""" + +import bz2 +import gzip +import os +import tempfile +from tempfile import NamedTemporaryFile, TemporaryDirectory +from typing import Dict, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.apache.hive.hooks.hive import HiveCliHook +from airflow.utils.compression import uncompress_file +from airflow.utils.decorators import apply_defaults + + +class S3ToHiveOperator(BaseOperator): # pylint: disable=too-many-instance-attributes + """ + Moves data from S3 to Hive. The operator downloads a file from S3, + stores the file locally before loading it into a Hive table. + If the ``create`` or ``recreate`` arguments are set to ``True``, + a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated. + Hive data types are inferred from the cursor's metadata from. + + Note that the table generated in Hive uses ``STORED AS textfile`` + which isn't the most efficient serialization format. If a + large amount of data is loaded and/or if the tables gets + queried considerably, you may want to use this operator only to + stage the data into a temporary table before loading it into its + final destination using a ``HiveOperator``. + + :param s3_key: The key to be retrieved from S3. (templated) + :type s3_key: str + :param field_dict: A dictionary of the fields name in the file + as keys and their Hive types as values + :type field_dict: dict + :param hive_table: target Hive table, use dot notation to target a + specific database. (templated) + :type hive_table: str + :param delimiter: field delimiter in the file + :type delimiter: str + :param create: whether to create the table if it doesn't exist + :type create: bool + :param recreate: whether to drop and recreate the table at every + execution + :type recreate: bool + :param partition: target partition as a dict of partition columns + and values. (templated) + :type partition: dict + :param headers: whether the file contains column names on the first + line + :type headers: bool + :param check_headers: whether the column names on the first line should be + checked against the keys of field_dict + :type check_headers: bool + :param wildcard_match: whether the s3_key should be interpreted as a Unix + wildcard pattern + :type wildcard_match: bool + :param aws_conn_id: source s3 connection + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str + :param hive_cli_conn_id: destination hive connection + :type hive_cli_conn_id: str + :param input_compressed: Boolean to determine if file decompression is + required to process headers + :type input_compressed: bool + :param tblproperties: TBLPROPERTIES of the hive table being created + :type tblproperties: dict + :param select_expression: S3 Select expression + :type select_expression: str + """ + + template_fields = ("s3_key", "partition", "hive_table") + template_ext = () + ui_color = "#a0e08c" + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + s3_key: str, + field_dict: Dict, + hive_table: str, + delimiter: str = ",", + create: bool = True, + recreate: bool = False, + partition: Optional[Dict] = None, + headers: bool = False, + check_headers: bool = False, + wildcard_match: bool = False, + aws_conn_id: str = "aws_default", + verify: Optional[Union[bool, str]] = None, + hive_cli_conn_id: str = "hive_cli_default", + input_compressed: bool = False, + tblproperties: Optional[Dict] = None, + select_expression: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.s3_key = s3_key + self.field_dict = field_dict + self.hive_table = hive_table + self.delimiter = delimiter + self.create = create + self.recreate = recreate + self.partition = partition + self.headers = headers + self.check_headers = check_headers + self.wildcard_match = wildcard_match + self.hive_cli_conn_id = hive_cli_conn_id + self.aws_conn_id = aws_conn_id + self.verify = verify + self.input_compressed = input_compressed + self.tblproperties = tblproperties + self.select_expression = select_expression + + if self.check_headers and not (self.field_dict is not None and self.headers): + raise AirflowException( + "To check_headers provide " + "field_dict and headers" + ) + + def execute(self, context): + # Downloading file from S3 + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + hive_hook = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) + self.log.info("Downloading S3 file") + + if self.wildcard_match: + if not s3_hook.check_for_wildcard_key(self.s3_key): + raise AirflowException(f"No key matches {self.s3_key}") + s3_key_object = s3_hook.get_wildcard_key(self.s3_key) + else: + if not s3_hook.check_for_key(self.s3_key): + raise AirflowException(f"The key {self.s3_key} does not exists") + s3_key_object = s3_hook.get_key(self.s3_key) + + _, file_ext = os.path.splitext(s3_key_object.key) + if ( + self.select_expression + and self.input_compressed + and file_ext.lower() != ".gz" + ): + raise AirflowException( + "GZIP is the only compression " + "format Amazon S3 Select supports" + ) + + with TemporaryDirectory(prefix="tmps32hive_") as tmp_dir, NamedTemporaryFile( + mode="wb", dir=tmp_dir, suffix=file_ext + ) as f: + self.log.info( + "Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name + ) + if self.select_expression: + option = {} + if self.headers: + option["FileHeaderInfo"] = "USE" + if self.delimiter: + option["FieldDelimiter"] = self.delimiter + + input_serialization = {"CSV": option} + if self.input_compressed: + input_serialization["CompressionType"] = "GZIP" + + content = s3_hook.select_key( + bucket_name=s3_key_object.bucket_name, + key=s3_key_object.key, + expression=self.select_expression, + input_serialization=input_serialization, + ) + f.write(content.encode("utf-8")) + else: + s3_key_object.download_fileobj(f) + f.flush() + + if self.select_expression or not self.headers: + self.log.info("Loading file %s into Hive", f.name) + hive_hook.load_file( + f.name, + self.hive_table, + field_dict=self.field_dict, + create=self.create, + partition=self.partition, + delimiter=self.delimiter, + recreate=self.recreate, + tblproperties=self.tblproperties, + ) + else: + # Decompressing file + if self.input_compressed: + self.log.info("Uncompressing file %s", f.name) + fn_uncompressed = uncompress_file(f.name, file_ext, tmp_dir) + self.log.info("Uncompressed to %s", fn_uncompressed) + # uncompressed file available now so deleting + # compressed file to save disk space + f.close() + else: + fn_uncompressed = f.name + + # Testing if header matches field_dict + if self.check_headers: + self.log.info("Matching file header against field_dict") + header_list = self._get_top_row_as_list(fn_uncompressed) + if not self._match_headers(header_list): + raise AirflowException("Header check failed") + + # Deleting top header row + self.log.info("Removing header from file %s", fn_uncompressed) + headless_file = self._delete_top_row_and_compress( + fn_uncompressed, file_ext, tmp_dir + ) + self.log.info("Headless file %s", headless_file) + self.log.info("Loading file %s into Hive", headless_file) + hive_hook.load_file( + headless_file, + self.hive_table, + field_dict=self.field_dict, + create=self.create, + partition=self.partition, + delimiter=self.delimiter, + recreate=self.recreate, + tblproperties=self.tblproperties, + ) + + def _get_top_row_as_list(self, file_name): + with open(file_name) as file: + header_line = file.readline().strip() + header_list = header_line.split(self.delimiter) + return header_list + + def _match_headers(self, header_list): + if not header_list: + raise AirflowException("Unable to retrieve header row from file") + field_names = self.field_dict.keys() + if len(field_names) != len(header_list): + self.log.warning( + "Headers count mismatch File headers:\n %s\nField names: \n %s\n", + header_list, + field_names, + ) + return False + test_field_match = [ + h1.lower() == h2.lower() for h1, h2 in zip(header_list, field_names) + ] + if not all(test_field_match): + self.log.warning( + "Headers do not match field names File headers:\n %s\nField names: \n %s\n", + header_list, + field_names, + ) + return False + else: + return True + + @staticmethod + def _delete_top_row_and_compress(input_file_name, output_file_ext, dest_dir): + # When output_file_ext is not defined, file is not compressed + open_fn = open + if output_file_ext.lower() == ".gz": + open_fn = gzip.GzipFile + elif output_file_ext.lower() == ".bz2": + open_fn = bz2.BZ2File + + _, fn_output = tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir) + with open(input_file_name, "rb") as f_in, open_fn(fn_output, "wb") as f_out: + f_in.seek(0) + next(f_in) + for line in f_in: + f_out.write(line) + return fn_output diff --git a/reference/providers/apache/hive/transfers/vertica_to_hive.py b/reference/providers/apache/hive/transfers/vertica_to_hive.py new file mode 100644 index 0000000..2f3a1de --- /dev/null +++ b/reference/providers/apache/hive/transfers/vertica_to_hive.py @@ -0,0 +1,144 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains operator to move data from Vertica to Hive.""" + +from collections import OrderedDict +from tempfile import NamedTemporaryFile + +import unicodecsv as csv +from airflow.models import BaseOperator +from airflow.providers.apache.hive.hooks.hive import HiveCliHook +from airflow.providers.vertica.hooks.vertica import VerticaHook +from airflow.utils.decorators import apply_defaults + + +class VerticaToHiveOperator(BaseOperator): + """ + Moves data from Vertica to Hive. The operator runs + your query against Vertica, stores the file locally + before loading it into a Hive table. If the ``create`` or + ``recreate`` arguments are set to ``True``, + a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated. + Hive data types are inferred from the cursor's metadata. + Note that the table generated in Hive uses ``STORED AS textfile`` + which isn't the most efficient serialization format. If a + large amount of data is loaded and/or if the table gets + queried considerably, you may want to use this operator only to + stage the data into a temporary table before loading it into its + final destination using a ``HiveOperator``. + + :param sql: SQL query to execute against the Vertica database. (templated) + :type sql: str + :param hive_table: target Hive table, use dot notation to target a + specific database. (templated) + :type hive_table: str + :param create: whether to create the table if it doesn't exist + :type create: bool + :param recreate: whether to drop and recreate the table at every execution + :type recreate: bool + :param partition: target partition as a dict of partition columns + and values. (templated) + :type partition: dict + :param delimiter: field delimiter in the file + :type delimiter: str + :param vertica_conn_id: source Vertica connection + :type vertica_conn_id: str + :param hive_conn_id: destination hive connection + :type hive_conn_id: str + + """ + + template_fields = ("sql", "partition", "hive_table") + template_ext = (".sql",) + ui_color = "#b4e0ff" + + @apply_defaults + def __init__( + self, + *, + sql, + hive_table, + create=True, + recreate=False, + partition=None, + delimiter=chr(1), + vertica_conn_id="vertica_default", + hive_cli_conn_id="hive_cli_default", + **kwargs, + ): + super().__init__(**kwargs) + self.sql = sql + self.hive_table = hive_table + self.partition = partition + self.create = create + self.recreate = recreate + self.delimiter = str(delimiter) + self.vertica_conn_id = vertica_conn_id + self.hive_cli_conn_id = hive_cli_conn_id + self.partition = partition or {} + + @classmethod + def type_map(cls, vertica_type): + """ + Vertica-python datatype.py does not provide the full type mapping access. + Manual hack. Reference: + https://github.com/uber/vertica-python/blob/master/vertica_python/vertica/column.py + """ + type_map = { + 5: "BOOLEAN", + 6: "INT", + 7: "FLOAT", + 8: "STRING", + 9: "STRING", + 16: "FLOAT", + } + return type_map.get(vertica_type, "STRING") + + def execute(self, context): + hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) + vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) + + self.log.info("Dumping Vertica query results to local file") + conn = vertica.get_conn() + cursor = conn.cursor() + cursor.execute(self.sql) + with NamedTemporaryFile("w") as f: + csv_writer = csv.writer(f, delimiter=self.delimiter, encoding="utf-8") + field_dict = OrderedDict() + col_count = 0 + for field in cursor.description: + col_count += 1 + col_position = f"Column{col_count}" + field_dict[ + col_position if field[0] == "" else field[0] + ] = self.type_map(field[1]) + csv_writer.writerows(cursor.iterate()) + f.flush() + cursor.close() + conn.close() + self.log.info("Loading file into Hive") + hive.load_file( + f.name, + self.hive_table, + field_dict=field_dict, + create=self.create, + partition=self.partition, + delimiter=self.delimiter, + recreate=self.recreate, + ) diff --git a/reference/providers/apache/kylin/CHANGELOG.rst b/reference/providers/apache/kylin/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/apache/kylin/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/apache/kylin/__init__.py b/reference/providers/apache/kylin/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/kylin/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/kylin/example_dags/__init__.py b/reference/providers/apache/kylin/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/kylin/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/kylin/example_dags/example_kylin_dag.py b/reference/providers/apache/kylin/example_dags/example_kylin_dag.py new file mode 100644 index 0000000..fce2321 --- /dev/null +++ b/reference/providers/apache/kylin/example_dags/example_kylin_dag.py @@ -0,0 +1,134 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This is an example DAG which uses the KylinCubeOperator. +The tasks below include kylin build, refresh, merge operation. +""" +from airflow import DAG +from airflow.operators.python import PythonOperator +from airflow.providers.apache.kylin.operators.kylin_cube import KylinCubeOperator +from airflow.utils.dates import days_ago + +args = { + "owner": "airflow", +} + +dag = DAG( + dag_id="example_kylin_operator", + default_args=args, + schedule_interval=None, + start_date=days_ago(1), + tags=["example"], +) + + +def gen_build_time(**kwargs): + """ + Gen build time and push to xcom + :param kwargs: + :return: + """ + ti = kwargs["ti"] + ti.xcom_push(key="date_start", value="1325347200000") + ti.xcom_push(key="date_end", value="1325433600000") + + +gen_build_time_task = PythonOperator( + python_callable=gen_build_time, task_id="gen_build_time", dag=dag +) + +build_task1 = KylinCubeOperator( + task_id="kylin_build_1", + kylin_conn_id="kylin_default", + project="learn_kylin", + cube="kylin_sales_cube", + command="build", + start_time="{{ task_instance.xcom_pull(task_ids='gen_build_time',key='date_start') }}", + end_time="{{ task_instance.xcom_pull(task_ids='gen_build_time',key='date_end') }}", + is_track_job=True, + dag=dag, +) + +build_task2 = KylinCubeOperator( + task_id="kylin_build_2", + kylin_conn_id="kylin_default", + project="learn_kylin", + cube="kylin_sales_cube", + command="build", + start_time="1325433600000", + end_time="1325520000000", + is_track_job=True, + dag=dag, +) + +refresh_task1 = KylinCubeOperator( + task_id="kylin_refresh_1", + kylin_conn_id="kylin_default", + project="learn_kylin", + cube="kylin_sales_cube", + command="refresh", + start_time="1325347200000", + end_time="1325433600000", + is_track_job=True, + dag=dag, +) + +merge_task = KylinCubeOperator( + task_id="kylin_merge", + kylin_conn_id="kylin_default", + project="learn_kylin", + cube="kylin_sales_cube", + command="merge", + start_time="1325347200000", + end_time="1325520000000", + is_track_job=True, + dag=dag, +) + +disable_task = KylinCubeOperator( + task_id="kylin_disable", + kylin_conn_id="kylin_default", + project="learn_kylin", + cube="kylin_sales_cube", + command="disable", + dag=dag, +) + +purge_task = KylinCubeOperator( + task_id="kylin_purge", + kylin_conn_id="kylin_default", + project="learn_kylin", + cube="kylin_sales_cube", + command="purge", + dag=dag, +) + +build_task3 = KylinCubeOperator( + task_id="kylin_build_3", + kylin_conn_id="kylin_default", + project="learn_kylin", + cube="kylin_sales_cube", + command="build", + start_time="1325433600000", + end_time="1325520000000", + dag=dag, +) + +gen_build_time_task >> build_task1 >> build_task2 >> refresh_task1 >> merge_task +merge_task >> disable_task >> purge_task >> build_task3 diff --git a/reference/providers/apache/kylin/hooks/__init__.py b/reference/providers/apache/kylin/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/kylin/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/kylin/hooks/kylin.py b/reference/providers/apache/kylin/hooks/kylin.py new file mode 100644 index 0000000..1af6f3d --- /dev/null +++ b/reference/providers/apache/kylin/hooks/kylin.py @@ -0,0 +1,85 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from kylinpy import exceptions, kylinpy + + +class KylinHook(BaseHook): + """ + :param kylin_conn_id: The connection id as configured in Airflow administration. + :type kylin_conn_id: str + :param project: project name + :type project: Optional[str] + :param dsn: dsn + :type dsn: Optional[str] + """ + + def __init__( + self, + kylin_conn_id: str = "kylin_default", + project: Optional[str] = None, + dsn: Optional[str] = None, + ): + super().__init__() + self.kylin_conn_id = kylin_conn_id + self.project = project + self.dsn = dsn + + def get_conn(self): + conn = self.get_connection(self.kylin_conn_id) + if self.dsn: + return kylinpy.create_kylin(self.dsn) + else: + self.project = self.project if self.project else conn.schema + return kylinpy.Kylin( + conn.host, + username=conn.login, + password=conn.password, + port=conn.port, + project=self.project, + **conn.extra_dejson, + ) + + def cube_run(self, datasource_name, op, **op_args): + """ + Run CubeSource command which in CubeSource.support_invoke_command + + :param datasource_name: + :param op: command + :param op_args: command args + :return: response + """ + cube_source = self.get_conn().get_datasource(datasource_name) + try: + response = cube_source.invoke_command(op, **op_args) + return response + except exceptions.KylinError as err: + raise AirflowException(f"Cube operation {op} error , Message: {err}") + + def get_job_status(self, job_id): + """ + Get job status + + :param job_id: kylin job id + :return: job status + """ + return self.get_conn().get_job(job_id).status diff --git a/reference/providers/apache/kylin/operators/__init__.py b/reference/providers/apache/kylin/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/kylin/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/kylin/operators/kylin_cube.py b/reference/providers/apache/kylin/operators/kylin_cube.py new file mode 100644 index 0000000..c2a1fc3 --- /dev/null +++ b/reference/providers/apache/kylin/operators/kylin_cube.py @@ -0,0 +1,203 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +from datetime import datetime +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.apache.kylin.hooks.kylin import KylinHook +from airflow.utils.decorators import apply_defaults +from kylinpy import kylinpy + + +class KylinCubeOperator(BaseOperator): + """ + This operator is used to submit request about kylin build/refresh/merge, + and can track job status . so users can easier to build kylin job + + For more detail information in + `Apache Kylin `_ + + :param kylin_conn_id: The connection id as configured in Airflow administration. + :type kylin_conn_id: str + :param project: kylin project name, this param will overwrite the project in kylin_conn_id: + :type project: str + :param cube: kylin cube name + :type cube: str + :param dsn: (dsn , dsn url of kylin connection ,which will overwrite kylin_conn_id. + for example: kylin://ADMIN:KYLIN@sandbox/learn_kylin?timeout=60&is_debug=1) + :type dsn: str + :param command: (kylin command include 'build', 'merge', 'refresh', 'delete', + 'build_streaming', 'merge_streaming', 'refresh_streaming', 'disable', 'enable', + 'purge', 'clone', 'drop'. + build - use /kylin/api/cubes/{cubeName}/build rest api,and buildType is ‘BUILD’, + and you should give start_time and end_time + refresh - use build rest api,and buildType is ‘REFRESH’ + merge - use build rest api,and buildType is ‘MERGE’ + build_streaming - use /kylin/api/cubes/{cubeName}/build2 rest api,and buildType is ‘BUILD’ + and you should give offset_start and offset_end + refresh_streaming - use build2 rest api,and buildType is ‘REFRESH’ + merge_streaming - use build2 rest api,and buildType is ‘MERGE’ + delete - delete segment, and you should give segment_name value + disable - disable cube + enable - enable cube + purge - purge cube + clone - clone cube,new cube name is {cube_name}_clone + drop - drop cube) + :type command: str + :param start_time: build segment start time + :type start_time: Optional[str] + :param end_time: build segment end time + :type end_time: Optional[str] + :param offset_start: streaming build segment start time + :type offset_start: Optional[str] + :param offset_end: streaming build segment end time + :type offset_end: Optional[str] + :param segment_name: segment name + :type segment_name: str + :param is_track_job: (whether to track job status. if value is True,will track job until + job status is in("FINISHED", "ERROR", "DISCARDED", "KILLED", "SUICIDAL", + "STOPPED") or timeout) + :type is_track_job: bool + :param interval: track job status,default value is 60s + :type interval: int + :param timeout: timeout value,default value is 1 day,60 * 60 * 24 s + :type timeout: int + :param eager_error_status: (jobs error status,if job status in this list ,this task will be error. + default value is tuple(["ERROR", "DISCARDED", "KILLED", "SUICIDAL", "STOPPED"])) + :type eager_error_status: tuple + """ + + template_fields = ( + "project", + "cube", + "dsn", + "command", + "start_time", + "end_time", + "segment_name", + "offset_start", + "offset_end", + ) + ui_color = "#E79C46" + build_command = { + "fullbuild", + "build", + "merge", + "refresh", + "build_streaming", + "merge_streaming", + "refresh_streaming", + } + jobs_end_status = { + "FINISHED", + "ERROR", + "DISCARDED", + "KILLED", + "SUICIDAL", + "STOPPED", + } + + # pylint: disable=too-many-arguments,inconsistent-return-statements + @apply_defaults + def __init__( + self, + *, + kylin_conn_id: str = "kylin_default", + project: Optional[str] = None, + cube: Optional[str] = None, + dsn: Optional[str] = None, + command: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + offset_start: Optional[str] = None, + offset_end: Optional[str] = None, + segment_name: Optional[str] = None, + is_track_job: bool = False, + interval: int = 60, + timeout: int = 60 * 60 * 24, + eager_error_status=("ERROR", "DISCARDED", "KILLED", "SUICIDAL", "STOPPED"), + **kwargs, + ): + super().__init__(**kwargs) + self.kylin_conn_id = kylin_conn_id + self.project = project + self.cube = cube + self.dsn = dsn + self.command = command + self.start_time = start_time + self.end_time = end_time + self.segment_name = segment_name + self.offset_start = offset_start + self.offset_end = offset_end + self.is_track_job = is_track_job + self.interval = interval + self.timeout = timeout + self.eager_error_status = eager_error_status + self.jobs_error_status = [stat.upper() for stat in eager_error_status] + + def execute(self, context): + + _hook = KylinHook( + kylin_conn_id=self.kylin_conn_id, project=self.project, dsn=self.dsn + ) + + _support_invoke_command = kylinpy.CubeSource.support_invoke_command + if self.command.lower() not in _support_invoke_command: + raise AirflowException( + "Kylin:Command {} can not match kylin command list {}".format( + self.command, _support_invoke_command + ) + ) + + kylinpy_params = { + "start": datetime.fromtimestamp(int(self.start_time) / 1000) + if self.start_time + else None, + "end": datetime.fromtimestamp(int(self.end_time) / 1000) + if self.end_time + else None, + "name": self.segment_name, + "offset_start": int(self.offset_start) if self.offset_start else None, + "offset_end": int(self.offset_end) if self.offset_end else None, + } + rsp_data = _hook.cube_run(self.cube, self.command.lower(), **kylinpy_params) + if self.is_track_job and self.command.lower() in self.build_command: + started_at = time.monotonic() + job_id = rsp_data.get("uuid") + if job_id is None: + raise AirflowException("kylin job id is None") + self.log.info("kylin job id: %s", job_id) + + job_status = None + while job_status not in self.jobs_end_status: + if time.monotonic() - started_at > self.timeout: + raise AirflowException(f"kylin job {job_id} timeout") + time.sleep(self.interval) + + job_status = _hook.get_job_status(job_id) + self.log.info("Kylin job status is %s ", job_status) + if job_status in self.jobs_error_status: + raise AirflowException( + f"Kylin job {job_id} status {job_status} is error " + ) + + if self.do_xcom_push: + return rsp_data diff --git a/reference/providers/apache/kylin/provider.yaml b/reference/providers/apache/kylin/provider.yaml new file mode 100644 index 0000000..49d5cdb --- /dev/null +++ b/reference/providers/apache/kylin/provider.yaml @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-apache-kylin +name: Apache Kylin +description: | + `Apache Kylin `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Apache Kylin + external-doc-url: https://kylin.apache.org/ + logo: /integration-logos/apache/kylin.png + tags: [apache] + +operators: + - integration-name: Apache Kylin + python-modules: + - airflow.providers.apache.kylin.operators.kylin_cube + +hooks: + - integration-name: Apache Kylin + python-modules: + - airflow.providers.apache.kylin.hooks.kylin diff --git a/reference/providers/apache/livy/CHANGELOG.rst b/reference/providers/apache/livy/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/apache/livy/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/apache/livy/__init__.py b/reference/providers/apache/livy/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/livy/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/livy/example_dags/__init__.py b/reference/providers/apache/livy/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/livy/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/livy/example_dags/example_livy.py b/reference/providers/apache/livy/example_dags/example_livy.py new file mode 100644 index 0000000..d4b03c8 --- /dev/null +++ b/reference/providers/apache/livy/example_dags/example_livy.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This is an example DAG which uses the LivyOperator. +The tasks below trigger the computation of pi on the Spark instance +using the Java and Python executables provided in the example library. +""" + +from airflow import DAG +from airflow.providers.apache.livy.operators.livy import LivyOperator +from airflow.utils.dates import days_ago + +args = {"owner": "airflow", "email": ["airflow@example.com"], "depends_on_past": False} + +with DAG( + dag_id="example_livy_operator", + default_args=args, + schedule_interval="@daily", + start_date=days_ago(5), +) as dag: + + livy_java_task = LivyOperator( + task_id="pi_java_task", + dag=dag, + livy_conn_id="livy_conn_default", + file="/spark-examples.jar", + args=[10], + num_executors=1, + conf={ + "spark.shuffle.compress": "false", + }, + class_name="org.apache.spark.examples.SparkPi", + ) + + livy_python_task = LivyOperator( + task_id="pi_python_task", + dag=dag, + livy_conn_id="livy_conn_default", + file="/pi.py", + args=[10], + polling_interval=60, + ) + + livy_java_task >> livy_python_task diff --git a/reference/providers/apache/livy/hooks/__init__.py b/reference/providers/apache/livy/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/livy/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/livy/hooks/livy.py b/reference/providers/apache/livy/hooks/livy.py new file mode 100644 index 0000000..cbd65b1 --- /dev/null +++ b/reference/providers/apache/livy/hooks/livy.py @@ -0,0 +1,427 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Apache Livy hook.""" +import json +import re +from enum import Enum +from typing import Any, Dict, List, Optional, Sequence, Union + +import requests +from airflow.exceptions import AirflowException +from airflow.providers.http.hooks.http import HttpHook +from airflow.utils.log.logging_mixin import LoggingMixin + + +class BatchState(Enum): + """Batch session states""" + + NOT_STARTED = "not_started" + STARTING = "starting" + RUNNING = "running" + IDLE = "idle" + BUSY = "busy" + SHUTTING_DOWN = "shutting_down" + ERROR = "error" + DEAD = "dead" + KILLED = "killed" + SUCCESS = "success" + + +class LivyHook(HttpHook, LoggingMixin): + """ + Hook for Apache Livy through the REST API. + + :param livy_conn_id: reference to a pre-defined Livy Connection. + :type livy_conn_id: str + + .. seealso:: + For more details refer to the Apache Livy API reference: + https://livy.apache.org/docs/latest/rest-api.html + """ + + TERMINAL_STATES = { + BatchState.SUCCESS, + BatchState.DEAD, + BatchState.KILLED, + BatchState.ERROR, + } + + _def_headers = {"Content-Type": "application/json", "Accept": "application/json"} + + conn_name_attr = "livy_conn_id" + default_conn_name = "livy_default" + conn_type = "livy" + hook_name = "Apache Livy" + + def __init__( + self, + livy_conn_id: str = default_conn_name, + extra_options: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(http_conn_id=livy_conn_id) + self.extra_options = extra_options or {} + + def get_conn(self, headers: Optional[Dict[str, Any]] = None) -> Any: + """ + Returns http session for use with requests + + :param headers: additional headers to be passed through as a dictionary + :type headers: dict + :return: requests session + :rtype: requests.Session + """ + tmp_headers = self._def_headers.copy() # setting default headers + if headers: + tmp_headers.update(headers) + return super().get_conn(tmp_headers) + + def run_method( + self, + endpoint: str, + method: str = "GET", + data: Optional[Any] = None, + headers: Optional[Dict[str, Any]] = None, + ) -> Any: + """ + Wrapper for HttpHook, allows to change method on the same HttpHook + + :param method: http method + :type method: str + :param endpoint: endpoint + :type endpoint: str + :param data: request payload + :type data: dict + :param headers: headers + :type headers: dict + :return: http response + :rtype: requests.Response + """ + if method not in ("GET", "POST", "PUT", "DELETE", "HEAD"): + raise ValueError(f"Invalid http method '{method}'") + if not self.extra_options: + self.extra_options = {"check_response": False} + + back_method = self.method + self.method = method + try: + result = self.run(endpoint, data, headers, self.extra_options) + finally: + self.method = back_method + return result + + def post_batch(self, *args: Any, **kwargs: Any) -> Any: + """ + Perform request to submit batch + + :return: batch session id + :rtype: int + """ + batch_submit_body = json.dumps(self.build_post_batch_body(*args, **kwargs)) + + if self.base_url is None: + # need to init self.base_url + self.get_conn() + self.log.info("Submitting job %s to %s", batch_submit_body, self.base_url) + + response = self.run_method( + method="POST", endpoint="/batches", data=batch_submit_body + ) + self.log.debug("Got response: %s", response.text) + + try: + response.raise_for_status() + except requests.exceptions.HTTPError as err: + raise AirflowException( + "Could not submit batch. Status code: {}. Message: '{}'".format( + err.response.status_code, err.response.text + ) + ) + + batch_id = self._parse_post_response(response.json()) + if batch_id is None: + raise AirflowException("Unable to parse the batch session id") + self.log.info("Batch submitted with session id: %d", batch_id) + + return batch_id + + def get_batch(self, session_id: Union[int, str]) -> Any: + """ + Fetch info about the specified batch + + :param session_id: identifier of the batch sessions + :type session_id: int + :return: response body + :rtype: dict + """ + self._validate_session_id(session_id) + + self.log.debug("Fetching info for batch session %d", session_id) + response = self.run_method(endpoint=f"/batches/{session_id}") + + try: + response.raise_for_status() + except requests.exceptions.HTTPError as err: + self.log.warning( + "Got status code %d for session %d", + err.response.status_code, + session_id, + ) + raise AirflowException( + f"Unable to fetch batch with id: {session_id}. Message: {err.response.text}" + ) + + return response.json() + + def get_batch_state(self, session_id: Union[int, str]) -> BatchState: + """ + Fetch the state of the specified batch + + :param session_id: identifier of the batch sessions + :type session_id: Union[int, str] + :return: batch state + :rtype: BatchState + """ + self._validate_session_id(session_id) + + self.log.debug("Fetching info for batch session %d", session_id) + response = self.run_method(endpoint=f"/batches/{session_id}/state") + + try: + response.raise_for_status() + except requests.exceptions.HTTPError as err: + self.log.warning( + "Got status code %d for session %d", + err.response.status_code, + session_id, + ) + raise AirflowException( + f"Unable to fetch batch with id: {session_id}. Message: {err.response.text}" + ) + + jresp = response.json() + if "state" not in jresp: + raise AirflowException( + f"Unable to get state for batch with id: {session_id}" + ) + return BatchState(jresp["state"]) + + def delete_batch(self, session_id: Union[int, str]) -> Any: + """ + Delete the specified batch + + :param session_id: identifier of the batch sessions + :type session_id: int + :return: response body + :rtype: dict + """ + self._validate_session_id(session_id) + + self.log.info("Deleting batch session %d", session_id) + response = self.run_method(method="DELETE", endpoint=f"/batches/{session_id}") + + try: + response.raise_for_status() + except requests.exceptions.HTTPError as err: + self.log.warning( + "Got status code %d for session %d", + err.response.status_code, + session_id, + ) + raise AirflowException( + "Could not kill the batch with session id: {}. Message: {}".format( + session_id, err.response.text + ) + ) + + return response.json() + + @staticmethod + def _validate_session_id(session_id: Union[int, str]) -> None: + """ + Validate session id is a int + + :param session_id: session id + :type session_id: Union[int, str] + """ + try: + int(session_id) + except (TypeError, ValueError): + raise TypeError("'session_id' must be an integer") + + @staticmethod + def _parse_post_response(response: Dict[Any, Any]) -> Any: + """ + Parse batch response for batch id + + :param response: response body + :type response: dict + :return: session id + :rtype: int + """ + return response.get("id") + + @staticmethod + def build_post_batch_body( + file: str, + args: Optional[Sequence[Union[str, int, float]]] = None, + class_name: Optional[str] = None, + jars: Optional[List[str]] = None, + py_files: Optional[List[str]] = None, + files: Optional[List[str]] = None, + archives: Optional[List[str]] = None, + name: Optional[str] = None, + driver_memory: Optional[str] = None, + driver_cores: Optional[Union[int, str]] = None, + executor_memory: Optional[str] = None, + executor_cores: Optional[int] = None, + num_executors: Optional[Union[int, str]] = None, + queue: Optional[str] = None, + proxy_user: Optional[str] = None, + conf: Optional[Dict[Any, Any]] = None, + ) -> Any: + """ + Build the post batch request body. + For more information about the format refer to + .. seealso:: https://livy.apache.org/docs/latest/rest-api.html + + :param file: Path of the file containing the application to execute (required). + :type file: str + :param proxy_user: User to impersonate when running the job. + :type proxy_user: str + :param class_name: Application Java/Spark main class string. + :type class_name: str + :param args: Command line arguments for the application s. + :type args: Sequence[Union[str, int, float]] + :param jars: jars to be used in this sessions. + :type jars: Sequence[str] + :param py_files: Python files to be used in this session. + :type py_files: Sequence[str] + :param files: files to be used in this session. + :type files: Sequence[str] + :param driver_memory: Amount of memory to use for the driver process string. + :type driver_memory: str + :param driver_cores: Number of cores to use for the driver process int. + :type driver_cores: Union[str, int] + :param executor_memory: Amount of memory to use per executor process string. + :type executor_memory: str + :param executor_cores: Number of cores to use for each executor int. + :type executor_cores: Union[int, str] + :param num_executors: Number of executors to launch for this session int. + :type num_executors: Union[str, int] + :param archives: Archives to be used in this session. + :type archives: Sequence[str] + :param queue: The name of the YARN queue to which submitted string. + :type queue: str + :param name: The name of this session string. + :type name: str + :param conf: Spark configuration properties. + :type conf: dict + :return: request body + :rtype: dict + """ + # pylint: disable-msg=too-many-arguments + + body: Dict[str, Any] = {"file": file} + + if proxy_user: + body["proxyUser"] = proxy_user + if class_name: + body["className"] = class_name + if args and LivyHook._validate_list_of_stringables(args): + body["args"] = [str(val) for val in args] + if jars and LivyHook._validate_list_of_stringables(jars): + body["jars"] = jars + if py_files and LivyHook._validate_list_of_stringables(py_files): + body["pyFiles"] = py_files + if files and LivyHook._validate_list_of_stringables(files): + body["files"] = files + if driver_memory and LivyHook._validate_size_format(driver_memory): + body["driverMemory"] = driver_memory + if driver_cores: + body["driverCores"] = driver_cores + if executor_memory and LivyHook._validate_size_format(executor_memory): + body["executorMemory"] = executor_memory + if executor_cores: + body["executorCores"] = executor_cores + if num_executors: + body["numExecutors"] = num_executors + if archives and LivyHook._validate_list_of_stringables(archives): + body["archives"] = archives + if queue: + body["queue"] = queue + if name: + body["name"] = name + if conf and LivyHook._validate_extra_conf(conf): + body["conf"] = conf + + return body + + @staticmethod + def _validate_size_format(size: str) -> bool: + """ + Validate size format. + + :param size: size value + :type size: str + :return: true if valid format + :rtype: bool + """ + if size and not ( + isinstance(size, str) and re.match(r"^\d+[kmgt]b?$", size, re.IGNORECASE) + ): + raise ValueError(f"Invalid java size format for string'{size}'") + return True + + @staticmethod + def _validate_list_of_stringables(vals: Sequence[Union[str, int, float]]) -> bool: + """ + Check the values in the provided list can be converted to strings. + + :param vals: list to validate + :type vals: Sequence[Union[str, int, float]] + :return: true if valid + :rtype: bool + """ + if ( + vals is None + or not isinstance(vals, (tuple, list)) + or any(1 for val in vals if not isinstance(val, (str, int, float))) + ): + raise ValueError("List of strings expected") + return True + + @staticmethod + def _validate_extra_conf(conf: Dict[Any, Any]) -> bool: + """ + Check configuration values are either strings or ints. + + :param conf: configuration variable + :type conf: dict + :return: true if valid + :rtype: bool + """ + if conf: + if not isinstance(conf, dict): + raise ValueError("'conf' argument must be a dict") + if any( + True + for k, v in conf.items() + if not (v and isinstance(v, str) or isinstance(v, int)) + ): + raise ValueError("'conf' values must be either strings or ints") + return True diff --git a/reference/providers/apache/livy/operators/__init__.py b/reference/providers/apache/livy/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/livy/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/livy/operators/livy.py b/reference/providers/apache/livy/operators/livy.py new file mode 100644 index 0000000..b74a728 --- /dev/null +++ b/reference/providers/apache/livy/operators/livy.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Apache Livy operator.""" +from time import sleep +from typing import Any, Dict, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.apache.livy.hooks.livy import BatchState, LivyHook +from airflow.utils.decorators import apply_defaults + + +class LivyOperator(BaseOperator): + """ + This operator wraps the Apache Livy batch REST API, allowing to submit a Spark + application to the underlying cluster. + + :param file: path of the file containing the application to execute (required). + :type file: str + :param class_name: name of the application Java/Spark main class. + :type class_name: str + :param args: application command line arguments. + :type args: list + :param jars: jars to be used in this sessions. + :type jars: list + :param py_files: python files to be used in this session. + :type py_files: list + :param files: files to be used in this session. + :type files: list + :param driver_memory: amount of memory to use for the driver process. + :type driver_memory: str + :param driver_cores: number of cores to use for the driver process. + :type driver_cores: str, int + :param executor_memory: amount of memory to use per executor process. + :type executor_memory: str + :param executor_cores: number of cores to use for each executor. + :type executor_cores: str, int + :param num_executors: number of executors to launch for this session. + :type num_executors: str, int + :param archives: archives to be used in this session. + :type archives: list + :param queue: name of the YARN queue to which the application is submitted. + :type queue: str + :param name: name of this session. + :type name: str + :param conf: Spark configuration properties. + :type conf: dict + :param proxy_user: user to impersonate when running the job. + :type proxy_user: str + :param livy_conn_id: reference to a pre-defined Livy Connection. + :type livy_conn_id: str + :param polling_interval: time in seconds between polling for job completion. Don't poll for values >=0 + :type polling_interval: int + :type extra_options: A dictionary of options, where key is string and value + depends on the option that's being modified. + """ + + template_fields = ("spark_params",) + + @apply_defaults + def __init__( + self, + *, + file: str, + class_name: Optional[str] = None, + args: Optional[Sequence[Union[str, int, float]]] = None, + conf: Optional[Dict[Any, Any]] = None, + jars: Optional[Sequence[str]] = None, + py_files: Optional[Sequence[str]] = None, + files: Optional[Sequence[str]] = None, + driver_memory: Optional[str] = None, + driver_cores: Optional[Union[int, str]] = None, + executor_memory: Optional[str] = None, + executor_cores: Optional[Union[int, str]] = None, + num_executors: Optional[Union[int, str]] = None, + archives: Optional[Sequence[str]] = None, + queue: Optional[str] = None, + name: Optional[str] = None, + proxy_user: Optional[str] = None, + livy_conn_id: str = "livy_default", + polling_interval: int = 0, + extra_options: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + # pylint: disable-msg=too-many-arguments + + super().__init__(**kwargs) + + self.spark_params = { + "file": file, + "class_name": class_name, + "args": args, + "jars": jars, + "py_files": py_files, + "files": files, + "driver_memory": driver_memory, + "driver_cores": driver_cores, + "executor_memory": executor_memory, + "executor_cores": executor_cores, + "num_executors": num_executors, + "archives": archives, + "queue": queue, + "name": name, + "conf": conf, + "proxy_user": proxy_user, + } + + self._livy_conn_id = livy_conn_id + self._polling_interval = polling_interval + self._extra_options = extra_options or {} + + self._livy_hook: Optional[LivyHook] = None + self._batch_id: Union[int, str] + + def get_hook(self) -> LivyHook: + """ + Get valid hook. + + :return: hook + :rtype: LivyHook + """ + if self._livy_hook is None or not isinstance(self._livy_hook, LivyHook): + self._livy_hook = LivyHook( + livy_conn_id=self._livy_conn_id, extra_options=self._extra_options + ) + return self._livy_hook + + def execute(self, context: Dict[Any, Any]) -> Any: + self._batch_id = self.get_hook().post_batch(**self.spark_params) + + if self._polling_interval > 0: + self.poll_for_termination(self._batch_id) + + return self._batch_id + + def poll_for_termination(self, batch_id: Union[int, str]) -> None: + """ + Pool Livy for batch termination. + + :param batch_id: id of the batch session to monitor. + :type batch_id: int + """ + hook = self.get_hook() + state = hook.get_batch_state(batch_id) + while state not in hook.TERMINAL_STATES: + self.log.debug("Batch with id %s is in state: %s", batch_id, state.value) + sleep(self._polling_interval) + state = hook.get_batch_state(batch_id) + self.log.info( + "Batch with id %s terminated with state: %s", batch_id, state.value + ) + if state != BatchState.SUCCESS: + raise AirflowException(f"Batch {batch_id} did not succeed") + + def on_kill(self) -> None: + self.kill() + + def kill(self) -> None: + """Delete the current batch session.""" + if self._batch_id is not None: + self.get_hook().delete_batch(self._batch_id) diff --git a/reference/providers/apache/livy/provider.yaml b/reference/providers/apache/livy/provider.yaml new file mode 100644 index 0000000..0cd1074 --- /dev/null +++ b/reference/providers/apache/livy/provider.yaml @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-apache-livy +name: Apache Livy +description: | + `Apache Livy `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Apache Livy + external-doc-url: https://livy.apache.org/ + logo: /integration-logos/apache/Livy.png + tags: [apache] + +operators: + - integration-name: Apache Livy + python-modules: + - airflow.providers.apache.livy.operators.livy + +sensors: + - integration-name: Apache Livy + python-modules: + - airflow.providers.apache.livy.sensors.livy + +hooks: + - integration-name: Apache Livy + python-modules: + - airflow.providers.apache.livy.hooks.livy + +hook-class-names: + - airflow.providers.apache.livy.hooks.livy.LivyHook diff --git a/reference/providers/apache/livy/sensors/__init__.py b/reference/providers/apache/livy/sensors/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/apache/livy/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/livy/sensors/livy.py b/reference/providers/apache/livy/sensors/livy.py new file mode 100644 index 0000000..cf460e5 --- /dev/null +++ b/reference/providers/apache/livy/sensors/livy.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Apache Livy sensor.""" +from typing import Any, Dict, Optional, Union + +from airflow.providers.apache.livy.hooks.livy import LivyHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class LivySensor(BaseSensorOperator): + """ + Monitor a Livy sessions for termination. + + :param livy_conn_id: reference to a pre-defined Livy connection + :type livy_conn_id: str + :param batch_id: identifier of the monitored batch + :type batch_id: Union[int, str] + :type extra_options: A dictionary of options, where key is string and value + depends on the option that's being modified. + """ + + template_fields = ("batch_id",) + + @apply_defaults + def __init__( + self, + *, + batch_id: Union[int, str], + livy_conn_id: str = "livy_default", + extra_options: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.batch_id = batch_id + self._livy_conn_id = livy_conn_id + self._livy_hook: Optional[LivyHook] = None + self._extra_options = extra_options or {} + + def get_hook(self) -> LivyHook: + """ + Get valid hook. + + :return: hook + :rtype: LivyHook + """ + if self._livy_hook is None or not isinstance(self._livy_hook, LivyHook): + self._livy_hook = LivyHook( + livy_conn_id=self._livy_conn_id, extra_options=self._extra_options + ) + return self._livy_hook + + def poke(self, context: Dict[Any, Any]) -> bool: + batch_id = self.batch_id + + status = self.get_hook().get_batch_state(batch_id) + return status in self.get_hook().TERMINAL_STATES diff --git a/reference/providers/apache/pig/CHANGELOG.rst b/reference/providers/apache/pig/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/apache/pig/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/apache/pig/__init__.py b/reference/providers/apache/pig/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/pig/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/pig/example_dags/__init__.py b/reference/providers/apache/pig/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/pig/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/pig/example_dags/example_pig.py b/reference/providers/apache/pig/example_dags/example_pig.py new file mode 100644 index 0000000..e74a8d3 --- /dev/null +++ b/reference/providers/apache/pig/example_dags/example_pig.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example DAG demonstrating the usage of the PigOperator.""" + +from airflow import DAG +from airflow.providers.apache.pig.operators.pig import PigOperator +from airflow.utils.dates import days_ago + +args = { + "owner": "airflow", +} + +dag = DAG( + dag_id="example_pig_operator", + default_args=args, + schedule_interval=None, + start_date=days_ago(2), + tags=["example"], +) + +run_this = PigOperator( + task_id="run_example_pig_script", + pig="ls /;", + pig_opts="-x local", + dag=dag, +) diff --git a/reference/providers/apache/pig/hooks/__init__.py b/reference/providers/apache/pig/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/pig/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/pig/hooks/pig.py b/reference/providers/apache/pig/hooks/pig.py new file mode 100644 index 0000000..f653c50 --- /dev/null +++ b/reference/providers/apache/pig/hooks/pig.py @@ -0,0 +1,105 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import subprocess +from tempfile import NamedTemporaryFile, TemporaryDirectory +from typing import Any, List, Optional + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +class PigCliHook(BaseHook): + """ + Simple wrapper around the pig CLI. + + Note that you can also set default pig CLI properties using the + ``pig_properties`` to be used in your connection as in + ``{"pig_properties": "-Dpig.tmpfilecompression=true"}`` + + """ + + conn_name_attr = "pig_cli_conn_id" + default_conn_name = "pig_cli_default" + conn_type = "pig_cli" + hook_name = "Pig Client Wrapper" + + def __init__(self, pig_cli_conn_id: str = default_conn_name) -> None: + super().__init__() + conn = self.get_connection(pig_cli_conn_id) + self.pig_properties = conn.extra_dejson.get("pig_properties", "") + self.conn = conn + self.sub_process = None + + def run_cli( + self, pig: str, pig_opts: Optional[str] = None, verbose: bool = True + ) -> Any: + """ + Run an pig script using the pig cli + + >>> ph = PigCliHook() + >>> result = ph.run_cli("ls /;", pig_opts="-x mapreduce") + >>> ("hdfs://" in result) + True + """ + with TemporaryDirectory(prefix="airflow_pigop_") as tmp_dir: + with NamedTemporaryFile(dir=tmp_dir) as f: + f.write(pig.encode("utf-8")) + f.flush() + fname = f.name + pig_bin = "pig" + cmd_extra: List[str] = [] + + pig_cmd = [pig_bin] + + if self.pig_properties: + pig_properties_list = self.pig_properties.split() + pig_cmd.extend(pig_properties_list) + if pig_opts: + pig_opts_list = pig_opts.split() + pig_cmd.extend(pig_opts_list) + + pig_cmd.extend(["-f", fname] + cmd_extra) + + if verbose: + self.log.info("%s", " ".join(pig_cmd)) + sub_process: Any = subprocess.Popen( + pig_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=tmp_dir, + close_fds=True, + ) + self.sub_process = sub_process + stdout = "" + for line in iter(sub_process.stdout.readline, b""): + stdout += line.decode("utf-8") + if verbose: + self.log.info(line.strip()) + sub_process.wait() + + if sub_process.returncode: + raise AirflowException(stdout) + + return stdout + + def kill(self) -> None: + """Kill Pig job""" + if self.sub_process: + if self.sub_process.poll() is None: + self.log.info("Killing the Pig job") + self.sub_process.kill() diff --git a/reference/providers/apache/pig/operators/__init__.py b/reference/providers/apache/pig/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/pig/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/pig/operators/pig.py b/reference/providers/apache/pig/operators/pig.py new file mode 100644 index 0000000..fc3ddca --- /dev/null +++ b/reference/providers/apache/pig/operators/pig.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import re +from typing import Any, Optional + +from airflow.models import BaseOperator +from airflow.providers.apache.pig.hooks.pig import PigCliHook +from airflow.utils.decorators import apply_defaults + + +class PigOperator(BaseOperator): + """ + Executes pig script. + + :param pig: the pig latin script to be executed. (templated) + :type pig: str + :param pig_cli_conn_id: reference to the Hive database + :type pig_cli_conn_id: str + :param pigparams_jinja_translate: when True, pig params-type templating + ${var} gets translated into jinja-type templating {{ var }}. Note that + you may want to use this along with the + ``DAG(user_defined_macros=myargs)`` parameter. View the DAG + object documentation for more details. + :type pigparams_jinja_translate: bool + :param pig_opts: pig options, such as: -x tez, -useHCatalog, ... + :type pig_opts: str + """ + + template_fields = ("pig",) + template_ext = ( + ".pig", + ".piglatin", + ) + ui_color = "#f0e4ec" + + @apply_defaults + def __init__( + self, + *, + pig: str, + pig_cli_conn_id: str = "pig_cli_default", + pigparams_jinja_translate: bool = False, + pig_opts: Optional[str] = None, + **kwargs: Any, + ) -> None: + + super().__init__(**kwargs) + self.pigparams_jinja_translate = pigparams_jinja_translate + self.pig = pig + self.pig_cli_conn_id = pig_cli_conn_id + self.pig_opts = pig_opts + self.hook = None + + def prepare_template(self): + if self.pigparams_jinja_translate: + self.pig = re.sub(r"(\$([a-zA-Z_][a-zA-Z0-9_]*))", r"{{ \g<2> }}", self.pig) + + def execute(self, context): + self.log.info("Executing: %s", self.pig) + self.hook = PigCliHook(pig_cli_conn_id=self.pig_cli_conn_id) + self.hook.run_cli(pig=self.pig, pig_opts=self.pig_opts) + + def on_kill(self): + self.hook.kill() diff --git a/reference/providers/apache/pig/provider.yaml b/reference/providers/apache/pig/provider.yaml new file mode 100644 index 0000000..e038ba3 --- /dev/null +++ b/reference/providers/apache/pig/provider.yaml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-apache-pig +name: Apache Pig +description: | + `Apache Pig `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Apache Pig + external-doc-url: https://pig.apache.org/ + logo: /integration-logos/apache/pig.png + tags: [apache] + +operators: + - integration-name: Apache Pig + python-modules: + - airflow.providers.apache.pig.operators.pig + +hooks: + - integration-name: Apache Pig + python-modules: + - airflow.providers.apache.pig.hooks.pig + +hook-class-names: + - airflow.providers.apache.pig.hooks.pig.PigCliHook diff --git a/reference/providers/apache/pinot/CHANGELOG.rst b/reference/providers/apache/pinot/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/apache/pinot/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/apache/pinot/__init__.py b/reference/providers/apache/pinot/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/pinot/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/pinot/hooks/__init__.py b/reference/providers/apache/pinot/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/pinot/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/pinot/hooks/pinot.py b/reference/providers/apache/pinot/hooks/pinot.py new file mode 100644 index 0000000..6c0a9fb --- /dev/null +++ b/reference/providers/apache/pinot/hooks/pinot.py @@ -0,0 +1,345 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import subprocess +from typing import Any, Dict, Iterable, List, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.hooks.dbapi import DbApiHook +from airflow.models import Connection +from pinotdb import connect + + +class PinotAdminHook(BaseHook): + """ + This hook is a wrapper around the pinot-admin.sh script. + For now, only small subset of its subcommands are implemented, + which are required to ingest offline data into Apache Pinot + (i.e., AddSchema, AddTable, CreateSegment, and UploadSegment). + Their command options are based on Pinot v0.1.0. + + Unfortunately, as of v0.1.0, pinot-admin.sh always exits with + status code 0. To address this behavior, users can use the + pinot_admin_system_exit flag. If its value is set to false, + this hook evaluates the result based on the output message + instead of the status code. This Pinot's behavior is supposed + to be improved in the next release, which will include the + following PR: https://github.com/apache/incubator-pinot/pull/4110 + + :param conn_id: The name of the connection to use. + :type conn_id: str + :param cmd_path: The filepath to the pinot-admin.sh executable + :type cmd_path: str + :param pinot_admin_system_exit: If true, the result is evaluated based on the status code. + Otherwise, the result is evaluated as a failure if "Error" or + "Exception" is in the output message. + :type pinot_admin_system_exit: bool + """ + + def __init__( + self, + conn_id: str = "pinot_admin_default", + cmd_path: str = "pinot-admin.sh", + pinot_admin_system_exit: bool = False, + ) -> None: + super().__init__() + conn = self.get_connection(conn_id) + self.host = conn.host + self.port = str(conn.port) + self.cmd_path = conn.extra_dejson.get("cmd_path", cmd_path) + self.pinot_admin_system_exit = conn.extra_dejson.get( + "pinot_admin_system_exit", pinot_admin_system_exit + ) + self.conn = conn + + def get_conn(self) -> Any: + return self.conn + + def add_schema(self, schema_file: str, with_exec: bool = True) -> Any: + """ + Add Pinot schema by run AddSchema command + + :param schema_file: Pinot schema file + :type schema_file: str + :param with_exec: bool + :type with_exec: bool + """ + cmd = ["AddSchema"] + cmd += ["-controllerHost", self.host] + cmd += ["-controllerPort", self.port] + cmd += ["-schemaFile", schema_file] + if with_exec: + cmd += ["-exec"] + self.run_cli(cmd) + + def add_table(self, file_path: str, with_exec: bool = True) -> Any: + """ + Add Pinot table with run AddTable command + + :param file_path: Pinot table configure file + :type file_path: str + :param with_exec: bool + :type with_exec: bool + """ + cmd = ["AddTable"] + cmd += ["-controllerHost", self.host] + cmd += ["-controllerPort", self.port] + cmd += ["-filePath", file_path] + if with_exec: + cmd += ["-exec"] + self.run_cli(cmd) + + # pylint: disable=too-many-arguments + def create_segment( + self, + generator_config_file: Optional[str] = None, + data_dir: Optional[str] = None, + segment_format: Optional[str] = None, + out_dir: Optional[str] = None, + overwrite: Optional[str] = None, + table_name: Optional[str] = None, + segment_name: Optional[str] = None, + time_column_name: Optional[str] = None, + schema_file: Optional[str] = None, + reader_config_file: Optional[str] = None, + enable_star_tree_index: Optional[str] = None, + star_tree_index_spec_file: Optional[str] = None, + hll_size: Optional[str] = None, + hll_columns: Optional[str] = None, + hll_suffix: Optional[str] = None, + num_threads: Optional[str] = None, + post_creation_verification: Optional[str] = None, + retry: Optional[str] = None, + ) -> Any: + """Create Pinot segment by run CreateSegment command""" + cmd = ["CreateSegment"] + + if generator_config_file: + cmd += ["-generatorConfigFile", generator_config_file] + + if data_dir: + cmd += ["-dataDir", data_dir] + + if segment_format: + cmd += ["-format", segment_format] + + if out_dir: + cmd += ["-outDir", out_dir] + + if overwrite: + cmd += ["-overwrite", overwrite] + + if table_name: + cmd += ["-tableName", table_name] + + if segment_name: + cmd += ["-segmentName", segment_name] + + if time_column_name: + cmd += ["-timeColumnName", time_column_name] + + if schema_file: + cmd += ["-schemaFile", schema_file] + + if reader_config_file: + cmd += ["-readerConfigFile", reader_config_file] + + if enable_star_tree_index: + cmd += ["-enableStarTreeIndex", enable_star_tree_index] + + if star_tree_index_spec_file: + cmd += ["-starTreeIndexSpecFile", star_tree_index_spec_file] + + if hll_size: + cmd += ["-hllSize", hll_size] + + if hll_columns: + cmd += ["-hllColumns", hll_columns] + + if hll_suffix: + cmd += ["-hllSuffix", hll_suffix] + + if num_threads: + cmd += ["-numThreads", num_threads] + + if post_creation_verification: + cmd += ["-postCreationVerification", post_creation_verification] + + if retry: + cmd += ["-retry", retry] + + self.run_cli(cmd) + + def upload_segment(self, segment_dir: str, table_name: Optional[str] = None) -> Any: + """ + Upload Segment with run UploadSegment command + + :param segment_dir: + :param table_name: + :return: + """ + cmd = ["UploadSegment"] + cmd += ["-controllerHost", self.host] + cmd += ["-controllerPort", self.port] + cmd += ["-segmentDir", segment_dir] + if table_name: + cmd += ["-tableName", table_name] + self.run_cli(cmd) + + def run_cli(self, cmd: List[str], verbose: bool = True) -> str: + """ + Run command with pinot-admin.sh + + :param cmd: List of command going to be run by pinot-admin.sh script + :type cmd: list + :param verbose: + :type verbose: bool + """ + command = [self.cmd_path] + command.extend(cmd) + + env = None + if self.pinot_admin_system_exit: + env = os.environ.copy() + java_opts = "-Dpinot.admin.system.exit=true " + os.environ.get( + "JAVA_OPTS", "" + ) + env.update({"JAVA_OPTS": java_opts}) + + if verbose: + self.log.info(" ".join(command)) + + sub_process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + close_fds=True, + env=env, + ) + + stdout = "" + if sub_process.stdout: + for line in iter(sub_process.stdout.readline, b""): + stdout += line.decode("utf-8") + if verbose: + self.log.info(line.decode("utf-8").strip()) + + sub_process.wait() + + # As of Pinot v0.1.0, either of "Error: ..." or "Exception caught: ..." + # is expected to be in the output messages. See: + # https://github.com/apache/incubator-pinot/blob/release-0.1.0/pinot-tools/src/main/java/org/apache/pinot/tools/admin/PinotAdministrator.java#L98-L101 + if (self.pinot_admin_system_exit and sub_process.returncode) or ( + "Error" in stdout or "Exception" in stdout + ): + raise AirflowException(stdout) + + return stdout + + +class PinotDbApiHook(DbApiHook): + """ + Interact with Pinot Broker Query API + + This hook uses standard-SQL endpoint since PQL endpoint is soon to be deprecated. + https://docs.pinot.apache.org/users/api/querying-pinot-using-standard-sql + """ + + conn_name_attr = "pinot_broker_conn_id" + default_conn_name = "pinot_broker_default" + supports_autocommit = False + + def get_conn(self) -> Any: + """Establish a connection to pinot broker through pinot dbapi.""" + # pylint: disable=no-member + conn = self.get_connection(self.pinot_broker_conn_id) # type: ignore + # pylint: enable=no-member + pinot_broker_conn = connect( + host=conn.host, + port=conn.port, + path=conn.extra_dejson.get("endpoint", "/query/sql"), + scheme=conn.extra_dejson.get("schema", "http"), + ) + self.log.info("Get the connection to pinot broker on %s", conn.host) + return pinot_broker_conn + + def get_uri(self) -> str: + """ + Get the connection uri for pinot broker. + + e.g: http://localhost:9000/query/sql + """ + conn = self.get_connection(getattr(self, self.conn_name_attr)) + host = conn.host + if conn.port is not None: + host += f":{conn.port}" + conn_type = "http" if not conn.conn_type else conn.conn_type + endpoint = conn.extra_dejson.get("endpoint", "query/sql") + return f"{conn_type}://{host}/{endpoint}" + + def get_records( + self, + sql: str, + parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None, + ) -> Any: + """ + Executes the sql and returns a set of records. + + :param sql: the sql statement to be executed (str) or a list of + sql statements to execute + :type sql: str + :param parameters: The parameters to render the SQL query with. + :type parameters: dict or iterable + """ + with self.get_conn() as cur: + cur.execute(sql) + return cur.fetchall() + + def get_first( + self, + sql: str, + parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None, + ) -> Any: + """ + Executes the sql and returns the first resulting row. + + :param sql: the sql statement to be executed (str) or a list of + sql statements to execute + :type sql: str or list + :param parameters: The parameters to render the SQL query with. + :type parameters: dict or iterable + """ + with self.get_conn() as cur: + cur.execute(sql) + return cur.fetchone() + + def set_autocommit(self, conn: Connection, autocommit: Any) -> Any: + raise NotImplementedError() + + def insert_rows( + self, + table: str, + rows: str, + target_fields: Optional[str] = None, + commit_every: int = 1000, + replace: bool = False, + **kwargs: Any, + ) -> Any: + raise NotImplementedError() diff --git a/reference/providers/apache/pinot/provider.yaml b/reference/providers/apache/pinot/provider.yaml new file mode 100644 index 0000000..16e0529 --- /dev/null +++ b/reference/providers/apache/pinot/provider.yaml @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-apache-pinot +name: Apache Pinot +description: | + `Apache Pinot `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Apache Pinot + external-doc-url: https://pinot.apache.org/ + logo: /integration-logos/apache/pinot.png + tags: [apache] + +hooks: + - integration-name: Apache Pinot + python-modules: + - airflow.providers.apache.pinot.hooks.pinot diff --git a/reference/providers/apache/spark/CHANGELOG.rst b/reference/providers/apache/spark/CHANGELOG.rst new file mode 100644 index 0000000..4e3356f --- /dev/null +++ b/reference/providers/apache/spark/CHANGELOG.rst @@ -0,0 +1,39 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.2 +..... + +Bug fixes +~~~~~~~~~ + +* ``Use apache.spark provider without kubernetes (#14187)`` + + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/apache/spark/__init__.py b/reference/providers/apache/spark/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/spark/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/spark/example_dags/__init__.py b/reference/providers/apache/spark/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/spark/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/spark/example_dags/example_spark_dag.py b/reference/providers/apache/spark/example_dags/example_spark_dag.py new file mode 100644 index 0000000..f20d376 --- /dev/null +++ b/reference/providers/apache/spark/example_dags/example_spark_dag.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG to submit Apache Spark applications using +`SparkSubmitOperator`, `SparkJDBCOperator` and `SparkSqlOperator`. +""" +from airflow.models import DAG +from airflow.providers.apache.spark.operators.spark_jdbc import SparkJDBCOperator +from airflow.providers.apache.spark.operators.spark_sql import SparkSqlOperator +from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator +from airflow.utils.dates import days_ago + +args = { + "owner": "Airflow", +} + +with DAG( + dag_id="example_spark_operator", + default_args=args, + schedule_interval=None, + start_date=days_ago(2), + tags=["example"], +) as dag: + # [START howto_operator_spark_submit] + submit_job = SparkSubmitOperator( + application="${SPARK_HOME}/examples/src/main/python/pi.py", task_id="submit_job" + ) + # [END howto_operator_spark_submit] + + # [START howto_operator_spark_jdbc] + jdbc_to_spark_job = SparkJDBCOperator( + cmd_type="jdbc_to_spark", + jdbc_table="foo", + spark_jars="${SPARK_HOME}/jars/postgresql-42.2.12.jar", + jdbc_driver="org.postgresql.Driver", + metastore_table="bar", + save_mode="overwrite", + save_format="JSON", + task_id="jdbc_to_spark_job", + ) + + spark_to_jdbc_job = SparkJDBCOperator( + cmd_type="spark_to_jdbc", + jdbc_table="foo", + spark_jars="${SPARK_HOME}/jars/postgresql-42.2.12.jar", + jdbc_driver="org.postgresql.Driver", + metastore_table="bar", + save_mode="append", + task_id="spark_to_jdbc_job", + ) + # [END howto_operator_spark_jdbc] + + # [START howto_operator_spark_sql] + sql_job = SparkSqlOperator( + sql="SELECT * FROM bar", master="local", task_id="sql_job" + ) + # [END howto_operator_spark_sql] diff --git a/reference/providers/apache/spark/hooks/__init__.py b/reference/providers/apache/spark/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/spark/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/spark/hooks/spark_jdbc.py b/reference/providers/apache/spark/hooks/spark_jdbc.py new file mode 100644 index 0000000..14b9077 --- /dev/null +++ b/reference/providers/apache/spark/hooks/spark_jdbc.py @@ -0,0 +1,272 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import os +from typing import Any, Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook + + +# pylint: disable=too-many-instance-attributes +class SparkJDBCHook(SparkSubmitHook): + """ + This hook extends the SparkSubmitHook specifically for performing data + transfers to/from JDBC-based databases with Apache Spark. + + :param spark_app_name: Name of the job (default airflow-spark-jdbc) + :type spark_app_name: str + :param spark_conn_id: Connection id as configured in Airflow administration + :type spark_conn_id: str + :param spark_conf: Any additional Spark configuration properties + :type spark_conf: dict + :param spark_py_files: Additional python files used (.zip, .egg, or .py) + :type spark_py_files: str + :param spark_files: Additional files to upload to the container running the job + :type spark_files: str + :param spark_jars: Additional jars to upload and add to the driver and + executor classpath + :type spark_jars: str + :param num_executors: number of executor to run. This should be set so as to manage + the number of connections made with the JDBC database + :type num_executors: int + :param executor_cores: Number of cores per executor + :type executor_cores: int + :param executor_memory: Memory per executor (e.g. 1000M, 2G) + :type executor_memory: str + :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) + :type driver_memory: str + :param verbose: Whether to pass the verbose flag to spark-submit for debugging + :type verbose: bool + :param keytab: Full path to the file that contains the keytab + :type keytab: str + :param principal: The name of the kerberos principal used for keytab + :type principal: str + :param cmd_type: Which way the data should flow. 2 possible values: + spark_to_jdbc: data written by spark from metastore to jdbc + jdbc_to_spark: data written by spark from jdbc to metastore + :type cmd_type: str + :param jdbc_table: The name of the JDBC table + :type jdbc_table: str + :param jdbc_conn_id: Connection id used for connection to JDBC database + :type jdbc_conn_id: str + :param jdbc_driver: Name of the JDBC driver to use for the JDBC connection. This + driver (usually a jar) should be passed in the 'jars' parameter + :type jdbc_driver: str + :param metastore_table: The name of the metastore table, + :type metastore_table: str + :param jdbc_truncate: (spark_to_jdbc only) Whether or not Spark should truncate or + drop and recreate the JDBC table. This only takes effect if + 'save_mode' is set to Overwrite. Also, if the schema is + different, Spark cannot truncate, and will drop and recreate + :type jdbc_truncate: bool + :param save_mode: The Spark save-mode to use (e.g. overwrite, append, etc.) + :type save_mode: str + :param save_format: (jdbc_to_spark-only) The Spark save-format to use (e.g. parquet) + :type save_format: str + :param batch_size: (spark_to_jdbc only) The size of the batch to insert per round + trip to the JDBC database. Defaults to 1000 + :type batch_size: int + :param fetch_size: (jdbc_to_spark only) The size of the batch to fetch per round trip + from the JDBC database. Default depends on the JDBC driver + :type fetch_size: int + :param num_partitions: The maximum number of partitions that can be used by Spark + simultaneously, both for spark_to_jdbc and jdbc_to_spark + operations. This will also cap the number of JDBC connections + that can be opened + :type num_partitions: int + :param partition_column: (jdbc_to_spark-only) A numeric column to be used to + partition the metastore table by. If specified, you must + also specify: + num_partitions, lower_bound, upper_bound + :type partition_column: str + :param lower_bound: (jdbc_to_spark-only) Lower bound of the range of the numeric + partition column to fetch. If specified, you must also specify: + num_partitions, partition_column, upper_bound + :type lower_bound: int + :param upper_bound: (jdbc_to_spark-only) Upper bound of the range of the numeric + partition column to fetch. If specified, you must also specify: + num_partitions, partition_column, lower_bound + :type upper_bound: int + :param create_table_column_types: (spark_to_jdbc-only) The database column data types + to use instead of the defaults, when creating the + table. Data type information should be specified in + the same format as CREATE TABLE columns syntax + (e.g: "name CHAR(64), comments VARCHAR(1024)"). + The specified types should be valid spark sql data + types. + """ + + conn_name_attr = "spark_conn_id" + default_conn_name = "spark_default" + conn_type = "spark_jdbc" + hook_name = "Spark JDBC" + + # pylint: disable=too-many-arguments,too-many-locals + def __init__( + self, + spark_app_name: str = "airflow-spark-jdbc", + spark_conn_id: str = default_conn_name, + spark_conf: Optional[Dict[str, Any]] = None, + spark_py_files: Optional[str] = None, + spark_files: Optional[str] = None, + spark_jars: Optional[str] = None, + num_executors: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + driver_memory: Optional[str] = None, + verbose: bool = False, + principal: Optional[str] = None, + keytab: Optional[str] = None, + cmd_type: str = "spark_to_jdbc", + jdbc_table: Optional[str] = None, + jdbc_conn_id: str = "jdbc-default", + jdbc_driver: Optional[str] = None, + metastore_table: Optional[str] = None, + jdbc_truncate: bool = False, + save_mode: Optional[str] = None, + save_format: Optional[str] = None, + batch_size: Optional[int] = None, + fetch_size: Optional[int] = None, + num_partitions: Optional[int] = None, + partition_column: Optional[str] = None, + lower_bound: Optional[str] = None, + upper_bound: Optional[str] = None, + create_table_column_types: Optional[str] = None, + *args: Any, + **kwargs: Any, + ): + super().__init__(*args, **kwargs) + self._name = spark_app_name + self._conn_id = spark_conn_id + self._conf = spark_conf or {} + self._py_files = spark_py_files + self._files = spark_files + self._jars = spark_jars + self._num_executors = num_executors + self._executor_cores = executor_cores + self._executor_memory = executor_memory + self._driver_memory = driver_memory + self._verbose = verbose + self._keytab = keytab + self._principal = principal + self._cmd_type = cmd_type + self._jdbc_table = jdbc_table + self._jdbc_conn_id = jdbc_conn_id + self._jdbc_driver = jdbc_driver + self._metastore_table = metastore_table + self._jdbc_truncate = jdbc_truncate + self._save_mode = save_mode + self._save_format = save_format + self._batch_size = batch_size + self._fetch_size = fetch_size + self._num_partitions = num_partitions + self._partition_column = partition_column + self._lower_bound = lower_bound + self._upper_bound = upper_bound + self._create_table_column_types = create_table_column_types + self._jdbc_connection = self._resolve_jdbc_connection() + + def _resolve_jdbc_connection(self) -> Dict[str, Any]: + conn_data = { + "url": "", + "schema": "", + "conn_prefix": "", + "user": "", + "password": "", + } + try: + conn = self.get_connection(self._jdbc_conn_id) + if conn.port: + conn_data["url"] = f"{conn.host}:{conn.port}" + else: + conn_data["url"] = conn.host + conn_data["schema"] = conn.schema + conn_data["user"] = conn.login + conn_data["password"] = conn.password + extra = conn.extra_dejson + conn_data["conn_prefix"] = extra.get("conn_prefix", "") + except AirflowException: + self.log.debug( + "Could not load jdbc connection string %s, defaulting to %s", + self._jdbc_conn_id, + "", + ) + return conn_data + + def _build_jdbc_application_arguments(self, jdbc_conn: Dict[str, Any]) -> Any: + arguments = [] + arguments += ["-cmdType", self._cmd_type] + if self._jdbc_connection["url"]: + arguments += [ + "-url", + f"{jdbc_conn['conn_prefix']}{jdbc_conn['url']}/{jdbc_conn['schema']}", + ] + if self._jdbc_connection["user"]: + arguments += ["-user", self._jdbc_connection["user"]] + if self._jdbc_connection["password"]: + arguments += ["-password", self._jdbc_connection["password"]] + if self._metastore_table: + arguments += ["-metastoreTable", self._metastore_table] + if self._jdbc_table: + arguments += ["-jdbcTable", self._jdbc_table] + if self._jdbc_truncate: + arguments += ["-jdbcTruncate", str(self._jdbc_truncate)] + if self._jdbc_driver: + arguments += ["-jdbcDriver", self._jdbc_driver] + if self._batch_size: + arguments += ["-batchsize", str(self._batch_size)] + if self._fetch_size: + arguments += ["-fetchsize", str(self._fetch_size)] + if self._num_partitions: + arguments += ["-numPartitions", str(self._num_partitions)] + if ( + self._partition_column + and self._lower_bound + and self._upper_bound + and self._num_partitions + ): + # these 3 parameters need to be used all together to take effect. + arguments += [ + "-partitionColumn", + self._partition_column, + "-lowerBound", + self._lower_bound, + "-upperBound", + self._upper_bound, + ] + if self._save_mode: + arguments += ["-saveMode", self._save_mode] + if self._save_format: + arguments += ["-saveFormat", self._save_format] + if self._create_table_column_types: + arguments += ["-createTableColumnTypes", self._create_table_column_types] + return arguments + + def submit_jdbc_job(self) -> None: + """Submit Spark JDBC job""" + self._application_args = self._build_jdbc_application_arguments( + self._jdbc_connection + ) + self.submit( + application=os.path.dirname(os.path.abspath(__file__)) + + "/spark_jdbc_script.py" + ) + + def get_conn(self) -> Any: + pass diff --git a/reference/providers/apache/spark/hooks/spark_jdbc_script.py b/reference/providers/apache/spark/hooks/spark_jdbc_script.py new file mode 100644 index 0000000..8259d4d --- /dev/null +++ b/reference/providers/apache/spark/hooks/spark_jdbc_script.py @@ -0,0 +1,198 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import argparse +from typing import Any, List, Optional + +from pyspark.sql import SparkSession + +SPARK_WRITE_TO_JDBC: str = "spark_to_jdbc" +SPARK_READ_FROM_JDBC: str = "jdbc_to_spark" + + +def set_common_options( + spark_# Any, + url: str = "localhost:5432", + jdbc_table: str = "default.default", + user: str = "root", + password: str = "root", + driver: str = "driver", +) -> Any: + """ + Get Spark source from JDBC connection + + :param spark_# Spark source, here is Spark reader or writer + :param url: JDBC resource url + :param jdbc_table: JDBC resource table name + :param user: JDBC resource user name + :param password: JDBC resource password + :param driver: JDBC resource driver + """ + spark_source = ( + spark_source.format("jdbc") + .option("url", url) + .option("dbtable", jdbc_table) + .option("user", user) + .option("password", password) + .option("driver", driver) + ) + return spark_source + + +# pylint: disable=too-many-arguments +def spark_write_to_jdbc( + spark_session: SparkSession, + url: str, + user: str, + password: str, + metastore_table: str, + jdbc_table: str, + driver: Any, + truncate: bool, + save_mode: str, + batch_size: int, + num_partitions: int, + create_table_column_types: str, +) -> None: + """Transfer data from Spark to JDBC source""" + writer = spark_session.table(metastore_table).write + # first set common options + writer = set_common_options(writer, url, jdbc_table, user, password, driver) + + # now set write-specific options + if truncate: + writer = writer.option("truncate", truncate) + if batch_size: + writer = writer.option("batchsize", batch_size) + if num_partitions: + writer = writer.option("numPartitions", num_partitions) + if create_table_column_types: + writer = writer.option("createTableColumnTypes", create_table_column_types) + + writer.save(mode=save_mode) + + +# pylint: disable=too-many-arguments +def spark_read_from_jdbc( + spark_session: SparkSession, + url: str, + user: str, + password: str, + metastore_table: str, + jdbc_table: str, + driver: Any, + save_mode: str, + save_format: str, + fetch_size: int, + num_partitions: int, + partition_column: str, + lower_bound: str, + upper_bound: str, +) -> None: + """Transfer data from JDBC source to Spark""" + # first set common options + reader = set_common_options( + spark_session.read, url, jdbc_table, user, password, driver + ) + + # now set specific read options + if fetch_size: + reader = reader.option("fetchsize", fetch_size) + if num_partitions: + reader = reader.option("numPartitions", num_partitions) + if partition_column and lower_bound and upper_bound: + reader = ( + reader.option("partitionColumn", partition_column) + .option("lowerBound", lower_bound) + .option("upperBound", upper_bound) + ) + + reader.load().write.saveAsTable(metastore_table, format=save_format, mode=save_mode) + + +def _parse_arguments(args: Optional[List[str]] = None) -> Any: + parser = argparse.ArgumentParser(description="Spark-JDBC") + parser.add_argument("-cmdType", dest="cmd_type", action="store") + parser.add_argument("-url", dest="url", action="store") + parser.add_argument("-user", dest="user", action="store") + parser.add_argument("-password", dest="password", action="store") + parser.add_argument("-metastoreTable", dest="metastore_table", action="store") + parser.add_argument("-jdbcTable", dest="jdbc_table", action="store") + parser.add_argument("-jdbcDriver", dest="jdbc_driver", action="store") + parser.add_argument("-jdbcTruncate", dest="truncate", action="store") + parser.add_argument("-saveMode", dest="save_mode", action="store") + parser.add_argument("-saveFormat", dest="save_format", action="store") + parser.add_argument("-batchsize", dest="batch_size", action="store") + parser.add_argument("-fetchsize", dest="fetch_size", action="store") + parser.add_argument("-name", dest="name", action="store") + parser.add_argument("-numPartitions", dest="num_partitions", action="store") + parser.add_argument("-partitionColumn", dest="partition_column", action="store") + parser.add_argument("-lowerBound", dest="lower_bound", action="store") + parser.add_argument("-upperBound", dest="upper_bound", action="store") + parser.add_argument( + "-createTableColumnTypes", dest="create_table_column_types", action="store" + ) + return parser.parse_args(args=args) + + +def _create_spark_session(arguments: Any) -> SparkSession: + return ( + SparkSession.builder.appName(arguments.name).enableHiveSupport().getOrCreate() + ) + + +def _run_spark(arguments: Any) -> None: + # Disable dynamic allocation by default to allow num_executors to take effect. + spark = _create_spark_session(arguments) + + if arguments.cmd_type == SPARK_WRITE_TO_JDBC: + spark_write_to_jdbc( + spark, + arguments.url, + arguments.user, + arguments.password, + arguments.metastore_table, + arguments.jdbc_table, + arguments.jdbc_driver, + arguments.truncate, + arguments.save_mode, + arguments.batch_size, + arguments.num_partitions, + arguments.create_table_column_types, + ) + elif arguments.cmd_type == SPARK_READ_FROM_JDBC: + spark_read_from_jdbc( + spark, + arguments.url, + arguments.user, + arguments.password, + arguments.metastore_table, + arguments.jdbc_table, + arguments.jdbc_driver, + arguments.save_mode, + arguments.save_format, + arguments.fetch_size, + arguments.num_partitions, + arguments.partition_column, + arguments.lower_bound, + arguments.upper_bound, + ) + + +if __name__ == "__main__": # pragma: no cover + _run_spark(arguments=_parse_arguments()) diff --git a/reference/providers/apache/spark/hooks/spark_sql.py b/reference/providers/apache/spark/hooks/spark_sql.py new file mode 100644 index 0000000..a703c6e --- /dev/null +++ b/reference/providers/apache/spark/hooks/spark_sql.py @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import subprocess +from typing import Any, List, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +class SparkSqlHook(BaseHook): + """ + This hook is a wrapper around the spark-sql binary. It requires that the + "spark-sql" binary is in the PATH. + + :param sql: The SQL query to execute + :type sql: str + :param conf: arbitrary Spark configuration property + :type conf: str (format: PROP=VALUE) + :param conn_id: connection_id string + :type conn_id: str + :param total_executor_cores: (Standalone & Mesos only) Total cores for all executors + (Default: all the available cores on the worker) + :type total_executor_cores: int + :param executor_cores: (Standalone & YARN only) Number of cores per + executor (Default: 2) + :type executor_cores: int + :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G) + :type executor_memory: str + :param keytab: Full path to the file that contains the keytab + :type keytab: str + :param master: spark://host:port, mesos://host:port, yarn, or local + :type master: str + :param name: Name of the job. + :type name: str + :param num_executors: Number of executors to launch + :type num_executors: int + :param verbose: Whether to pass the verbose flag to spark-sql + :type verbose: bool + :param yarn_queue: The YARN queue to submit to (Default: "default") + :type yarn_queue: str + """ + + conn_name_attr = "conn_id" + default_conn_name = "spark_sql_default" + conn_type = "spark_sql" + hook_name = "Spark SQL" + + # pylint: disable=too-many-arguments + def __init__( + self, + sql: str, + conf: Optional[str] = None, + conn_id: str = default_conn_name, + total_executor_cores: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + keytab: Optional[str] = None, + principal: Optional[str] = None, + master: str = "yarn", + name: str = "default-name", + num_executors: Optional[int] = None, + verbose: bool = True, + yarn_queue: str = "default", + ) -> None: + super().__init__() + self._sql = sql + self._conf = conf + self._conn = self.get_connection(conn_id) + self._total_executor_cores = total_executor_cores + self._executor_cores = executor_cores + self._executor_memory = executor_memory + self._keytab = keytab + self._principal = principal + self._master = master + self._name = name + self._num_executors = num_executors + self._verbose = verbose + self._yarn_queue = yarn_queue + self._sp: Any = None + + def get_conn(self) -> Any: + pass + + def _prepare_command(self, cmd: Union[str, List[str]]) -> List[str]: + """ + Construct the spark-sql command to execute. Verbose output is enabled + as default. + + :param cmd: command to append to the spark-sql command + :type cmd: str or list[str] + :return: full command to be executed + """ + connection_cmd = ["spark-sql"] + if self._conf: + for conf_el in self._conf.split(","): + connection_cmd += ["--conf", conf_el] + if self._total_executor_cores: + connection_cmd += [ + "--total-executor-cores", + str(self._total_executor_cores), + ] + if self._executor_cores: + connection_cmd += ["--executor-cores", str(self._executor_cores)] + if self._executor_memory: + connection_cmd += ["--executor-memory", self._executor_memory] + if self._keytab: + connection_cmd += ["--keytab", self._keytab] + if self._principal: + connection_cmd += ["--principal", self._principal] + if self._num_executors: + connection_cmd += ["--num-executors", str(self._num_executors)] + if self._sql: + sql = self._sql.strip() + if sql.endswith(".sql") or sql.endswith(".hql"): + connection_cmd += ["-f", sql] + else: + connection_cmd += ["-e", sql] + if self._master: + connection_cmd += ["--master", self._master] + if self._name: + connection_cmd += ["--name", self._name] + if self._verbose: + connection_cmd += ["--verbose"] + if self._yarn_queue: + connection_cmd += ["--queue", self._yarn_queue] + + if isinstance(cmd, str): + connection_cmd += cmd.split() + elif isinstance(cmd, list): + connection_cmd += cmd + else: + raise AirflowException(f"Invalid additional command: {cmd}") + + self.log.debug("Spark-Sql cmd: %s", connection_cmd) + + return connection_cmd + + def run_query(self, cmd: str = "", **kwargs: Any) -> None: + """ + Remote Popen (actually execute the Spark-sql query) + + :param cmd: command to append to the spark-sql command + :type cmd: str or list[str] + :param kwargs: extra arguments to Popen (see subprocess.Popen) + :type kwargs: dict + """ + spark_sql_cmd = self._prepare_command(cmd) + self._sp = subprocess.Popen( + spark_sql_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs + ) + + for line in iter(self._sp.stdout): # type: ignore + self.log.info(line) + + returncode = self._sp.wait() + + if returncode: + raise AirflowException( + "Cannot execute '{}' on {} (additional parameters: '{}'). Process exit code: {}.".format( + self._sql, self._master, cmd, returncode + ) + ) + + def kill(self) -> None: + """Kill Spark job""" + if self._sp and self._sp.poll() is None: + self.log.info("Killing the Spark-Sql job") + self._sp.kill() diff --git a/reference/providers/apache/spark/hooks/spark_submit.py b/reference/providers/apache/spark/hooks/spark_submit.py new file mode 100644 index 0000000..eda5d55 --- /dev/null +++ b/reference/providers/apache/spark/hooks/spark_submit.py @@ -0,0 +1,734 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import os +import re +import subprocess +import time +from typing import Any, Dict, Iterator, List, Optional, Union + +from airflow.configuration import conf as airflow_conf +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.security.kerberos import renew_from_kt +from airflow.utils.log.logging_mixin import LoggingMixin + +try: + from airflow.kubernetes import kube_client +except (ImportError, NameError): + pass + + +# pylint: disable=too-many-instance-attributes +class SparkSubmitHook(BaseHook, LoggingMixin): + """ + This hook is a wrapper around the spark-submit binary to kick off a spark-submit job. + It requires that the "spark-submit" binary is in the PATH or the spark_home to be + supplied. + + :param conf: Arbitrary Spark configuration properties + :type conf: dict + :param conn_id: The connection id as configured in Airflow administration. When an + invalid connection_id is supplied, it will default to yarn. + :type conn_id: str + :param files: Upload additional files to the executor running the job, separated by a + comma. Files will be placed in the working directory of each executor. + For example, serialized objects. + :type files: str + :param py_files: Additional python files used by the job, can be .zip, .egg or .py. + :type py_files: str + :param: archives: Archives that spark should unzip (and possibly tag with #ALIAS) into + the application working directory. + :param driver_class_path: Additional, driver-specific, classpath settings. + :type driver_class_path: str + :param jars: Submit additional jars to upload and place them in executor classpath. + :type jars: str + :param java_class: the main class of the Java application + :type java_class: str + :param packages: Comma-separated list of maven coordinates of jars to include on the + driver and executor classpaths + :type packages: str + :param exclude_packages: Comma-separated list of maven coordinates of jars to exclude + while resolving the dependencies provided in 'packages' + :type exclude_packages: str + :param repositories: Comma-separated list of additional remote repositories to search + for the maven coordinates given with 'packages' + :type repositories: str + :param total_executor_cores: (Standalone & Mesos only) Total cores for all executors + (Default: all the available cores on the worker) + :type total_executor_cores: int + :param executor_cores: (Standalone, YARN and Kubernetes only) Number of cores per + executor (Default: 2) + :type executor_cores: int + :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G) + :type executor_memory: str + :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) (Default: 1G) + :type driver_memory: str + :param keytab: Full path to the file that contains the keytab + :type keytab: str + :param principal: The name of the kerberos principal used for keytab + :type principal: str + :param proxy_user: User to impersonate when submitting the application + :type proxy_user: str + :param name: Name of the job (default airflow-spark) + :type name: str + :param num_executors: Number of executors to launch + :type num_executors: int + :param status_poll_interval: Seconds to wait between polls of driver status in cluster + mode (Default: 1) + :type status_poll_interval: int + :param application_args: Arguments for the application being submitted + :type application_args: list + :param env_vars: Environment variables for spark-submit. It + supports yarn and k8s mode too. + :type env_vars: dict + :param verbose: Whether to pass the verbose flag to spark-submit process for debugging + :type verbose: bool + :param spark_binary: The command to use for spark submit. + Some distros may use spark2-submit. + :type spark_binary: str + """ + + conn_name_attr = "conn_id" + default_conn_name = "spark_default" + conn_type = "spark" + hook_name = "Spark" + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ["schema", "login", "password"], + "relabeling": {}, + } + + # pylint: disable=too-many-arguments,too-many-locals,too-many-branches + def __init__( + self, + conf: Optional[Dict[str, Any]] = None, + conn_id: str = "spark_default", + files: Optional[str] = None, + py_files: Optional[str] = None, + archives: Optional[str] = None, + driver_class_path: Optional[str] = None, + jars: Optional[str] = None, + java_class: Optional[str] = None, + packages: Optional[str] = None, + exclude_packages: Optional[str] = None, + repositories: Optional[str] = None, + total_executor_cores: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + driver_memory: Optional[str] = None, + keytab: Optional[str] = None, + principal: Optional[str] = None, + proxy_user: Optional[str] = None, + name: str = "default-name", + num_executors: Optional[int] = None, + status_poll_interval: int = 1, + application_args: Optional[List[Any]] = None, + env_vars: Optional[Dict[str, Any]] = None, + verbose: bool = False, + spark_binary: Optional[str] = None, + ) -> None: + super().__init__() + self._conf = conf or {} + self._conn_id = conn_id + self._files = files + self._py_files = py_files + self._archives = archives + self._driver_class_path = driver_class_path + self._jars = jars + self._java_class = java_class + self._packages = packages + self._exclude_packages = exclude_packages + self._repositories = repositories + self._total_executor_cores = total_executor_cores + self._executor_cores = executor_cores + self._executor_memory = executor_memory + self._driver_memory = driver_memory + self._keytab = keytab + self._principal = principal + self._proxy_user = proxy_user + self._name = name + self._num_executors = num_executors + self._status_poll_interval = status_poll_interval + self._application_args = application_args + self._env_vars = env_vars + self._verbose = verbose + self._submit_sp: Optional[Any] = None + self._yarn_application_id: Optional[str] = None + self._kubernetes_driver_pod: Optional[str] = None + self._spark_binary = spark_binary + + self._connection = self._resolve_connection() + self._is_yarn = "yarn" in self._connection["master"] + self._is_kubernetes = "k8s" in self._connection["master"] + if self._is_kubernetes and kube_client is None: + raise RuntimeError( + "{} specified by kubernetes dependencies are not installed!".format( + self._connection["master"] + ) + ) + + self._should_track_driver_status = self._resolve_should_track_driver_status() + self._driver_id: Optional[str] = None + self._driver_status: Optional[str] = None + self._spark_exit_code: Optional[int] = None + self._env: Optional[Dict[str, Any]] = None + + def _resolve_should_track_driver_status(self) -> bool: + """ + Determines whether or not this hook should poll the spark driver status through + subsequent spark-submit status requests after the initial spark-submit request + :return: if the driver status should be tracked + """ + return ( + "spark://" in self._connection["master"] + and self._connection["deploy_mode"] == "cluster" + ) + + def _resolve_connection(self) -> Dict[str, Any]: + # Build from connection master or default to yarn if not available + conn_data = { + "master": "yarn", + "queue": None, + "deploy_mode": None, + "spark_home": None, + "spark_binary": self._spark_binary or "spark-submit", + "namespace": None, + } + + try: + # Master can be local, yarn, spark://HOST:PORT, mesos://HOST:PORT and + # k8s://https://: + conn = self.get_connection(self._conn_id) + if conn.port: + conn_data["master"] = f"{conn.host}:{conn.port}" + else: + conn_data["master"] = conn.host + + # Determine optional yarn queue from the extra field + extra = conn.extra_dejson + conn_data["queue"] = extra.get("queue") + conn_data["deploy_mode"] = extra.get("deploy-mode") + conn_data["spark_home"] = extra.get("spark-home") + conn_data["spark_binary"] = self._spark_binary or extra.get( + "spark-binary", "spark-submit" + ) + conn_data["namespace"] = extra.get("namespace") + except AirflowException: + self.log.info( + "Could not load connection string %s, defaulting to %s", + self._conn_id, + conn_data["master"], + ) + + if "spark.kubernetes.namespace" in self._conf: + conn_data["namespace"] = self._conf["spark.kubernetes.namespace"] + + return conn_data + + def get_conn(self) -> Any: + pass + + def _get_spark_binary_path(self) -> List[str]: + # If the spark_home is passed then build the spark-submit executable path using + # the spark_home; otherwise assume that spark-submit is present in the path to + # the executing user + if self._connection["spark_home"]: + connection_cmd = [ + os.path.join( + self._connection["spark_home"], + "bin", + self._connection["spark_binary"], + ) + ] + else: + connection_cmd = [self._connection["spark_binary"]] + + return connection_cmd + + def _mask_cmd(self, connection_cmd: Union[str, List[str]]) -> str: + # Mask any password related fields in application args with key value pair + # where key contains password (case insensitive), e.g. HivePassword='abc' + connection_cmd_masked = re.sub( + r"(" + r"\S*?" # Match all non-whitespace characters before... + r"(?:secret|password)" # ...literally a "secret" or "password" + # word (not capturing them). + r"\S*?" # All non-whitespace characters before either... + r"(?:=|\s+)" # ...an equal sign or whitespace characters + # (not capturing them). + r"(['\"]?)" # An optional single or double quote. + r")" # This is the end of the first capturing group. + r"(?:(?!\2\s).)*" # All characters between optional quotes + # (matched above); if the value is quoted, + # it may contain whitespace. + r"(\2)", # Optional matching quote. + r"\1******\3", + " ".join(connection_cmd), + flags=re.I, + ) + + return connection_cmd_masked + + def _build_spark_submit_command(self, application: str) -> List[str]: + """ + Construct the spark-submit command to execute. + + :param application: command to append to the spark-submit command + :type application: str + :return: full command to be executed + """ + connection_cmd = self._get_spark_binary_path() + + # The url of the spark master + connection_cmd += ["--master", self._connection["master"]] + + for key in self._conf: + connection_cmd += ["--conf", f"{key}={str(self._conf[key])}"] + if self._env_vars and (self._is_kubernetes or self._is_yarn): + if self._is_yarn: + tmpl = "spark.yarn.appMasterEnv.{}={}" + # Allow dynamic setting of hadoop/yarn configuration environments + self._env = self._env_vars + else: + tmpl = "spark.kubernetes.driverEnv.{}={}" + for key in self._env_vars: + connection_cmd += ["--conf", tmpl.format(key, str(self._env_vars[key]))] + elif self._env_vars and self._connection["deploy_mode"] != "cluster": + self._env = self._env_vars # Do it on Popen of the process + elif self._env_vars and self._connection["deploy_mode"] == "cluster": + raise AirflowException( + "SparkSubmitHook env_vars is not supported in standalone-cluster mode." + ) + if self._is_kubernetes and self._connection["namespace"]: + connection_cmd += [ + "--conf", + f"spark.kubernetes.namespace={self._connection['namespace']}", + ] + if self._files: + connection_cmd += ["--files", self._files] + if self._py_files: + connection_cmd += ["--py-files", self._py_files] + if self._archives: + connection_cmd += ["--archives", self._archives] + if self._driver_class_path: + connection_cmd += ["--driver-class-path", self._driver_class_path] + if self._jars: + connection_cmd += ["--jars", self._jars] + if self._packages: + connection_cmd += ["--packages", self._packages] + if self._exclude_packages: + connection_cmd += ["--exclude-packages", self._exclude_packages] + if self._repositories: + connection_cmd += ["--repositories", self._repositories] + if self._num_executors: + connection_cmd += ["--num-executors", str(self._num_executors)] + if self._total_executor_cores: + connection_cmd += [ + "--total-executor-cores", + str(self._total_executor_cores), + ] + if self._executor_cores: + connection_cmd += ["--executor-cores", str(self._executor_cores)] + if self._executor_memory: + connection_cmd += ["--executor-memory", self._executor_memory] + if self._driver_memory: + connection_cmd += ["--driver-memory", self._driver_memory] + if self._keytab: + connection_cmd += ["--keytab", self._keytab] + if self._principal: + connection_cmd += ["--principal", self._principal] + if self._proxy_user: + connection_cmd += ["--proxy-user", self._proxy_user] + if self._name: + connection_cmd += ["--name", self._name] + if self._java_class: + connection_cmd += ["--class", self._java_class] + if self._verbose: + connection_cmd += ["--verbose"] + if self._connection["queue"]: + connection_cmd += ["--queue", self._connection["queue"]] + if self._connection["deploy_mode"]: + connection_cmd += ["--deploy-mode", self._connection["deploy_mode"]] + + # The actual script to execute + connection_cmd += [application] + + # Append any application arguments + if self._application_args: + connection_cmd += self._application_args + + self.log.info("Spark-Submit cmd: %s", self._mask_cmd(connection_cmd)) + + return connection_cmd + + def _build_track_driver_status_command(self) -> List[str]: + """ + Construct the command to poll the driver status. + + :return: full command to be executed + """ + curl_max_wait_time = 30 + spark_host = self._connection["master"] + if spark_host.endswith(":6066"): + spark_host = spark_host.replace("spark://", "http://") + connection_cmd = [ + "/usr/bin/curl", + "--max-time", + str(curl_max_wait_time), + f"{spark_host}/v1/submissions/status/{self._driver_id}", + ] + self.log.info(connection_cmd) + + # The driver id so we can poll for its status + if self._driver_id: + pass + else: + raise AirflowException( + "Invalid status: attempted to poll driver " + + "status but no driver id is known. Giving up." + ) + + else: + + connection_cmd = self._get_spark_binary_path() + + # The url to the spark master + connection_cmd += ["--master", self._connection["master"]] + + # The driver id so we can poll for its status + if self._driver_id: + connection_cmd += ["--status", self._driver_id] + else: + raise AirflowException( + "Invalid status: attempted to poll driver " + + "status but no driver id is known. Giving up." + ) + + self.log.debug("Poll driver status cmd: %s", connection_cmd) + + return connection_cmd + + def submit(self, application: str = "", **kwargs: Any) -> None: + """ + Remote Popen to execute the spark-submit job + + :param application: Submitted application, jar or py file + :type application: str + :param kwargs: extra arguments to Popen (see subprocess.Popen) + """ + spark_submit_cmd = self._build_spark_submit_command(application) + + if self._env: + env = os.environ.copy() + env.update(self._env) + kwargs["env"] = env + + self._submit_sp = subprocess.Popen( + spark_submit_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=-1, + universal_newlines=True, + **kwargs, + ) + + self._process_spark_submit_log(iter(self._submit_sp.stdout)) # type: ignore + returncode = self._submit_sp.wait() + + # Check spark-submit return code. In Kubernetes mode, also check the value + # of exit code in the log, as it may differ. + if returncode or (self._is_kubernetes and self._spark_exit_code != 0): + if self._is_kubernetes: + raise AirflowException( + "Cannot execute: {}. Error code is: {}. Kubernetes spark exit code is: {}".format( + self._mask_cmd(spark_submit_cmd), + returncode, + self._spark_exit_code, + ) + ) + else: + raise AirflowException( + "Cannot execute: {}. Error code is: {}.".format( + self._mask_cmd(spark_submit_cmd), returncode + ) + ) + + self.log.debug("Should track driver: %s", self._should_track_driver_status) + + # We want the Airflow job to wait until the Spark driver is finished + if self._should_track_driver_status: + if self._driver_id is None: + raise AirflowException( + "No driver id is known: something went wrong when executing " + + "the spark submit command" + ) + + # We start with the SUBMITTED status as initial status + self._driver_status = "SUBMITTED" + + # Start tracking the driver status (blocking function) + self._start_driver_status_tracking() + + if self._driver_status != "FINISHED": + raise AirflowException( + "ERROR : Driver {} badly exited with status {}".format( + self._driver_id, self._driver_status + ) + ) + + def _process_spark_submit_log(self, itr: Iterator[Any]) -> None: + """ + Processes the log files and extracts useful information out of it. + + If the deploy-mode is 'client', log the output of the submit command as those + are the output logs of the Spark worker directly. + + Remark: If the driver needs to be tracked for its status, the log-level of the + spark deploy needs to be at least INFO (log4j.logger.org.apache.spark.deploy=INFO) + + :param itr: An iterator which iterates over the input of the subprocess + """ + # Consume the iterator + for line in itr: + line = line.strip() + # If we run yarn cluster mode, we want to extract the application id from + # the logs so we can kill the application when we stop it unexpectedly + if self._is_yarn and self._connection["deploy_mode"] == "cluster": + match = re.search("(application[0-9_]+)", line) + if match: + self._yarn_application_id = match.groups()[0] + self.log.info( + "Identified spark driver id: %s", self._yarn_application_id + ) + + # If we run Kubernetes cluster mode, we want to extract the driver pod id + # from the logs so we can kill the application when we stop it unexpectedly + elif self._is_kubernetes: + match = re.search(r"\s*pod name: ((.+?)-([a-z0-9]+)-driver)", line) + if match: + self._kubernetes_driver_pod = match.groups()[0] + self.log.info( + "Identified spark driver pod: %s", self._kubernetes_driver_pod + ) + + # Store the Spark Exit code + match_exit_code = re.search(r"\s*[eE]xit code: (\d+)", line) + if match_exit_code: + self._spark_exit_code = int(match_exit_code.groups()[0]) + + # if we run in standalone cluster mode and we want to track the driver status + # we need to extract the driver id from the logs. This allows us to poll for + # the status using the driver id. Also, we can kill the driver when needed. + elif self._should_track_driver_status and not self._driver_id: + match_driver_id = re.search(r"(driver-[0-9\-]+)", line) + if match_driver_id: + self._driver_id = match_driver_id.groups()[0] + self.log.info("identified spark driver id: %s", self._driver_id) + + self.log.info(line) + + def _process_spark_status_log(self, itr: Iterator[Any]) -> None: + """ + Parses the logs of the spark driver status query process + + :param itr: An iterator which iterates over the input of the subprocess + """ + driver_found = False + # Consume the iterator + for line in itr: + line = line.strip() + + # Check if the log line is about the driver status and extract the status. + if "driverState" in line: + self._driver_status = ( + line.split(" : ")[1].replace(",", "").replace('"', "").strip() + ) + driver_found = True + + self.log.debug("spark driver status log: %s", line) + + if not driver_found: + self._driver_status = "UNKNOWN" + + def _start_driver_status_tracking(self) -> None: + """ + Polls the driver based on self._driver_id to get the status. + Finish successfully when the status is FINISHED. + Finish failed when the status is ERROR/UNKNOWN/KILLED/FAILED. + + Possible status: + + SUBMITTED + Submitted but not yet scheduled on a worker + RUNNING + Has been allocated to a worker to run + FINISHED + Previously ran and exited cleanly + RELAUNCHING + Exited non-zero or due to worker failure, but has not yet + started running again + UNKNOWN + The status of the driver is temporarily not known due to + master failure recovery + KILLED + A user manually killed this driver + FAILED + The driver exited non-zero and was not supervised + ERROR + Unable to run or restart due to an unrecoverable error + (e.g. missing jar file) + """ + # When your Spark Standalone cluster is not performing well + # due to misconfiguration or heavy loads. + # it is possible that the polling request will timeout. + # Therefore we use a simple retry mechanism. + missed_job_status_reports = 0 + max_missed_job_status_reports = 10 + + # Keep polling as long as the driver is processing + while self._driver_status not in [ + "FINISHED", + "UNKNOWN", + "KILLED", + "FAILED", + "ERROR", + ]: + + # Sleep for n seconds as we do not want to spam the cluster + time.sleep(self._status_poll_interval) + + self.log.debug("polling status of spark driver with id %s", self._driver_id) + + poll_drive_status_cmd = self._build_track_driver_status_command() + status_process: Any = subprocess.Popen( + poll_drive_status_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=-1, + universal_newlines=True, + ) + + self._process_spark_status_log(iter(status_process.stdout)) + returncode = status_process.wait() + + if returncode: + if missed_job_status_reports < max_missed_job_status_reports: + missed_job_status_reports += 1 + else: + raise AirflowException( + "Failed to poll for the driver status {} times: returncode = {}".format( + max_missed_job_status_reports, returncode + ) + ) + + def _build_spark_driver_kill_command(self) -> List[str]: + """ + Construct the spark-submit command to kill a driver. + :return: full command to kill a driver + """ + # If the spark_home is passed then build the spark-submit executable path using + # the spark_home; otherwise assume that spark-submit is present in the path to + # the executing user + if self._connection["spark_home"]: + connection_cmd = [ + os.path.join( + self._connection["spark_home"], + "bin", + self._connection["spark_binary"], + ) + ] + else: + connection_cmd = [self._connection["spark_binary"]] + + # The url to the spark master + connection_cmd += ["--master", self._connection["master"]] + + # The actual kill command + if self._driver_id: + connection_cmd += ["--kill", self._driver_id] + + self.log.debug("Spark-Kill cmd: %s", connection_cmd) + + return connection_cmd + + def on_kill(self) -> None: + """Kill Spark submit command""" + self.log.debug("Kill Command is being called") + + if self._should_track_driver_status: + if self._driver_id: + self.log.info("Killing driver %s on cluster", self._driver_id) + + kill_cmd = self._build_spark_driver_kill_command() + driver_kill = subprocess.Popen( + kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + self.log.info( + "Spark driver %s killed with return code: %s", + self._driver_id, + driver_kill.wait(), + ) + + if self._submit_sp and self._submit_sp.poll() is None: + self.log.info("Sending kill signal to %s", self._connection["spark_binary"]) + self._submit_sp.kill() + + if self._yarn_application_id: + kill_cmd = f"yarn application -kill {self._yarn_application_id}".split() + env = None + if self._keytab is not None and self._principal is not None: + # we are ignoring renewal failures from renew_from_kt + # here as the failure could just be due to a non-renewable ticket, + # we still attempt to kill the yarn application + renew_from_kt(self._principal, self._keytab, exit_on_fail=False) + env = os.environ.copy() + env["KRB5CCNAME"] = airflow_conf.get("kerberos", "ccache") + + yarn_kill = subprocess.Popen( + kill_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + self.log.info("YARN app killed with return code: %s", yarn_kill.wait()) + + if self._kubernetes_driver_pod: + self.log.info( + "Killing pod %s on Kubernetes", self._kubernetes_driver_pod + ) + + # Currently only instantiate Kubernetes client for killing a spark pod. + try: + import kubernetes + + client = kube_client.get_kube_client() + api_response = client.delete_namespaced_pod( + self._kubernetes_driver_pod, + self._connection["namespace"], + body=kubernetes.client.V1DeleteOptions(), + pretty=True, + ) + + self.log.info("Spark on K8s killed with response: %s", api_response) + + except kube_client.ApiException as e: + self.log.error("Exception when attempting to kill Spark on K8s:") + self.log.exception(e) diff --git a/reference/providers/apache/spark/operators/__init__.py b/reference/providers/apache/spark/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/spark/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/spark/operators/spark_jdbc.py b/reference/providers/apache/spark/operators/spark_jdbc.py new file mode 100644 index 0000000..6ddf5e8 --- /dev/null +++ b/reference/providers/apache/spark/operators/spark_jdbc.py @@ -0,0 +1,228 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Any, Dict, Optional + +from airflow.providers.apache.spark.hooks.spark_jdbc import SparkJDBCHook +from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator +from airflow.utils.decorators import apply_defaults + + +# pylint: disable=too-many-instance-attributes +class SparkJDBCOperator(SparkSubmitOperator): + """ + This operator extends the SparkSubmitOperator specifically for performing data + transfers to/from JDBC-based databases with Apache Spark. As with the + SparkSubmitOperator, it assumes that the "spark-submit" binary is available on the + PATH. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SparkJDBCOperator` + + :param spark_app_name: Name of the job (default airflow-spark-jdbc) + :type spark_app_name: str + :param spark_conn_id: Connection id as configured in Airflow administration + :type spark_conn_id: str + :param spark_conf: Any additional Spark configuration properties + :type spark_conf: dict + :param spark_py_files: Additional python files used (.zip, .egg, or .py) + :type spark_py_files: str + :param spark_files: Additional files to upload to the container running the job + :type spark_files: str + :param spark_jars: Additional jars to upload and add to the driver and + executor classpath + :type spark_jars: str + :param num_executors: number of executor to run. This should be set so as to manage + the number of connections made with the JDBC database + :type num_executors: int + :param executor_cores: Number of cores per executor + :type executor_cores: int + :param executor_memory: Memory per executor (e.g. 1000M, 2G) + :type executor_memory: str + :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) + :type driver_memory: str + :param verbose: Whether to pass the verbose flag to spark-submit for debugging + :type verbose: bool + :param keytab: Full path to the file that contains the keytab + :type keytab: str + :param principal: The name of the kerberos principal used for keytab + :type principal: str + :param cmd_type: Which way the data should flow. 2 possible values: + spark_to_jdbc: data written by spark from metastore to jdbc + jdbc_to_spark: data written by spark from jdbc to metastore + :type cmd_type: str + :param jdbc_table: The name of the JDBC table + :type jdbc_table: str + :param jdbc_conn_id: Connection id used for connection to JDBC database + :type jdbc_conn_id: str + :param jdbc_driver: Name of the JDBC driver to use for the JDBC connection. This + driver (usually a jar) should be passed in the 'jars' parameter + :type jdbc_driver: str + :param metastore_table: The name of the metastore table, + :type metastore_table: str + :param jdbc_truncate: (spark_to_jdbc only) Whether or not Spark should truncate or + drop and recreate the JDBC table. This only takes effect if + 'save_mode' is set to Overwrite. Also, if the schema is + different, Spark cannot truncate, and will drop and recreate + :type jdbc_truncate: bool + :param save_mode: The Spark save-mode to use (e.g. overwrite, append, etc.) + :type save_mode: str + :param save_format: (jdbc_to_spark-only) The Spark save-format to use (e.g. parquet) + :type save_format: str + :param batch_size: (spark_to_jdbc only) The size of the batch to insert per round + trip to the JDBC database. Defaults to 1000 + :type batch_size: int + :param fetch_size: (jdbc_to_spark only) The size of the batch to fetch per round trip + from the JDBC database. Default depends on the JDBC driver + :type fetch_size: int + :param num_partitions: The maximum number of partitions that can be used by Spark + simultaneously, both for spark_to_jdbc and jdbc_to_spark + operations. This will also cap the number of JDBC connections + that can be opened + :type num_partitions: int + :param partition_column: (jdbc_to_spark-only) A numeric column to be used to + partition the metastore table by. If specified, you must + also specify: + num_partitions, lower_bound, upper_bound + :type partition_column: str + :param lower_bound: (jdbc_to_spark-only) Lower bound of the range of the numeric + partition column to fetch. If specified, you must also specify: + num_partitions, partition_column, upper_bound + :type lower_bound: int + :param upper_bound: (jdbc_to_spark-only) Upper bound of the range of the numeric + partition column to fetch. If specified, you must also specify: + num_partitions, partition_column, lower_bound + :type upper_bound: int + :param create_table_column_types: (spark_to_jdbc-only) The database column data types + to use instead of the defaults, when creating the + table. Data type information should be specified in + the same format as CREATE TABLE columns syntax + (e.g: "name CHAR(64), comments VARCHAR(1024)"). + The specified types should be valid spark sql data + types. + """ + + # pylint: disable=too-many-arguments,too-many-locals + @apply_defaults + def __init__( + self, + *, + spark_app_name: str = "airflow-spark-jdbc", + spark_conn_id: str = "spark-default", + spark_conf: Optional[Dict[str, Any]] = None, + spark_py_files: Optional[str] = None, + spark_files: Optional[str] = None, + spark_jars: Optional[str] = None, + num_executors: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + driver_memory: Optional[str] = None, + verbose: bool = False, + principal: Optional[str] = None, + keytab: Optional[str] = None, + cmd_type: str = "spark_to_jdbc", + jdbc_table: Optional[str] = None, + jdbc_conn_id: str = "jdbc-default", + jdbc_driver: Optional[str] = None, + metastore_table: Optional[str] = None, + jdbc_truncate: bool = False, + save_mode: Optional[str] = None, + save_format: Optional[str] = None, + batch_size: Optional[int] = None, + fetch_size: Optional[int] = None, + num_partitions: Optional[int] = None, + partition_column: Optional[str] = None, + lower_bound: Optional[str] = None, + upper_bound: Optional[str] = None, + create_table_column_types: Optional[str] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self._spark_app_name = spark_app_name + self._spark_conn_id = spark_conn_id + self._spark_conf = spark_conf + self._spark_py_files = spark_py_files + self._spark_files = spark_files + self._spark_jars = spark_jars + self._num_executors = num_executors + self._executor_cores = executor_cores + self._executor_memory = executor_memory + self._driver_memory = driver_memory + self._verbose = verbose + self._keytab = keytab + self._principal = principal + self._cmd_type = cmd_type + self._jdbc_table = jdbc_table + self._jdbc_conn_id = jdbc_conn_id + self._jdbc_driver = jdbc_driver + self._metastore_table = metastore_table + self._jdbc_truncate = jdbc_truncate + self._save_mode = save_mode + self._save_format = save_format + self._batch_size = batch_size + self._fetch_size = fetch_size + self._num_partitions = num_partitions + self._partition_column = partition_column + self._lower_bound = lower_bound + self._upper_bound = upper_bound + self._create_table_column_types = create_table_column_types + self._hook: Optional[SparkJDBCHook] = None + + def execute(self, context: Dict[str, Any]) -> None: + """Call the SparkSubmitHook to run the provided spark job""" + if self._hook is None: + self._hook = self._get_hook() + self._hook.submit_jdbc_job() + + def on_kill(self) -> None: + if self._hook is None: + self._hook = self._get_hook() + self._hook.on_kill() + + def _get_hook(self) -> SparkJDBCHook: + return SparkJDBCHook( + spark_app_name=self._spark_app_name, + spark_conn_id=self._spark_conn_id, + spark_conf=self._spark_conf, + spark_py_files=self._spark_py_files, + spark_files=self._spark_files, + spark_jars=self._spark_jars, + num_executors=self._num_executors, + executor_cores=self._executor_cores, + executor_memory=self._executor_memory, + driver_memory=self._driver_memory, + verbose=self._verbose, + keytab=self._keytab, + principal=self._principal, + cmd_type=self._cmd_type, + jdbc_table=self._jdbc_table, + jdbc_conn_id=self._jdbc_conn_id, + jdbc_driver=self._jdbc_driver, + metastore_table=self._metastore_table, + jdbc_truncate=self._jdbc_truncate, + save_mode=self._save_mode, + save_format=self._save_format, + batch_size=self._batch_size, + fetch_size=self._fetch_size, + num_partitions=self._num_partitions, + partition_column=self._partition_column, + lower_bound=self._lower_bound, + upper_bound=self._upper_bound, + create_table_column_types=self._create_table_column_types, + ) diff --git a/reference/providers/apache/spark/operators/spark_sql.py b/reference/providers/apache/spark/operators/spark_sql.py new file mode 100644 index 0000000..deba50b --- /dev/null +++ b/reference/providers/apache/spark/operators/spark_sql.py @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Any, Dict, Optional + +from airflow.models import BaseOperator +from airflow.providers.apache.spark.hooks.spark_sql import SparkSqlHook +from airflow.utils.decorators import apply_defaults + + +class SparkSqlOperator(BaseOperator): + """ + Execute Spark SQL query + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SparkSqlOperator` + + :param sql: The SQL query to execute. (templated) + :type sql: str + :param conf: arbitrary Spark configuration property + :type conf: str (format: PROP=VALUE) + :param conn_id: connection_id string + :type conn_id: str + :param total_executor_cores: (Standalone & Mesos only) Total cores for all + executors (Default: all the available cores on the worker) + :type total_executor_cores: int + :param executor_cores: (Standalone & YARN only) Number of cores per + executor (Default: 2) + :type executor_cores: int + :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G) + :type executor_memory: str + :param keytab: Full path to the file that contains the keytab + :type keytab: str + :param master: spark://host:port, mesos://host:port, yarn, or local + :type master: str + :param name: Name of the job + :type name: str + :param num_executors: Number of executors to launch + :type num_executors: int + :param verbose: Whether to pass the verbose flag to spark-sql + :type verbose: bool + :param yarn_queue: The YARN queue to submit to (Default: "default") + :type yarn_queue: str + """ + + template_fields = ["_sql"] + template_ext = [".sql", ".hql"] + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + sql: str, + conf: Optional[str] = None, + conn_id: str = "spark_sql_default", + total_executor_cores: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + keytab: Optional[str] = None, + principal: Optional[str] = None, + master: str = "yarn", + name: str = "default-name", + num_executors: Optional[int] = None, + verbose: bool = True, + yarn_queue: str = "default", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self._sql = sql + self._conf = conf + self._conn_id = conn_id + self._total_executor_cores = total_executor_cores + self._executor_cores = executor_cores + self._executor_memory = executor_memory + self._keytab = keytab + self._principal = principal + self._master = master + self._name = name + self._num_executors = num_executors + self._verbose = verbose + self._yarn_queue = yarn_queue + self._hook: Optional[SparkSqlHook] = None + + def execute(self, context: Dict[str, Any]) -> None: + """Call the SparkSqlHook to run the provided sql query""" + if self._hook is None: + self._hook = self._get_hook() + self._hook.run_query() + + def on_kill(self) -> None: + if self._hook is None: + self._hook = self._get_hook() + self._hook.kill() + + def _get_hook(self) -> SparkSqlHook: + """Get SparkSqlHook""" + return SparkSqlHook( + sql=self._sql, + conf=self._conf, + conn_id=self._conn_id, + total_executor_cores=self._total_executor_cores, + executor_cores=self._executor_cores, + executor_memory=self._executor_memory, + keytab=self._keytab, + principal=self._principal, + name=self._name, + num_executors=self._num_executors, + master=self._master, + verbose=self._verbose, + yarn_queue=self._yarn_queue, + ) diff --git a/reference/providers/apache/spark/operators/spark_submit.py b/reference/providers/apache/spark/operators/spark_submit.py new file mode 100644 index 0000000..10330a3 --- /dev/null +++ b/reference/providers/apache/spark/operators/spark_submit.py @@ -0,0 +1,217 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Any, Dict, List, Optional + +from airflow.models import BaseOperator +from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook +from airflow.settings import WEB_COLORS +from airflow.utils.decorators import apply_defaults + + +# pylint: disable=too-many-instance-attributes +class SparkSubmitOperator(BaseOperator): + """ + This hook is a wrapper around the spark-submit binary to kick off a spark-submit job. + It requires that the "spark-submit" binary is in the PATH or the spark-home is set + in the extra on the connection. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SparkSubmitOperator` + + :param application: The application that submitted as a job, either jar or py file. (templated) + :type application: str + :param conf: Arbitrary Spark configuration properties (templated) + :type conf: dict + :param conn_id: The connection id as configured in Airflow administration. When an + invalid connection_id is supplied, it will default to yarn. + :type conn_id: str + :param files: Upload additional files to the executor running the job, separated by a + comma. Files will be placed in the working directory of each executor. + For example, serialized objects. (templated) + :type files: str + :param py_files: Additional python files used by the job, can be .zip, .egg or .py. (templated) + :type py_files: str + :param jars: Submit additional jars to upload and place them in executor classpath. (templated) + :type jars: str + :param driver_class_path: Additional, driver-specific, classpath settings. (templated) + :type driver_class_path: str + :param java_class: the main class of the Java application + :type java_class: str + :param packages: Comma-separated list of maven coordinates of jars to include on the + driver and executor classpaths. (templated) + :type packages: str + :param exclude_packages: Comma-separated list of maven coordinates of jars to exclude + while resolving the dependencies provided in 'packages' (templated) + :type exclude_packages: str + :param repositories: Comma-separated list of additional remote repositories to search + for the maven coordinates given with 'packages' + :type repositories: str + :param total_executor_cores: (Standalone & Mesos only) Total cores for all executors + (Default: all the available cores on the worker) + :type total_executor_cores: int + :param executor_cores: (Standalone & YARN only) Number of cores per executor (Default: 2) + :type executor_cores: int + :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G) + :type executor_memory: str + :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) (Default: 1G) + :type driver_memory: str + :param keytab: Full path to the file that contains the keytab (templated) + :type keytab: str + :param principal: The name of the kerberos principal used for keytab (templated) + :type principal: str + :param proxy_user: User to impersonate when submitting the application (templated) + :type proxy_user: str + :param name: Name of the job (default airflow-spark). (templated) + :type name: str + :param num_executors: Number of executors to launch + :type num_executors: int + :param status_poll_interval: Seconds to wait between polls of driver status in cluster + mode (Default: 1) + :type status_poll_interval: int + :param application_args: Arguments for the application being submitted (templated) + :type application_args: list + :param env_vars: Environment variables for spark-submit. It supports yarn and k8s mode too. (templated) + :type env_vars: dict + :param verbose: Whether to pass the verbose flag to spark-submit process for debugging + :type verbose: bool + :param spark_binary: The command to use for spark submit. + Some distros may use spark2-submit. + :type spark_binary: str + """ + + template_fields = ( + "_application", + "_conf", + "_files", + "_py_files", + "_jars", + "_driver_class_path", + "_packages", + "_exclude_packages", + "_keytab", + "_principal", + "_proxy_user", + "_name", + "_application_args", + "_env_vars", + ) + ui_color = WEB_COLORS["LIGHTORANGE"] + + # pylint: disable=too-many-arguments,too-many-locals + @apply_defaults + def __init__( + self, + *, + application: str = "", + conf: Optional[Dict[str, Any]] = None, + conn_id: str = "spark_default", + files: Optional[str] = None, + py_files: Optional[str] = None, + archives: Optional[str] = None, + driver_class_path: Optional[str] = None, + jars: Optional[str] = None, + java_class: Optional[str] = None, + packages: Optional[str] = None, + exclude_packages: Optional[str] = None, + repositories: Optional[str] = None, + total_executor_cores: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + driver_memory: Optional[str] = None, + keytab: Optional[str] = None, + principal: Optional[str] = None, + proxy_user: Optional[str] = None, + name: str = "arrow-spark", + num_executors: Optional[int] = None, + status_poll_interval: int = 1, + application_args: Optional[List[Any]] = None, + env_vars: Optional[Dict[str, Any]] = None, + verbose: bool = False, + spark_binary: Optional[str] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self._application = application + self._conf = conf + self._files = files + self._py_files = py_files + self._archives = archives + self._driver_class_path = driver_class_path + self._jars = jars + self._java_class = java_class + self._packages = packages + self._exclude_packages = exclude_packages + self._repositories = repositories + self._total_executor_cores = total_executor_cores + self._executor_cores = executor_cores + self._executor_memory = executor_memory + self._driver_memory = driver_memory + self._keytab = keytab + self._principal = principal + self._proxy_user = proxy_user + self._name = name + self._num_executors = num_executors + self._status_poll_interval = status_poll_interval + self._application_args = application_args + self._env_vars = env_vars + self._verbose = verbose + self._spark_binary = spark_binary + self._hook: Optional[SparkSubmitHook] = None + self._conn_id = conn_id + + def execute(self, context: Dict[str, Any]) -> None: + """Call the SparkSubmitHook to run the provided spark job""" + if self._hook is None: + self._hook = self._get_hook() + self._hook.submit(self._application) + + def on_kill(self) -> None: + if self._hook is None: + self._hook = self._get_hook() + self._hook.on_kill() + + def _get_hook(self) -> SparkSubmitHook: + return SparkSubmitHook( + conf=self._conf, + conn_id=self._conn_id, + files=self._files, + py_files=self._py_files, + archives=self._archives, + driver_class_path=self._driver_class_path, + jars=self._jars, + java_class=self._java_class, + packages=self._packages, + exclude_packages=self._exclude_packages, + repositories=self._repositories, + total_executor_cores=self._total_executor_cores, + executor_cores=self._executor_cores, + executor_memory=self._executor_memory, + driver_memory=self._driver_memory, + keytab=self._keytab, + principal=self._principal, + proxy_user=self._proxy_user, + name=self._name, + num_executors=self._num_executors, + status_poll_interval=self._status_poll_interval, + application_args=self._application_args, + env_vars=self._env_vars, + verbose=self._verbose, + spark_binary=self._spark_binary, + ) diff --git a/reference/providers/apache/spark/provider.yaml b/reference/providers/apache/spark/provider.yaml new file mode 100644 index 0000000..3b212fb --- /dev/null +++ b/reference/providers/apache/spark/provider.yaml @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-apache-spark +name: Apache Spark +description: | + `Apache Spark `__ + +versions: + - 1.0.2 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Apache Spark + external-doc-url: https://spark.apache.org/ + how-to-guide: + - /docs/apache-airflow-providers-apache-spark/operators.rst + logo: /integration-logos/apache/spark.png + tags: [apache] + +operators: + - integration-name: Apache Spark + python-modules: + - airflow.providers.apache.spark.operators.spark_jdbc + - airflow.providers.apache.spark.operators.spark_sql + - airflow.providers.apache.spark.operators.spark_submit + +hooks: + - integration-name: Apache Spark + python-modules: + - airflow.providers.apache.spark.hooks.spark_jdbc + - airflow.providers.apache.spark.hooks.spark_jdbc_script + - airflow.providers.apache.spark.hooks.spark_sql + - airflow.providers.apache.spark.hooks.spark_submit + +hook-class-names: + - airflow.providers.apache.spark.hooks.spark_jdbc.SparkJDBCHook + - airflow.providers.apache.spark.hooks.spark_sql.SparkSqlHook + - airflow.providers.apache.spark.hooks.spark_submit.SparkSubmitHook diff --git a/reference/providers/apache/sqoop/CHANGELOG.rst b/reference/providers/apache/sqoop/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/apache/sqoop/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/apache/sqoop/__init__.py b/reference/providers/apache/sqoop/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/sqoop/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/sqoop/hooks/__init__.py b/reference/providers/apache/sqoop/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/sqoop/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/sqoop/hooks/sqoop.py b/reference/providers/apache/sqoop/hooks/sqoop.py new file mode 100644 index 0000000..0f3f764 --- /dev/null +++ b/reference/providers/apache/sqoop/hooks/sqoop.py @@ -0,0 +1,442 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +"""This module contains a sqoop 1.x hook""" +import subprocess +from copy import deepcopy +from typing import Any, Dict, List, Optional + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +class SqoopHook(BaseHook): + """ + This hook is a wrapper around the sqoop 1 binary. To be able to use the hook + it is required that "sqoop" is in the PATH. + + Additional arguments that can be passed via the 'extra' JSON field of the + sqoop connection: + + * ``job_tracker``: Job tracker local|jobtracker:port. + * ``namenode``: Namenode. + * ``lib_jars``: Comma separated jar files to include in the classpath. + * ``files``: Comma separated files to be copied to the map reduce cluster. + * ``archives``: Comma separated archives to be unarchived on the compute + machines. + * ``password_file``: Path to file containing the password. + + :param conn_id: Reference to the sqoop connection. + :type conn_id: str + :param verbose: Set sqoop to verbose. + :type verbose: bool + :param num_mappers: Number of map tasks to import in parallel. + :type num_mappers: int + :param properties: Properties to set via the -D argument + :type properties: dict + """ + + conn_name_attr = "conn_id" + default_conn_name = "sqoop_default" + conn_type = "sqoop" + hook_name = "Sqoop" + + def __init__( + self, + conn_id: str = default_conn_name, + verbose: bool = False, + num_mappers: Optional[int] = None, + hcatalog_database: Optional[str] = None, + hcatalog_table: Optional[str] = None, + properties: Optional[Dict[str, Any]] = None, + ) -> None: + # No mutable types in the default parameters + super().__init__() + self.conn = self.get_connection(conn_id) + connection_parameters = self.conn.extra_dejson + self.job_tracker = connection_parameters.get("job_tracker", None) + self.namenode = connection_parameters.get("namenode", None) + self.libjars = connection_parameters.get("libjars", None) + self.files = connection_parameters.get("files", None) + self.archives = connection_parameters.get("archives", None) + self.password_file = connection_parameters.get("password_file", None) + self.hcatalog_database = hcatalog_database + self.hcatalog_table = hcatalog_table + self.verbose = verbose + self.num_mappers = num_mappers + self.properties = properties or {} + self.log.info( + "Using connection to: %s:%s/%s", + self.conn.host, + self.conn.port, + self.conn.schema, + ) + self.sub_process: Any = None + + def get_conn(self) -> Any: + return self.conn + + def cmd_mask_password(self, cmd_orig: List[str]) -> List[str]: + """Mask command password for safety""" + cmd = deepcopy(cmd_orig) + try: + password_index = cmd.index("--password") + cmd[password_index + 1] = "MASKED" + except ValueError: + self.log.debug("No password in sqoop cmd") + return cmd + + def popen(self, cmd: List[str], **kwargs: Any) -> None: + """ + Remote Popen + + :param cmd: command to remotely execute + :param kwargs: extra arguments to Popen (see subprocess.Popen) + :return: handle to subprocess + """ + masked_cmd = " ".join(self.cmd_mask_password(cmd)) + self.log.info("Executing command: %s", masked_cmd) + self.sub_process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs + ) + + for line in iter(self.sub_process.stdout): # type: ignore + self.log.info(line.strip()) + + self.sub_process.wait() + + self.log.info("Command exited with return code %s", self.sub_process.returncode) + + if self.sub_process.returncode: + raise AirflowException(f"Sqoop command failed: {masked_cmd}") + + def _prepare_command(self, export: bool = False) -> List[str]: + sqoop_cmd_type = "export" if export else "import" + connection_cmd = ["sqoop", sqoop_cmd_type] + + for key, value in self.properties.items(): + connection_cmd += ["-D", f"{key}={value}"] + + if self.namenode: + connection_cmd += ["-fs", self.namenode] + if self.job_tracker: + connection_cmd += ["-jt", self.job_tracker] + if self.libjars: + connection_cmd += ["-libjars", self.libjars] + if self.files: + connection_cmd += ["-files", self.files] + if self.archives: + connection_cmd += ["-archives", self.archives] + if self.conn.login: + connection_cmd += ["--username", self.conn.login] + if self.conn.password: + connection_cmd += ["--password", self.conn.password] + if self.password_file: + connection_cmd += ["--password-file", self.password_file] + if self.verbose: + connection_cmd += ["--verbose"] + if self.num_mappers: + connection_cmd += ["--num-mappers", str(self.num_mappers)] + if self.hcatalog_database: + connection_cmd += ["--hcatalog-database", self.hcatalog_database] + if self.hcatalog_table: + connection_cmd += ["--hcatalog-table", self.hcatalog_table] + + connect_str = self.conn.host + if self.conn.port: + connect_str += f":{self.conn.port}" + if self.conn.schema: + connect_str += f"/{self.conn.schema}" + connection_cmd += ["--connect", connect_str] + + return connection_cmd + + @staticmethod + def _get_export_format_argument(file_type: str = "text") -> List[str]: + if file_type == "avro": + return ["--as-avrodatafile"] + elif file_type == "sequence": + return ["--as-sequencefile"] + elif file_type == "parquet": + return ["--as-parquetfile"] + elif file_type == "text": + return ["--as-textfile"] + else: + raise AirflowException( + "Argument file_type should be 'avro', 'sequence', 'parquet' or 'text'." + ) + + def _import_cmd( + self, + target_dir: Optional[str], + append: bool, + file_type: str, + split_by: Optional[str], + direct: Optional[bool], + driver: Any, + extra_import_options: Any, + ) -> List[str]: + + cmd = self._prepare_command(export=False) + + if target_dir: + cmd += ["--target-dir", target_dir] + + if append: + cmd += ["--append"] + + cmd += self._get_export_format_argument(file_type) + + if split_by: + cmd += ["--split-by", split_by] + + if direct: + cmd += ["--direct"] + + if driver: + cmd += ["--driver", driver] + + if extra_import_options: + for key, value in extra_import_options.items(): + cmd += [f"--{key}"] + if value: + cmd += [str(value)] + + return cmd + + # pylint: disable=too-many-arguments + def import_table( + self, + table: str, + target_dir: Optional[str] = None, + append: bool = False, + file_type: str = "text", + columns: Optional[str] = None, + split_by: Optional[str] = None, + where: Optional[str] = None, + direct: bool = False, + driver: Any = None, + extra_import_options: Optional[Dict[str, Any]] = None, + ) -> Any: + """ + Imports table from remote location to target dir. Arguments are + copies of direct sqoop command line arguments + + :param table: Table to read + :param target_dir: HDFS destination dir + :param append: Append data to an existing dataset in HDFS + :param file_type: "avro", "sequence", "text" or "parquet". + Imports data to into the specified format. Defaults to text. + :param columns: Columns to import from table + :param split_by: Column of the table used to split work units + :param where: WHERE clause to use during import + :param direct: Use direct connector if exists for the database + :param driver: Manually specify JDBC driver class to use + :param extra_import_options: Extra import options to pass as dict. + If a key doesn't have a value, just pass an empty string to it. + Don't include prefix of -- for sqoop options. + """ + cmd = self._import_cmd( + target_dir, + append, + file_type, + split_by, + direct, + driver, + extra_import_options, + ) + + cmd += ["--table", table] + + if columns: + cmd += ["--columns", columns] + if where: + cmd += ["--where", where] + + self.popen(cmd) + + def import_query( + self, + query: str, + target_dir: Optional[str] = None, + append: bool = False, + file_type: str = "text", + split_by: Optional[str] = None, + direct: Optional[bool] = None, + driver: Optional[Any] = None, + extra_import_options: Optional[Dict[str, Any]] = None, + ) -> Any: + """ + Imports a specific query from the rdbms to hdfs + + :param query: Free format query to run + :param target_dir: HDFS destination dir + :param append: Append data to an existing dataset in HDFS + :param file_type: "avro", "sequence", "text" or "parquet" + Imports data to hdfs into the specified format. Defaults to text. + :param split_by: Column of the table used to split work units + :param direct: Use direct import fast path + :param driver: Manually specify JDBC driver class to use + :param extra_import_options: Extra import options to pass as dict. + If a key doesn't have a value, just pass an empty string to it. + Don't include prefix of -- for sqoop options. + """ + cmd = self._import_cmd( + target_dir, + append, + file_type, + split_by, + direct, + driver, + extra_import_options, + ) + cmd += ["--query", query] + + self.popen(cmd) + + # pylint: disable=too-many-arguments + def _export_cmd( + self, + table: str, + export_dir: Optional[str] = None, + input_null_string: Optional[str] = None, + input_null_non_string: Optional[str] = None, + staging_table: Optional[str] = None, + clear_staging_table: bool = False, + enclosed_by: Optional[str] = None, + escaped_by: Optional[str] = None, + input_fields_terminated_by: Optional[str] = None, + input_lines_terminated_by: Optional[str] = None, + input_optionally_enclosed_by: Optional[str] = None, + batch: bool = False, + relaxed_isolation: bool = False, + extra_export_options: Optional[Dict[str, Any]] = None, + ) -> List[str]: + + cmd = self._prepare_command(export=True) + + if input_null_string: + cmd += ["--input-null-string", input_null_string] + + if input_null_non_string: + cmd += ["--input-null-non-string", input_null_non_string] + + if staging_table: + cmd += ["--staging-table", staging_table] + + if clear_staging_table: + cmd += ["--clear-staging-table"] + + if enclosed_by: + cmd += ["--enclosed-by", enclosed_by] + + if escaped_by: + cmd += ["--escaped-by", escaped_by] + + if input_fields_terminated_by: + cmd += ["--input-fields-terminated-by", input_fields_terminated_by] + + if input_lines_terminated_by: + cmd += ["--input-lines-terminated-by", input_lines_terminated_by] + + if input_optionally_enclosed_by: + cmd += ["--input-optionally-enclosed-by", input_optionally_enclosed_by] + + if batch: + cmd += ["--batch"] + + if relaxed_isolation: + cmd += ["--relaxed-isolation"] + + if export_dir: + cmd += ["--export-dir", export_dir] + + if extra_export_options: + for key, value in extra_export_options.items(): + cmd += [f"--{key}"] + if value: + cmd += [str(value)] + + # The required option + cmd += ["--table", table] + + return cmd + + # pylint: disable=too-many-arguments + def export_table( + self, + table: str, + export_dir: Optional[str] = None, + input_null_string: Optional[str] = None, + input_null_non_string: Optional[str] = None, + staging_table: Optional[str] = None, + clear_staging_table: bool = False, + enclosed_by: Optional[str] = None, + escaped_by: Optional[str] = None, + input_fields_terminated_by: Optional[str] = None, + input_lines_terminated_by: Optional[str] = None, + input_optionally_enclosed_by: Optional[str] = None, + batch: bool = False, + relaxed_isolation: bool = False, + extra_export_options: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Exports Hive table to remote location. Arguments are copies of direct + sqoop command line Arguments + + :param table: Table remote destination + :param export_dir: Hive table to export + :param input_null_string: The string to be interpreted as null for + string columns + :param input_null_non_string: The string to be interpreted as null + for non-string columns + :param staging_table: The table in which data will be staged before + being inserted into the destination table + :param clear_staging_table: Indicate that any data present in the + staging table can be deleted + :param enclosed_by: Sets a required field enclosing character + :param escaped_by: Sets the escape character + :param input_fields_terminated_by: Sets the field separator character + :param input_lines_terminated_by: Sets the end-of-line character + :param input_optionally_enclosed_by: Sets a field enclosing character + :param batch: Use batch mode for underlying statement execution + :param relaxed_isolation: Transaction isolation to read uncommitted + for the mappers + :param extra_export_options: Extra export options to pass as dict. + If a key doesn't have a value, just pass an empty string to it. + Don't include prefix of -- for sqoop options. + """ + cmd = self._export_cmd( + table, + export_dir, + input_null_string, + input_null_non_string, + staging_table, + clear_staging_table, + enclosed_by, + escaped_by, + input_fields_terminated_by, + input_lines_terminated_by, + input_optionally_enclosed_by, + batch, + relaxed_isolation, + extra_export_options, + ) + + self.popen(cmd) diff --git a/reference/providers/apache/sqoop/operators/__init__.py b/reference/providers/apache/sqoop/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/apache/sqoop/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/apache/sqoop/operators/sqoop.py b/reference/providers/apache/sqoop/operators/sqoop.py new file mode 100644 index 0000000..06a55b6 --- /dev/null +++ b/reference/providers/apache/sqoop/operators/sqoop.py @@ -0,0 +1,265 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains a sqoop 1 operator""" +import os +import signal +from typing import Any, Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.apache.sqoop.hooks.sqoop import SqoopHook +from airflow.utils.decorators import apply_defaults + + +# pylint: disable=too-many-instance-attributes +class SqoopOperator(BaseOperator): + """ + Execute a Sqoop job. + Documentation for Apache Sqoop can be found here: + https://sqoop.apache.org/docs/1.4.2/SqoopUserGuide.html + + :param conn_id: str + :param cmd_type: str specify command to execute "export" or "import" + :param table: Table to read + :param query: Import result of arbitrary SQL query. Instead of using the table, + columns and where arguments, you can specify a SQL statement with the query + argument. Must also specify a destination directory with target_dir. + :param target_dir: HDFS destination directory where the data + from the rdbms will be written + :param append: Append data to an existing dataset in HDFS + :param file_type: "avro", "sequence", "text" Imports data to + into the specified format. Defaults to text. + :param columns: Columns to import from table + :param num_mappers: Use n mapper tasks to import/export in parallel + :param split_by: Column of the table used to split work units + :param where: WHERE clause to use during import + :param export_dir: HDFS Hive database directory to export to the rdbms + :param input_null_string: The string to be interpreted as null + for string columns + :param input_null_non_string: The string to be interpreted as null + for non-string columns + :param staging_table: The table in which data will be staged before + being inserted into the destination table + :param clear_staging_table: Indicate that any data present in the + staging table can be deleted + :param enclosed_by: Sets a required field enclosing character + :param escaped_by: Sets the escape character + :param input_fields_terminated_by: Sets the input field separator + :param input_lines_terminated_by: Sets the input end-of-line character + :param input_optionally_enclosed_by: Sets a field enclosing character + :param batch: Use batch mode for underlying statement execution + :param direct: Use direct export fast path + :param driver: Manually specify JDBC driver class to use + :param verbose: Switch to more verbose logging for debug purposes + :param relaxed_isolation: use read uncommitted isolation level + :param hcatalog_database: Specifies the database name for the HCatalog table + :param hcatalog_table: The argument value for this option is the HCatalog table + :param create_hcatalog_table: Have sqoop create the hcatalog table passed + in or not + :param properties: additional JVM properties passed to sqoop + :param extra_import_options: Extra import options to pass as dict. + If a key doesn't have a value, just pass an empty string to it. + Don't include prefix of -- for sqoop options. + :param extra_export_options: Extra export options to pass as dict. + If a key doesn't have a value, just pass an empty string to it. + Don't include prefix of -- for sqoop options. + """ + + template_fields = ( + "conn_id", + "cmd_type", + "table", + "query", + "target_dir", + "file_type", + "columns", + "split_by", + "where", + "export_dir", + "input_null_string", + "input_null_non_string", + "staging_table", + "enclosed_by", + "escaped_by", + "input_fields_terminated_by", + "input_lines_terminated_by", + "input_optionally_enclosed_by", + "properties", + "extra_import_options", + "driver", + "extra_export_options", + "hcatalog_database", + "hcatalog_table", + ) + ui_color = "#7D8CA4" + + # pylint: disable=too-many-arguments,too-many-locals + @apply_defaults + def __init__( + self, + *, + conn_id: str = "sqoop_default", + cmd_type: str = "import", + table: Optional[str] = None, + query: Optional[str] = None, + target_dir: Optional[str] = None, + append: bool = False, + file_type: str = "text", + columns: Optional[str] = None, + num_mappers: Optional[int] = None, + split_by: Optional[str] = None, + where: Optional[str] = None, + export_dir: Optional[str] = None, + input_null_string: Optional[str] = None, + input_null_non_string: Optional[str] = None, + staging_table: Optional[str] = None, + clear_staging_table: bool = False, + enclosed_by: Optional[str] = None, + escaped_by: Optional[str] = None, + input_fields_terminated_by: Optional[str] = None, + input_lines_terminated_by: Optional[str] = None, + input_optionally_enclosed_by: Optional[str] = None, + batch: bool = False, + direct: bool = False, + driver: Optional[Any] = None, + verbose: bool = False, + relaxed_isolation: bool = False, + properties: Optional[Dict[str, Any]] = None, + hcatalog_database: Optional[str] = None, + hcatalog_table: Optional[str] = None, + create_hcatalog_table: bool = False, + extra_import_options: Optional[Dict[str, Any]] = None, + extra_export_options: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.conn_id = conn_id + self.cmd_type = cmd_type + self.table = table + self.query = query + self.target_dir = target_dir + self.append = append + self.file_type = file_type + self.columns = columns + self.num_mappers = num_mappers + self.split_by = split_by + self.where = where + self.export_dir = export_dir + self.input_null_string = input_null_string + self.input_null_non_string = input_null_non_string + self.staging_table = staging_table + self.clear_staging_table = clear_staging_table + self.enclosed_by = enclosed_by + self.escaped_by = escaped_by + self.input_fields_terminated_by = input_fields_terminated_by + self.input_lines_terminated_by = input_lines_terminated_by + self.input_optionally_enclosed_by = input_optionally_enclosed_by + self.batch = batch + self.direct = direct + self.driver = driver + self.verbose = verbose + self.relaxed_isolation = relaxed_isolation + self.hcatalog_database = hcatalog_database + self.hcatalog_table = hcatalog_table + self.create_hcatalog_table = create_hcatalog_table + self.properties = properties + self.extra_import_options = extra_import_options or {} + self.extra_export_options = extra_export_options or {} + self.hook: Optional[SqoopHook] = None + + def execute(self, context: Dict[str, Any]) -> None: + """Execute sqoop job""" + if self.hook is None: + self.hook = self._get_hook() + + if self.cmd_type == "export": + self.hook.export_table( + table=self.table, # type: ignore + export_dir=self.export_dir, + input_null_string=self.input_null_string, + input_null_non_string=self.input_null_non_string, + staging_table=self.staging_table, + clear_staging_table=self.clear_staging_table, + enclosed_by=self.enclosed_by, + escaped_by=self.escaped_by, + input_fields_terminated_by=self.input_fields_terminated_by, + input_lines_terminated_by=self.input_lines_terminated_by, + input_optionally_enclosed_by=self.input_optionally_enclosed_by, + batch=self.batch, + relaxed_isolation=self.relaxed_isolation, + extra_export_options=self.extra_export_options, + ) + elif self.cmd_type == "import": + # add create hcatalog table to extra import options if option passed + # if new params are added to constructor can pass them in here + # so don't modify sqoop_hook for each param + if self.create_hcatalog_table: + self.extra_import_options["create-hcatalog-table"] = "" + + if self.table and self.query: + raise AirflowException( + "Cannot specify query and table together. Need to specify either or." + ) + + if self.table: + self.hook.import_table( + table=self.table, + target_dir=self.target_dir, + append=self.append, + file_type=self.file_type, + columns=self.columns, + split_by=self.split_by, + where=self.where, + direct=self.direct, + driver=self.driver, + extra_import_options=self.extra_import_options, + ) + elif self.query: + self.hook.import_query( + query=self.query, + target_dir=self.target_dir, + append=self.append, + file_type=self.file_type, + split_by=self.split_by, + direct=self.direct, + driver=self.driver, + extra_import_options=self.extra_import_options, + ) + else: + raise AirflowException( + "Provide query or table parameter to import using Sqoop" + ) + else: + raise AirflowException("cmd_type should be 'import' or 'export'") + + def on_kill(self) -> None: + if self.hook is None: + self.hook = self._get_hook() + self.log.info("Sending SIGTERM signal to bash process group") + os.killpg(os.getpgid(self.hook.sub_process.pid), signal.SIGTERM) + + def _get_hook(self) -> SqoopHook: + return SqoopHook( + conn_id=self.conn_id, + verbose=self.verbose, + num_mappers=self.num_mappers, + hcatalog_database=self.hcatalog_database, + hcatalog_table=self.hcatalog_table, + properties=self.properties, + ) diff --git a/reference/providers/apache/sqoop/provider.yaml b/reference/providers/apache/sqoop/provider.yaml new file mode 100644 index 0000000..f4975eb --- /dev/null +++ b/reference/providers/apache/sqoop/provider.yaml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-apache-sqoop +name: Apache Sqoop +description: | + `Apache Sqoop `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Apache Sqoop + external-doc-url: https://sqoop.apache.org/ + logo: /integration-logos/apache/sqoop.png + tags: [apache] + +operators: + - integration-name: Apache Sqoop + python-modules: + - airflow.providers.apache.sqoop.operators.sqoop + +hooks: + - integration-name: Apache Sqoop + python-modules: + - airflow.providers.apache.sqoop.hooks.sqoop + +hook-class-names: + - airflow.providers.apache.sqoop.hooks.sqoop.SqoopHook diff --git a/reference/providers/celery/CHANGELOG.rst b/reference/providers/celery/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/celery/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/celery/__init__.py b/reference/providers/celery/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/celery/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/celery/provider.yaml b/reference/providers/celery/provider.yaml new file mode 100644 index 0000000..0ac6859 --- /dev/null +++ b/reference/providers/celery/provider.yaml @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-celery +name: Celery +description: | + `Celery `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Celery + external-doc-url: http://www.celeryproject.org/ + logo: /integration-logos/celery/Celery.png + tags: [software] + +sensors: + - integration-name: Celery + python-modules: + - airflow.providers.celery.sensors.celery_queue diff --git a/reference/providers/celery/sensors/__init__.py b/reference/providers/celery/sensors/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/celery/sensors/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/celery/sensors/celery_queue.py b/reference/providers/celery/sensors/celery_queue.py new file mode 100644 index 0000000..aa6953d --- /dev/null +++ b/reference/providers/celery/sensors/celery_queue.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Optional + +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults +from celery.app import control + + +class CeleryQueueSensor(BaseSensorOperator): + """ + Waits for a Celery queue to be empty. By default, in order to be considered + empty, the queue must not have any tasks in the ``reserved``, ``scheduled`` + or ``active`` states. + + :param celery_queue: The name of the Celery queue to wait for. + :type celery_queue: str + :param target_task_id: Task id for checking + :type target_task_id: str + """ + + @apply_defaults + def __init__( + self, *, celery_queue: str, target_task_id: Optional[str] = None, **kwargs + ) -> None: + + super().__init__(**kwargs) + self.celery_queue = celery_queue + self.target_task_id = target_task_id + + def _check_task_id(self, context: Dict[str, Any]) -> bool: + """ + Gets the returned Celery result from the Airflow task + ID provided to the sensor, and returns True if the + celery result has been finished execution. + + :param context: Airflow's execution context + :type context: dict + :return: True if task has been executed, otherwise False + :rtype: bool + """ + ti = context["ti"] + celery_result = ti.xcom_pull(task_ids=self.target_task_id) + return celery_result.ready() + + def poke(self, context: Dict[str, Any]) -> bool: + + if self.target_task_id: + return self._check_task_id(context) + + inspect_result = control.Inspect() + reserved = inspect_result.reserved() + scheduled = inspect_result.scheduled() + active = inspect_result.active() + + try: + reserved = len(reserved[self.celery_queue]) + scheduled = len(scheduled[self.celery_queue]) + active = len(active[self.celery_queue]) + + self.log.info("Checking if celery queue %s is empty.", self.celery_queue) + + return reserved == 0 and scheduled == 0 and active == 0 + except KeyError: + raise KeyError(f"Could not locate Celery queue {self.celery_queue}") diff --git a/reference/providers/cloudant/CHANGELOG.rst b/reference/providers/cloudant/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/cloudant/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/cloudant/__init__.py b/reference/providers/cloudant/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/cloudant/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/cloudant/hooks/__init__.py b/reference/providers/cloudant/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/cloudant/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/cloudant/hooks/cloudant.py b/reference/providers/cloudant/hooks/cloudant.py new file mode 100644 index 0000000..c1564f6 --- /dev/null +++ b/reference/providers/cloudant/hooks/cloudant.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Cloudant""" +from typing import Dict + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from cloudant import cloudant + + +class CloudantHook(BaseHook): + """ + Interact with Cloudant. This class is a thin wrapper around the cloudant python library. + + .. seealso:: the latest documentation `here `_. + + :param cloudant_conn_id: The connection id to authenticate and get a session object from cloudant. + :type cloudant_conn_id: str + """ + + conn_name_attr = "cloudant_conn_id" + default_conn_name = "cloudant_default" + conn_type = "cloudant" + hook_name = "Cloudant" + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ["port", "extra"], + "relabeling": { + "host": "Account", + "login": "Username (or API Key)", + "schema": "Database", + }, + } + + def __init__(self, cloudant_conn_id: str = default_conn_name) -> None: + super().__init__() + self.cloudant_conn_id = cloudant_conn_id + + def get_conn(self) -> cloudant: + """ + Opens a connection to the cloudant service and closes it automatically if used as context manager. + + .. note:: + In the connection form: + - 'host' equals the 'Account' (optional) + - 'login' equals the 'Username (or API Key)' (required) + - 'password' equals the 'Password' (required) + + :return: an authorized cloudant session context manager object. + :rtype: cloudant + """ + conn = self.get_connection(self.cloudant_conn_id) + + self._validate_connection(conn) + + cloudant_session = cloudant( + user=conn.login, passwd=conn.password, account=conn.host + ) + + return cloudant_session + + def _validate_connection(self, conn: cloudant) -> None: + for conn_param in ["login", "password"]: + if not getattr(conn, conn_param): + raise AirflowException(f"missing connection parameter {conn_param}") diff --git a/reference/providers/cloudant/provider.yaml b/reference/providers/cloudant/provider.yaml new file mode 100644 index 0000000..25f07c9 --- /dev/null +++ b/reference/providers/cloudant/provider.yaml @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-cloudant +name: IBM Cloudant +description: | + `IBM Cloudant `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: IBM Cloudant + external-doc-url: https://www.ibm.com/cloud/cloudant + logo: /integration-logos/cloudant/Cloudant.png + tags: [service] + +hooks: + - integration-name: IBM Cloudant + python-modules: + - airflow.providers.cloudant.hooks.cloudant + +hook-class-names: + - airflow.providers.cloudant.hooks.cloudant.CloudantHook diff --git a/reference/providers/cncf/__init__.py b/reference/providers/cncf/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/cncf/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/cncf/kubernetes/CHANGELOG.rst b/reference/providers/cncf/kubernetes/CHANGELOG.rst new file mode 100644 index 0000000..378a0c5 --- /dev/null +++ b/reference/providers/cncf/kubernetes/CHANGELOG.rst @@ -0,0 +1,44 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.2 +..... + +Bug fixes +~~~~~~~~~ + +* ``Allow pod name override in KubernetesPodOperator if pod_template is used. (#14186)`` +* ``Allow users of the KPO to *actually* template environment variables (#14083)`` + +1.0.1 +..... + +Updated documentation and readme files. + +Bug fixes +~~~~~~~~~ + +* ``Pass image_pull_policy in KubernetesPodOperator correctly (#13289)`` + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/cncf/kubernetes/__init__.py b/reference/providers/cncf/kubernetes/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/cncf/kubernetes/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/cncf/kubernetes/backcompat/__init__.py b/reference/providers/cncf/kubernetes/backcompat/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/cncf/kubernetes/backcompat/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py b/reference/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py new file mode 100644 index 0000000..d8f8331 --- /dev/null +++ b/reference/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Executes task in a Kubernetes POD""" + +from typing import List + +from airflow.exceptions import AirflowException +from airflow.providers.cncf.kubernetes.backcompat.pod import Port, Resources +from airflow.providers.cncf.kubernetes.backcompat.pod_runtime_info_env import ( + PodRuntimeInfoEnv, +) +from airflow.providers.cncf.kubernetes.backcompat.volume import Volume +from airflow.providers.cncf.kubernetes.backcompat.volume_mount import VolumeMount +from kubernetes.client import ApiClient +from kubernetes.client import models as k8s + + +def _convert_kube_model_object(obj, old_class, new_class): + convert_op = getattr(obj, "to_k8s_client_obj", None) + if callable(convert_op): + return obj.to_k8s_client_obj() + elif isinstance(obj, new_class): + return obj + else: + raise AirflowException(f"Expected {old_class} or {new_class}, got {type(obj)}") + + +def _convert_from_dict(obj, new_class): + if isinstance(obj, new_class): + return obj + elif isinstance(obj, dict): + api_client = ApiClient() + return api_client._ApiClient__deserialize_model( + obj, new_class + ) # pylint: disable=W0212 + else: + raise AirflowException(f"Expected dict or {new_class}, got {type(obj)}") + + +def convert_volume(volume) -> k8s.V1Volume: + """ + Converts an airflow Volume object into a k8s.V1Volume + + :param volume: + :return: k8s.V1Volume + """ + return _convert_kube_model_object(volume, Volume, k8s.V1Volume) + + +def convert_volume_mount(volume_mount) -> k8s.V1VolumeMount: + """ + Converts an airflow VolumeMount object into a k8s.V1VolumeMount + + :param volume_mount: + :return: k8s.V1VolumeMount + """ + return _convert_kube_model_object(volume_mount, VolumeMount, k8s.V1VolumeMount) + + +def convert_resources(resources) -> k8s.V1ResourceRequirements: + """ + Converts an airflow Resources object into a k8s.V1ResourceRequirements + + :param resources: + :return: k8s.V1ResourceRequirements + """ + if isinstance(resources, dict): + resources = Resources(**resources) + return _convert_kube_model_object(resources, Resources, k8s.V1ResourceRequirements) + + +def convert_port(port) -> k8s.V1ContainerPort: + """ + Converts an airflow Port object into a k8s.V1ContainerPort + + :param port: + :return: k8s.V1ContainerPort + """ + return _convert_kube_model_object(port, Port, k8s.V1ContainerPort) + + +def convert_env_vars(env_vars) -> List[k8s.V1EnvVar]: + """ + Converts a dictionary into a list of env_vars + + :param env_vars: + :return: + """ + if isinstance(env_vars, dict): + res = [] + for k, v in env_vars.items(): + res.append(k8s.V1EnvVar(name=k, value=v)) + return res + elif isinstance(env_vars, list): + return env_vars + else: + raise AirflowException(f"Expected dict or list, got {type(env_vars)}") + + +def convert_pod_runtime_info_env(pod_runtime_info_envs) -> k8s.V1EnvVar: + """ + Converts a PodRuntimeInfoEnv into an k8s.V1EnvVar + + :param pod_runtime_info_envs: + :return: + """ + return _convert_kube_model_object( + pod_runtime_info_envs, PodRuntimeInfoEnv, k8s.V1EnvVar + ) + + +def convert_image_pull_secrets(image_pull_secrets) -> List[k8s.V1LocalObjectReference]: + """ + Converts a PodRuntimeInfoEnv into an k8s.V1EnvVar + + :param image_pull_secrets: + :return: + """ + if isinstance(image_pull_secrets, str): + secrets = image_pull_secrets.split(",") + return [k8s.V1LocalObjectReference(name=secret) for secret in secrets] + else: + return image_pull_secrets + + +def convert_configmap(configmaps) -> k8s.V1EnvFrom# + """ + Converts a str into an k8s.V1EnvFromSource + + :param configmaps: + :return: + """ + return k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmaps)) + + +def convert_affinity(affinity) -> k8s.V1Affinity: + """Converts a dict into an k8s.V1Affinity""" + return _convert_from_dict(affinity, k8s.V1Affinity) + + +def convert_toleration(toleration) -> k8s.V1Toleration: + """Converts a dict into an k8s.V1Toleration""" + return _convert_from_dict(toleration, k8s.V1Toleration) diff --git a/reference/providers/cncf/kubernetes/backcompat/pod.py b/reference/providers/cncf/kubernetes/backcompat/pod.py new file mode 100644 index 0000000..ceb8276 --- /dev/null +++ b/reference/providers/cncf/kubernetes/backcompat/pod.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Classes for interacting with Kubernetes API""" + +from kubernetes.client import models as k8s + + +class Resources: + """backwards compat for Resources""" + + __slots__ = ( + "request_memory", + "request_cpu", + "limit_memory", + "limit_cpu", + "limit_gpu", + "request_ephemeral_storage", + "limit_ephemeral_storage", + ) + + """ + :param request_memory: requested memory + :type request_memory: str + :param request_cpu: requested CPU number + :type request_cpu: float | str + :param request_ephemeral_storage: requested ephemeral storage + :type request_ephemeral_storage: str + :param limit_memory: limit for memory usage + :type limit_memory: str + :param limit_cpu: Limit for CPU used + :type limit_cpu: float | str + :param limit_gpu: Limits for GPU used + :type limit_gpu: int + :param limit_ephemeral_storage: Limit for ephemeral storage + :type limit_ephemeral_storage: float | str + """ + + def __init__( + self, + request_memory=None, + request_cpu=None, + request_ephemeral_storage=None, + limit_memory=None, + limit_cpu=None, + limit_gpu=None, + limit_ephemeral_storage=None, + ): + self.request_memory = request_memory + self.request_cpu = request_cpu + self.request_ephemeral_storage = request_ephemeral_storage + self.limit_memory = limit_memory + self.limit_cpu = limit_cpu + self.limit_gpu = limit_gpu + self.limit_ephemeral_storage = limit_ephemeral_storage + + def to_k8s_client_obj(self): + """ + Converts to k8s object. + + @rtype: object + """ + limits_raw = { + "cpu": self.limit_cpu, + "memory": self.limit_memory, + "nvidia.com/gpu": self.limit_gpu, + "ephemeral-storage": self.limit_ephemeral_storage, + } + requests_raw = { + "cpu": self.request_cpu, + "memory": self.request_memory, + "ephemeral-storage": self.request_ephemeral_storage, + } + + limits = {k: v for k, v in limits_raw.items() if v} + requests = {k: v for k, v in requests_raw.items() if v} + resource_req = k8s.V1ResourceRequirements(limits=limits, requests=requests) + return resource_req + + +class Port: + """POD port""" + + __slots__ = ("name", "container_port") + + def __init__(self, name=None, container_port=None): + """Creates port""" + self.name = name + self.container_port = container_port + + def to_k8s_client_obj(self): + """ + Converts to k8s object. + + :rtype: object + """ + return k8s.V1ContainerPort(name=self.name, container_port=self.container_port) diff --git a/reference/providers/cncf/kubernetes/backcompat/pod_runtime_info_env.py b/reference/providers/cncf/kubernetes/backcompat/pod_runtime_info_env.py new file mode 100644 index 0000000..d62f15a --- /dev/null +++ b/reference/providers/cncf/kubernetes/backcompat/pod_runtime_info_env.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Classes for interacting with Kubernetes API""" + +import kubernetes.client.models as k8s + + +class PodRuntimeInfoEnv: + """Defines Pod runtime information as environment variable""" + + def __init__(self, name, field_path): + """ + Adds Kubernetes pod runtime information as environment variables such as namespace, pod IP, pod name. + Full list of options can be found in kubernetes documentation. + + :param name: the name of the environment variable + :type: name: str + :param field_path: path to pod runtime info. Ex: metadata.namespace | status.podIP + :type: field_path: str + """ + self.name = name + self.field_path = field_path + + def to_k8s_client_obj(self): + """Converts to k8s object. + + :return: kubernetes.client.models.V1EnvVar + """ + return k8s.V1EnvVar( + name=self.name, + value_from=k8s.V1EnvVarSource( + field_ref=k8s.V1ObjectFieldSelector(field_path=self.field_path) + ), + ) diff --git a/reference/providers/cncf/kubernetes/backcompat/volume.py b/reference/providers/cncf/kubernetes/backcompat/volume.py new file mode 100644 index 0000000..95cc00b --- /dev/null +++ b/reference/providers/cncf/kubernetes/backcompat/volume.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module is deprecated. Please use `kubernetes.client.models.V1Volume`.""" + +import warnings + +from kubernetes.client import models as k8s + +warnings.warn( + "This module is deprecated. Please use `kubernetes.client.models.V1Volume`.", + DeprecationWarning, + stacklevel=2, +) + + +class Volume: + """Backward compatible Volume""" + + def __init__(self, name, configs): + """Adds Kubernetes Volume to pod. allows pod to access features like ConfigMaps + and Persistent Volumes + + :param name: the name of the volume mount + :type name: str + :param configs: dictionary of any features needed for volume. We purposely keep this + vague since there are multiple volume types with changing configs. + :type configs: dict + """ + self.name = name + self.configs = configs + + def to_k8s_client_obj(self) -> k8s.V1Volume: + """ + Converts to k8s object. + + :return: Volume Mount k8s object + """ + resp = k8s.V1Volume(name=self.name) + for k, v in self.configs.items(): + snake_key = Volume._convert_to_snake_case(k) + if hasattr(resp, snake_key): + setattr(resp, snake_key, v) + else: + raise AttributeError(f"V1Volume does not have attribute {k}") + return resp + + # # https://www.geeksforgeeks.org/python-program-to-convert-camel-case-string-to-snake-case/ + @staticmethod + def _convert_to_snake_case(input_string): + return "".join( + ["_" + i.lower() if i.isupper() else i for i in input_string] + ).lstrip("_") diff --git a/reference/providers/cncf/kubernetes/backcompat/volume_mount.py b/reference/providers/cncf/kubernetes/backcompat/volume_mount.py new file mode 100644 index 0000000..0327faa --- /dev/null +++ b/reference/providers/cncf/kubernetes/backcompat/volume_mount.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Classes for interacting with Kubernetes API""" + +import warnings + +from kubernetes.client import models as k8s + +warnings.warn( + "This module is deprecated. Please use `kubernetes.client.models.V1VolumeMount`.", + DeprecationWarning, + stacklevel=2, +) + + +class VolumeMount: + """Backward compatible VolumeMount""" + + __slots__ = ("name", "mount_path", "sub_path", "read_only") + + def __init__(self, name, mount_path, sub_path, read_only): + """ + Initialize a Kubernetes Volume Mount. Used to mount pod level volumes to + running container. + + :param name: the name of the volume mount + :type name: str + :param mount_path: + :type mount_path: str + :param sub_path: subpath within the volume mount + :type sub_path: Optional[str] + :param read_only: whether to access pod with read-only mode + :type read_only: bool + """ + self.name = name + self.mount_path = mount_path + self.sub_path = sub_path + self.read_only = read_only + + def to_k8s_client_obj(self) -> k8s.V1VolumeMount: + """ + Converts to k8s object. + + :return: Volume Mount k8s object + """ + return k8s.V1VolumeMount( + name=self.name, + mount_path=self.mount_path, + sub_path=self.sub_path, + read_only=self.read_only, + ) diff --git a/reference/providers/cncf/kubernetes/example_dags/__init__.py b/reference/providers/cncf/kubernetes/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/cncf/kubernetes/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/cncf/kubernetes/example_dags/example_kubernetes.py b/reference/providers/cncf/kubernetes/example_dags/example_kubernetes.py new file mode 100644 index 0000000..e409b65 --- /dev/null +++ b/reference/providers/cncf/kubernetes/example_dags/example_kubernetes.py @@ -0,0 +1,185 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example dag for using the KubernetesPodOperator. +""" + +from airflow import DAG +from airflow.kubernetes.secret import Secret +from airflow.operators.bash import BashOperator +from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import ( + KubernetesPodOperator, +) +from airflow.utils.dates import days_ago +from kubernetes.client import models as k8s + +# [START howto_operator_k8s_cluster_resources] +secret_file = Secret("volume", "/etc/sql_conn", "airflow-secrets", "sql_alchemy_conn") +secret_env = Secret("env", "SQL_CONN", "airflow-secrets", "sql_alchemy_conn") +secret_all_keys = Secret("env", None, "airflow-secrets-2") +volume_mount = k8s.V1VolumeMount( + name="test-volume", mount_path="/root/mount_file", sub_path=None, read_only=True +) + +configmaps = [ + k8s.V1EnvFromSource( + config_map_ref=k8s.V1ConfigMapEnvSource(name="test-configmap-1") + ), + k8s.V1EnvFromSource( + config_map_ref=k8s.V1ConfigMapEnvSource(name="test-configmap-2") + ), +] + +volume = k8s.V1Volume( + name="test-volume", + persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource( + claim_name="test-volume" + ), +) + +port = k8s.V1ContainerPort(name="http", container_port=80) + +init_container_volume_mounts = [ + k8s.V1VolumeMount( + mount_path="/etc/foo", name="test-volume", sub_path=None, read_only=True + ) +] + +init_environments = [ + k8s.V1EnvVar(name="key1", value="value1"), + k8s.V1EnvVar(name="key2", value="value2"), +] + +init_container = k8s.V1Container( + name="init-container", + image="ubuntu:16.04", + env=init_environments, + volume_mounts=init_container_volume_mounts, + command=["bash", "-cx"], + args=["echo 10"], +) + +affinity = k8s.V1Affinity( + node_affinity=k8s.V1NodeAffinity( + preferred_during_scheduling_ignored_during_execution=[ + k8s.V1PreferredSchedulingTerm( + weight=1, + preference=k8s.V1NodeSelectorTerm( + match_expressions=[ + k8s.V1NodeSelectorRequirement( + key="disktype", operator="in", values=["ssd"] + ) + ] + ), + ) + ] + ), + pod_affinity=k8s.V1PodAffinity( + required_during_scheduling_ignored_during_execution=[ + k8s.V1WeightedPodAffinityTerm( + weight=1, + pod_affinity_term=k8s.V1PodAffinityTerm( + label_selector=k8s.V1LabelSelector( + match_expressions=[ + k8s.V1LabelSelectorRequirement( + key="security", operator="In", values="S1" + ) + ] + ), + topology_key="failure-domain.beta.kubernetes.io/zone", + ), + ) + ] + ), +) + +tolerations = [k8s.V1Toleration(key="key", operator="Equal", value="value")] + +# [END howto_operator_k8s_cluster_resources] + + +default_args = { + "owner": "airflow", +} + +with DAG( + dag_id="example_kubernetes_operator", + default_args=default_args, + schedule_interval=None, + start_date=days_ago(2), + tags=["example"], +) as dag: + k = KubernetesPodOperator( + namespace="default", + image="ubuntu:16.04", + cmds=["bash", "-cx"], + arguments=["echo", "10"], + labels={"foo": "bar"}, + secrets=[secret_file, secret_env, secret_all_keys], + ports=[port], + volumes=[volume], + volume_mounts=[volume_mount], + env_from=configmaps, + name="airflow-test-pod", + task_id="task", + affinity=affinity, + is_delete_operator_pod=True, + hostnetwork=False, + tolerations=tolerations, + init_containers=[init_container], + priority_class_name="medium", + ) + + # [START howto_operator_k8s_private_image] + quay_k8s = KubernetesPodOperator( + namespace="default", + image="quay.io/apache/bash", + image_pull_secrets=[k8s.V1LocalObjectReference("testquay")], + cmds=["bash", "-cx"], + arguments=["echo", "10", "echo pwd"], + labels={"foo": "bar"}, + name="airflow-private-image-pod", + is_delete_operator_pod=True, + in_cluster=True, + task_id="task-two", + get_logs=True, + ) + # [END howto_operator_k8s_private_image] + + # [START howto_operator_k8s_write_xcom] + write_xcom = KubernetesPodOperator( + namespace="default", + image="alpine", + cmds=[ + "sh", + "-c", + "mkdir -p /airflow/xcom/;echo '[1,2,3,4]' > /airflow/xcom/return.json", + ], + name="write-xcom", + do_xcom_push=True, + is_delete_operator_pod=True, + in_cluster=True, + task_id="write-xcom", + get_logs=True, + ) + + pod_task_xcom_result = BashOperator( + bash_command="echo \"{{ task_instance.xcom_pull('write-xcom')[0] }}\"", + task_id="pod_task_xcom_result", + ) + # [END howto_operator_k8s_write_xcom] diff --git a/reference/providers/cncf/kubernetes/example_dags/example_spark_kubernetes.py b/reference/providers/cncf/kubernetes/example_dags/example_spark_kubernetes.py new file mode 100644 index 0000000..1faeff1 --- /dev/null +++ b/reference/providers/cncf/kubernetes/example_dags/example_spark_kubernetes.py @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example DAG which uses SparkKubernetesOperator and SparkKubernetesSensor. +In this example, we create two tasks which execute sequentially. +The first task is to submit sparkApplication on Kubernetes cluster(the example uses spark-pi application). +and the second task is to check the final state of the sparkApplication that submitted in the first state. + +Spark-on-k8s operator is required to be already installed on Kubernetes +https://github.com/GoogleCloudPlatform/spark-on-k8s-operator +""" + +from datetime import timedelta + +# [START import_module] +# The DAG object; we'll need this to instantiate a DAG +from airflow import DAG + +# Operators; we need this to operate! +from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import ( + SparkKubernetesOperator, +) +from airflow.providers.cncf.kubernetes.sensors.spark_kubernetes import ( + SparkKubernetesSensor, +) +from airflow.utils.dates import days_ago + +# [END import_module] + +# [START default_args] +# These args will get passed on to each operator +# You can override them on a per-task basis during operator initialization +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "max_active_runs": 1, +} +# [END default_args] + +# [START instantiate_dag] + +dag = DAG( + "spark_pi", + default_args=default_args, + description="submit spark-pi as sparkApplication on kubernetes", + schedule_interval=timedelta(days=1), + start_date=days_ago(1), +) + +t1 = SparkKubernetesOperator( + task_id="spark_pi_submit", + namespace="default", + application_file="example_spark_kubernetes_spark_pi.yaml", + kubernetes_conn_id="kubernetes_default", + do_xcom_push=True, + dag=dag, +) + +t2 = SparkKubernetesSensor( + task_id="spark_pi_monitor", + namespace="default", + application_name="{{ task_instance.xcom_pull(task_ids='spark_pi_submit')['metadata']['name'] }}", + kubernetes_conn_id="kubernetes_default", + dag=dag, +) +t1 >> t2 diff --git a/reference/providers/cncf/kubernetes/example_dags/example_spark_kubernetes_spark_pi.yaml b/reference/providers/cncf/kubernetes/example_dags/example_spark_kubernetes_spark_pi.yaml new file mode 100644 index 0000000..52521a9 --- /dev/null +++ b/reference/providers/cncf/kubernetes/example_dags/example_spark_kubernetes_spark_pi.yaml @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +--- +apiVersion: "sparkoperator.k8s.io/v1beta2" +kind: SparkApplication +metadata: + name: "spark-pi-{{ ds }}-{{ task_instance.try_number }}" + namespace: default +spec: + type: Scala + mode: cluster + image: "gcr.io/spark-operator/spark:v2.4.4" + imagePullPolicy: Always + mainClass: org.apache.spark.examples.SparkPi + mainApplicationFile: "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar" + sparkVersion: "2.4.4" + restartPolicy: + type: Never + volumes: + - name: "test-volume" + hostPath: + path: "/tmp" + type: Directory + driver: + cores: 1 + coreLimit: "1200m" + memory: "512m" + labels: + version: 2.4.4 + serviceAccount: spark + volumeMounts: + - name: "test-volume" + mountPath: "/tmp" + executor: + cores: 1 + instances: 1 + memory: "512m" + labels: + version: 2.4.4 + volumeMounts: + - name: "test-volume" + mountPath: "/tmp" diff --git a/reference/providers/cncf/kubernetes/hooks/__init__.py b/reference/providers/cncf/kubernetes/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/cncf/kubernetes/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/cncf/kubernetes/hooks/kubernetes.py b/reference/providers/cncf/kubernetes/hooks/kubernetes.py new file mode 100644 index 0000000..3522439 --- /dev/null +++ b/reference/providers/cncf/kubernetes/hooks/kubernetes.py @@ -0,0 +1,294 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tempfile +from typing import Any, Dict, Generator, Optional, Tuple, Union + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property +from kubernetes import client, config, watch + +try: + import airflow.utils.yaml as yaml +except ImportError: + import yaml + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +def _load_body_to_dict(body): + try: + body_dict = yaml.safe_load(body) + except yaml.YAMLError as e: + raise AirflowException(f"Exception when loading resource definition: {e}\n") + return body_dict + + +class KubernetesHook(BaseHook): + """ + Creates Kubernetes API connection. + + - use in cluster configuration by using ``extra__kubernetes__in_cluster`` in connection + - use custom config by providing path to the file using ``extra__kubernetes__kube_config_path`` + - use custom configuration by providing content of kubeconfig file via + ``extra__kubernetes__kube_config`` in connection + - use default config by providing no extras + + This hook check for configuration option in the above order. Once an option is present it will + use this configuration. + + .. seealso:: + For more information about Kubernetes connection: + :doc:`/connections/kubernetes` + + :param conn_id: the connection to Kubernetes cluster + :type conn_id: str + """ + + conn_name_attr = "kubernetes_conn_id" + default_conn_name = "kubernetes_default" + conn_type = "kubernetes" + hook_name = "Kubernetes Cluster Connection" + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import BooleanField, StringField + + return { + "extra__kubernetes__in_cluster": BooleanField( + lazy_gettext("In cluster configuration") + ), + "extra__kubernetes__kube_config_path": StringField( + lazy_gettext("Kube config path"), widget=BS3TextFieldWidget() + ), + "extra__kubernetes__kube_config": StringField( + lazy_gettext("Kube config (JSON format)"), widget=BS3TextFieldWidget() + ), + "extra__kubernetes__namespace": StringField( + lazy_gettext("Namespace"), widget=BS3TextFieldWidget() + ), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], + "relabeling": {}, + } + + def __init__( + self, + conn_id: str = default_conn_name, + client_configuration: Optional[client.Configuration] = None, + ) -> None: + super().__init__() + self.conn_id = conn_id + self.client_configuration = client_configuration + + def get_conn(self) -> Any: + """Returns kubernetes api session for use with requests""" + connection = self.get_connection(self.conn_id) + extras = connection.extra_dejson + in_cluster = extras.get("extra__kubernetes__in_cluster") + kubeconfig_path = extras.get("extra__kubernetes__kube_config_path") + kubeconfig = extras.get("extra__kubernetes__kube_config") + num_selected_configuration = len( + [o for o in [in_cluster, kubeconfig, kubeconfig_path] if o] + ) + + if num_selected_configuration > 1: + raise AirflowException( + "Invalid connection configuration. Options extra__kubernetes__kube_config_path, " + "extra__kubernetes__kube_config, extra__kubernetes__in_cluster are mutually exclusive. " + "You can only use one option at a time." + ) + if in_cluster: + self.log.debug("loading kube_config from: in_cluster configuration") + config.load_incluster_config() + return client.ApiClient() + + if kubeconfig_path is not None: + self.log.debug("loading kube_config from: %s", kubeconfig_path) + config.load_kube_config( + config_file=kubeconfig_path, + client_configuration=self.client_configuration, + ) + return client.ApiClient() + + if kubeconfig is not None: + with tempfile.NamedTemporaryFile() as temp_config: + self.log.debug("loading kube_config from: connection kube_config") + temp_config.write(kubeconfig.encode()) + temp_config.flush() + config.load_kube_config( + config_file=temp_config.name, + client_configuration=self.client_configuration, + ) + return client.ApiClient() + + self.log.debug("loading kube_config from: default file") + config.load_kube_config(client_configuration=self.client_configuration) + return client.ApiClient() + + @cached_property + def api_client(self) -> Any: + """Cached Kubernetes API client""" + return self.get_conn() + + def create_custom_object( + self, + group: str, + version: str, + plural: str, + body: Union[str, dict], + namespace: Optional[str] = None, + ): + """ + Creates custom resource definition object in Kubernetes + + :param group: api group + :type group: str + :param version: api version + :type version: str + :param plural: api plural + :type plural: str + :param body: crd object definition + :type body: Union[str, dict] + :param namespace: kubernetes namespace + :type namespace: str + """ + api = client.CustomObjectsApi(self.api_client) + if namespace is None: + namespace = self.get_namespace() + if isinstance(body, str): + body = _load_body_to_dict(body) + try: + response = api.create_namespaced_custom_object( + group=group, + version=version, + namespace=namespace, + plural=plural, + body=body, + ) + self.log.debug("Response: %s", response) + return response + except client.rest.ApiException as e: + raise AirflowException( + f"Exception when calling -> create_custom_object: {e}\n" + ) + + def get_custom_object( + self, + group: str, + version: str, + plural: str, + name: str, + namespace: Optional[str] = None, + ): + """ + Get custom resource definition object from Kubernetes + + :param group: api group + :type group: str + :param version: api version + :type version: str + :param plural: api plural + :type plural: str + :param name: crd object name + :type name: str + :param namespace: kubernetes namespace + :type namespace: str + """ + api = client.CustomObjectsApi(self.api_client) + if namespace is None: + namespace = self.get_namespace() + try: + response = api.get_namespaced_custom_object( + group=group, + version=version, + namespace=namespace, + plural=plural, + name=name, + ) + return response + except client.rest.ApiException as e: + raise AirflowException( + f"Exception when calling -> get_custom_object: {e}\n" + ) + + def get_namespace(self) -> str: + """Returns the namespace that defined in the connection""" + connection = self.get_connection(self.conn_id) + extras = connection.extra_dejson + namespace = extras.get("extra__kubernetes__namespace", "default") + return namespace + + def get_pod_log_stream( + self, + pod_name: str, + container: Optional[str] = "", + namespace: Optional[str] = None, + ) -> Tuple[watch.Watch, Generator[str, None, None]]: + """ + Retrieves a log stream for a container in a kubernetes pod. + + :param pod_name: pod name + :type pod_name: str + :param container: container name + :param namespace: kubernetes namespace + :type namespace: str + """ + api = client.CoreV1Api(self.api_client) + watcher = watch.Watch() + return ( + watcher, + watcher.stream( + api.read_namespaced_pod_log, + name=pod_name, + container=container, + namespace=namespace if namespace else self.get_namespace(), + ), + ) + + def get_pod_logs( + self, + pod_name: str, + container: Optional[str] = "", + namespace: Optional[str] = None, + ): + """ + Retrieves a container's log from the specified pod. + + :param pod_name: pod name + :type pod_name: str + :param container: container name + :param namespace: kubernetes namespace + :type namespace: str + """ + api = client.CoreV1Api(self.api_client) + return api.read_namespaced_pod_log( + name=pod_name, + container=container, + _preload_content=False, + namespace=namespace if namespace else self.get_namespace(), + ) diff --git a/reference/providers/cncf/kubernetes/operators/__init__.py b/reference/providers/cncf/kubernetes/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/cncf/kubernetes/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/cncf/kubernetes/operators/kubernetes_pod.py b/reference/providers/cncf/kubernetes/operators/kubernetes_pod.py new file mode 100644 index 0000000..08afbb5 --- /dev/null +++ b/reference/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -0,0 +1,564 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Executes task in a Kubernetes POD""" +import re +import warnings +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple + +from kubernetes.client import CoreV1Api +from kubernetes.client import models as k8s + +try: + import airflow.utils.yaml as yaml +except ImportError: + import yaml + +from airflow.exceptions import AirflowException +from airflow.kubernetes import kube_client, pod_generator, pod_launcher +from airflow.kubernetes.pod_generator import PodGenerator +from airflow.kubernetes.secret import Secret +from airflow.models import BaseOperator +from airflow.providers.cncf.kubernetes.backcompat.backwards_compat_converters import ( + convert_affinity, + convert_configmap, + convert_env_vars, + convert_image_pull_secrets, + convert_pod_runtime_info_env, + convert_port, + convert_resources, + convert_toleration, + convert_volume, + convert_volume_mount, +) +from airflow.providers.cncf.kubernetes.backcompat.pod_runtime_info_env import ( + PodRuntimeInfoEnv, +) +from airflow.utils.decorators import apply_defaults +from airflow.utils.helpers import validate_key +from airflow.utils.state import State +from airflow.version import version as airflow_version + +if TYPE_CHECKING: + import jinja2 + + +class KubernetesPodOperator( + BaseOperator +): # pylint: disable=too-many-instance-attributes + """ + Execute a task in a Kubernetes Pod + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:KubernetesPodOperator` + + .. note:: + If you use `Google Kubernetes Engine `__ + and Airflow is not running in the same cluster, consider using + :class:`~airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator`, which + simplifies the authorization process. + + :param namespace: the namespace to run within kubernetes. + :type namespace: str + :param image: Docker image you wish to launch. Defaults to hub.docker.com, + but fully qualified URLS will point to custom repositories. (templated) + :type image: str + :param name: name of the pod in which the task will run, will be used (plus a random + suffix) to generate a pod id (DNS-1123 subdomain, containing only [a-z0-9.-]). + :type name: str + :param cmds: entrypoint of the container. (templated) + The docker images's entrypoint is used if this is not provided. + :type cmds: list[str] + :param arguments: arguments of the entrypoint. (templated) + The docker image's CMD is used if this is not provided. + :type arguments: list[str] + :param ports: ports for launched pod. + :type ports: list[k8s.V1ContainerPort] + :param volume_mounts: volumeMounts for launched pod. + :type volume_mounts: list[k8s.V1VolumeMount] + :param volumes: volumes for launched pod. Includes ConfigMaps and PersistentVolumes. + :type volumes: list[k8s.V1Volume] + :param env_vars: Environment variables initialized in the container. (templated) + :type env_vars: list[k8s.V1EnvVar] + :param secrets: Kubernetes secrets to inject in the container. + They can be exposed as environment vars or files in a volume. + :type secrets: list[airflow.kubernetes.secret.Secret] + :param in_cluster: run kubernetes client with in_cluster configuration. + :type in_cluster: bool + :param cluster_context: context that points to kubernetes cluster. + Ignored when in_cluster is True. If None, current-context is used. + :type cluster_context: str + :param reattach_on_restart: if the scheduler dies while the pod is running, reattach and monitor + :type reattach_on_restart: bool + :param labels: labels to apply to the Pod. (templated) + :type labels: dict + :param startup_timeout_seconds: timeout in seconds to startup the pod. + :type startup_timeout_seconds: int + :param get_logs: get the stdout of the container as logs of the tasks. + :type get_logs: bool + :param image_pull_policy: Specify a policy to cache or always pull an image. + :type image_pull_policy: str + :param annotations: non-identifying metadata you can attach to the Pod. + Can be a large range of data, and can include characters + that are not permitted by labels. + :type annotations: dict + :param resources: A dict containing resources requests and limits. + Possible keys are request_memory, request_cpu, limit_memory, limit_cpu, + and limit_gpu, which will be used to generate airflow.kubernetes.pod.Resources. + See also kubernetes.io/docs/concepts/configuration/manage-compute-resources-container + :type resources: k8s.V1ResourceRequirements + :param affinity: A dict containing a group of affinity scheduling rules. + :type affinity: k8s.V1Affinity + :param config_file: The path to the Kubernetes config file. (templated) + If not specified, default value is ``~/.kube/config`` + :type config_file: str + :param node_selectors: A dict containing a group of scheduling rules. + :type node_selectors: dict + :param image_pull_secrets: Any image pull secrets to be given to the pod. + If more than one secret is required, provide a + comma separated list: secret_a,secret_b + :type image_pull_secrets: List[k8s.V1LocalObjectReference] + :param service_account_name: Name of the service account + :type service_account_name: str + :param is_delete_operator_pod: What to do when the pod reaches its final + state, or the execution is interrupted. + If False (default): do nothing, If True: delete the pod + :type is_delete_operator_pod: bool + :param hostnetwork: If True enable host networking on the pod. + :type hostnetwork: bool + :param tolerations: A list of kubernetes tolerations. + :type tolerations: List[k8s.V1Toleration] + :param security_context: security options the pod should run with (PodSecurityContext). + :type security_context: dict + :param dnspolicy: dnspolicy for the pod. + :type dnspolicy: str + :param schedulername: Specify a schedulername for the pod + :type schedulername: str + :param full_pod_spec: The complete podSpec + :type full_pod_spec: kubernetes.client.models.V1Pod + :param init_containers: init container for the launched Pod + :type init_containers: list[kubernetes.client.models.V1Container] + :param log_events_on_failure: Log the pod's events if a failure occurs + :type log_events_on_failure: bool + :param do_xcom_push: If True, the content of the file + /airflow/xcom/return.json in the container will also be pushed to an + XCom when the container completes. + :type do_xcom_push: bool + :param pod_template_file: path to pod template file (templated) + :type pod_template_file: str + :param priority_class_name: priority class name for the launched Pod + :type priority_class_name: str + :param termination_grace_period: Termination grace period if task killed in UI, + defaults to kubernetes default + :type termination_grace_period: int + """ + + template_fields: Iterable[str] = ( + "image", + "cmds", + "arguments", + "env_vars", + "labels", + "config_file", + "pod_template_file", + ) + + # fmt: off + @apply_defaults + def __init__( # pylint: disable=too-many-arguments,too-many-locals + # fmt: on + self, + *, + namespace: Optional[str] = None, + image: Optional[str] = None, + name: Optional[str] = None, + cmds: Optional[List[str]] = None, + arguments: Optional[List[str]] = None, + ports: Optional[List[k8s.V1ContainerPort]] = None, + volume_mounts: Optional[List[k8s.V1VolumeMount]] = None, + volumes: Optional[List[k8s.V1Volume]] = None, + env_vars: Optional[List[k8s.V1EnvVar]] = None, + env_from: Optional[List[k8s.V1EnvFromSource]] = None, + secrets: Optional[List[Secret]] = None, + in_cluster: Optional[bool] = None, + cluster_context: Optional[str] = None, + labels: Optional[Dict] = None, + reattach_on_restart: bool = True, + startup_timeout_seconds: int = 120, + get_logs: bool = True, + image_pull_policy: str = 'IfNotPresent', + annotations: Optional[Dict] = None, + resources: Optional[k8s.V1ResourceRequirements] = None, + affinity: Optional[k8s.V1Affinity] = None, + config_file: Optional[str] = None, + node_selectors: Optional[dict] = None, + node_selector: Optional[dict] = None, + image_pull_secrets: Optional[List[k8s.V1LocalObjectReference]] = None, + service_account_name: str = 'default', + is_delete_operator_pod: bool = False, + hostnetwork: bool = False, + tolerations: Optional[List[k8s.V1Toleration]] = None, + security_context: Optional[Dict] = None, + dnspolicy: Optional[str] = None, + schedulername: Optional[str] = None, + full_pod_spec: Optional[k8s.V1Pod] = None, + init_containers: Optional[List[k8s.V1Container]] = None, + log_events_on_failure: bool = False, + do_xcom_push: bool = False, + pod_template_file: Optional[str] = None, + priority_class_name: Optional[str] = None, + pod_runtime_info_envs: List[PodRuntimeInfoEnv] = None, + termination_grace_period: Optional[int] = None, + configmaps: Optional[str] = None, + **kwargs, + ) -> None: + if kwargs.get('xcom_push') is not None: + raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") + super().__init__(resources=None, **kwargs) + + self.do_xcom_push = do_xcom_push + self.image = image + self.namespace = namespace + self.cmds = cmds or [] + self.arguments = arguments or [] + self.labels = labels or {} + self.startup_timeout_seconds = startup_timeout_seconds + self.env_vars = convert_env_vars(env_vars) if env_vars else [] + if pod_runtime_info_envs: + self.env_vars.extend([convert_pod_runtime_info_env(p) for p in pod_runtime_info_envs]) + self.env_from = env_from or [] + if configmaps: + self.env_from.extend([convert_configmap(c) for c in configmaps]) + self.ports = [convert_port(p) for p in ports] if ports else [] + self.volume_mounts = [convert_volume_mount(v) for v in volume_mounts] if volume_mounts else [] + self.volumes = [convert_volume(volume) for volume in volumes] if volumes else [] + self.secrets = secrets or [] + self.in_cluster = in_cluster + self.cluster_context = cluster_context + self.reattach_on_restart = reattach_on_restart + self.get_logs = get_logs + self.image_pull_policy = image_pull_policy + if node_selectors: + # Node selectors is incorrect based on k8s API + warnings.warn("node_selectors is deprecated. Please use node_selector instead.") + self.node_selector = node_selectors or {} + elif node_selector: + self.node_selector = node_selector or {} + else: + self.node_selector = None + self.annotations = annotations or {} + self.affinity = convert_affinity(affinity) if affinity else k8s.V1Affinity() + self.k8s_resources = convert_resources(resources) if resources else {} + self.config_file = config_file + self.image_pull_secrets = convert_image_pull_secrets(image_pull_secrets) if image_pull_secrets else [] + self.service_account_name = service_account_name + self.is_delete_operator_pod = is_delete_operator_pod + self.hostnetwork = hostnetwork + self.tolerations = [convert_toleration(toleration) for toleration in tolerations] \ + if tolerations else [] + self.security_context = security_context or {} + self.dnspolicy = dnspolicy + self.schedulername = schedulername + self.full_pod_spec = full_pod_spec + self.init_containers = init_containers or [] + self.log_events_on_failure = log_events_on_failure + self.priority_class_name = priority_class_name + self.pod_template_file = pod_template_file + self.name = self._set_name(name) + self.termination_grace_period = termination_grace_period + self.client: CoreV1Api = None + self.pod: k8s.V1Pod = None + + def _render_nested_template_fields( + self, + content: Any, + context: Dict, + jinja_env: "jinja2.Environment", + seen_oids: set, + ) -> None: + if id(content) not in seen_oids and isinstance(content, k8s.V1EnvVar): + seen_oids.add(id(content)) + self._do_render_template_fields(content, ('value', 'name'), context, jinja_env, seen_oids) + return + + super()._render_nested_template_fields( + content, + context, + jinja_env, + seen_oids + ) + + @staticmethod + def create_labels_for_pod(context) -> dict: + """ + Generate labels for the pod to track the pod in case of Operator crash + + :param context: task context provided by airflow DAG + :return: dict + """ + labels = { + 'dag_id': context['dag'].dag_id, + 'task_id': context['task'].task_id, + 'execution_date': context['ts'], + 'try_number': context['ti'].try_number, + } + # In the case of sub dags this is just useful + if context['dag'].is_subdag: + labels['parent_dag_id'] = context['dag'].parent_dag.dag_id + # Ensure that label is valid for Kube, + # and if not truncate/remove invalid chars and replace with short hash. + for label_id, label in labels.items(): + safe_label = pod_generator.make_safe_label_value(str(label)) + labels[label_id] = safe_label + return labels + + def execute(self, context) -> Optional[str]: + try: + if self.in_cluster is not None: + client = kube_client.get_kube_client( + in_cluster=self.in_cluster, + cluster_context=self.cluster_context, + config_file=self.config_file, + ) + else: + client = kube_client.get_kube_client( + cluster_context=self.cluster_context, config_file=self.config_file + ) + + self.pod = self.create_pod_request_obj() + self.namespace = self.pod.metadata.namespace + + self.client = client + + # Add combination of labels to uniquely identify a running pod + labels = self.create_labels_for_pod(context) + + label_selector = self._get_pod_identifying_label_string(labels) + + self.namespace = self.pod.metadata.namespace + + pod_list = client.list_namespaced_pod(self.namespace, label_selector=label_selector) + + if len(pod_list.items) > 1 and self.reattach_on_restart: + raise AirflowException( + f'More than one pod running with labels: {label_selector}' + ) + + launcher = pod_launcher.PodLauncher(kube_client=client, extract_xcom=self.do_xcom_push) + + if len(pod_list.items) == 1: + try_numbers_match = self._try_numbers_match(context, pod_list.items[0]) + final_state, result = self.handle_pod_overlap( + labels, try_numbers_match, launcher, pod_list.items[0] + ) + else: + self.log.info("creating pod with labels %s and launcher %s", labels, launcher) + final_state, _, result = self.create_new_pod_for_operator(labels, launcher) + if final_state != State.SUCCESS: + status = self.client.read_namespaced_pod(self.pod.metadata.name, self.namespace) + raise AirflowException(f'Pod {self.pod.metadata.name} returned a failure: {status}') + return result + except AirflowException as ex: + raise AirflowException(f'Pod Launching failed: {ex}') + + def handle_pod_overlap( + self, labels: dict, try_numbers_match: bool, launcher: Any, pod: k8s.V1Pod + ) -> Tuple[State, Optional[str]]: + """ + + In cases where the Scheduler restarts while a KubernetesPodOperator task is running, + this function will either continue to monitor the existing pod or launch a new pod + based on the `reattach_on_restart` parameter. + + :param labels: labels used to determine if a pod is repeated + :type labels: dict + :param try_numbers_match: do the try numbers match? Only needed for logging purposes + :type try_numbers_match: bool + :param launcher: PodLauncher + :param pod_list: list of pods found + """ + if try_numbers_match: + log_line = f"found a running pod with labels {labels} and the same try_number." + else: + log_line = f"found a running pod with labels {labels} but a different try_number." + + # In case of failed pods, should reattach the first time, but only once + # as the task will have already failed. + if self.reattach_on_restart and not pod.metadata.labels.get("already_checked"): + log_line += " Will attach to this pod and monitor instead of starting new one" + self.log.info(log_line) + self.pod = pod + final_state, result = self.monitor_launched_pod(launcher, pod) + else: + log_line += f"creating pod with labels {labels} and launcher {launcher}" + self.log.info(log_line) + final_state, _, result = self.create_new_pod_for_operator(labels, launcher) + return final_state, result + + @staticmethod + def _get_pod_identifying_label_string(labels) -> str: + filtered_labels = {label_id: label for label_id, label in labels.items() if label_id != 'try_number'} + return ','.join([label_id + '=' + label for label_id, label in sorted(filtered_labels.items())]) + + @staticmethod + def _try_numbers_match(context, pod) -> bool: + return pod.metadata.labels['try_number'] == context['ti'].try_number + + def _set_name(self, name): + if name is None: + return None + validate_key(name, max_length=220) + return re.sub(r'[^a-z0-9.-]+', '-', name.lower()) + + def create_pod_request_obj(self) -> k8s.V1Pod: + """ + Creates a V1Pod based on user parameters. Note that a `pod` or `pod_template_file` + will supersede all other values. + + """ + self.log.debug("Creating pod for K8sPodOperator task %s", self.task_id) + if self.pod_template_file: + self.log.debug("Pod template file found, will parse for base pod") + pod_template = pod_generator.PodGenerator.deserialize_model_file(self.pod_template_file) + if self.full_pod_spec: + pod_template = PodGenerator.reconcile_pods(pod_template, self.full_pod_spec) + elif self.full_pod_spec: + pod_template = self.full_pod_spec + else: + pod_template = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="name")) + + pod = k8s.V1Pod( + api_version="v1", + kind="Pod", + metadata=k8s.V1ObjectMeta( + namespace=self.namespace, + labels=self.labels, + name=PodGenerator.make_unique_pod_id(self.name), + annotations=self.annotations, + ), + spec=k8s.V1PodSpec( + node_selector=self.node_selector, + affinity=self.affinity, + tolerations=self.tolerations, + init_containers=self.init_containers, + containers=[ + k8s.V1Container( + image=self.image, + name="base", + command=self.cmds, + ports=self.ports, + image_pull_policy=self.image_pull_policy, + resources=self.k8s_resources, + volume_mounts=self.volume_mounts, + args=self.arguments, + env=self.env_vars, + env_from=self.env_from, + ) + ], + image_pull_secrets=self.image_pull_secrets, + service_account_name=self.service_account_name, + host_network=self.hostnetwork, + security_context=self.security_context, + dns_policy=self.dnspolicy, + scheduler_name=self.schedulername, + restart_policy='Never', + priority_class_name=self.priority_class_name, + volumes=self.volumes, + ), + ) + + pod = PodGenerator.reconcile_pods(pod_template, pod) + + for secret in self.secrets: + self.log.debug("Adding secret to task %s", self.task_id) + pod = secret.attach_to_pod(pod) + if self.do_xcom_push: + self.log.debug("Adding xcom sidecar to task %s", self.task_id) + pod = PodGenerator.add_xcom_sidecar(pod) + return pod + + def create_new_pod_for_operator(self, labels, launcher) -> Tuple[State, k8s.V1Pod, Optional[str]]: + """ + Creates a new pod and monitors for duration of task + + :param labels: labels used to track pod + :param launcher: pod launcher that will manage launching and monitoring pods + :return: + """ + if not (self.full_pod_spec or self.pod_template_file): + # Add Airflow Version to the label + # And a label to identify that pod is launched by KubernetesPodOperator + self.log.debug("Adding k8spodoperator labels to pod before launch for task %s", self.task_id) + self.labels.update( + { + 'airflow_version': airflow_version.replace('+', '-'), + 'kubernetes_pod_operator': 'True', + } + ) + self.labels.update(labels) + self.pod.metadata.labels = self.labels + self.log.debug("Starting pod:\n%s", yaml.safe_dump(self.pod.to_dict())) + try: + launcher.start_pod(self.pod, startup_timeout=self.startup_timeout_seconds) + final_state, result = launcher.monitor_pod(pod=self.pod, get_logs=self.get_logs) + except AirflowException: + if self.log_events_on_failure: + for event in launcher.read_pod_events(self.pod).items: + self.log.error("Pod Event: %s - %s", event.reason, event.message) + raise + finally: + if self.is_delete_operator_pod: + self.log.debug("Deleting pod for task %s", self.task_id) + launcher.delete_pod(self.pod) + return final_state, self.pod, result + + def patch_already_checked(self, pod: k8s.V1Pod): + """Add an "already tried annotation to ensure we only retry once""" + pod.metadata.labels["already_checked"] = "True" + body = PodGenerator.serialize_pod(pod) + self.client.patch_namespaced_pod(pod.metadata.name, pod.metadata.namespace, body) + + def monitor_launched_pod(self, launcher, pod) -> Tuple[State, Optional[str]]: + """ + Monitors a pod to completion that was created by a previous KubernetesPodOperator + + :param launcher: pod launcher that will manage launching and monitoring pods + :param pod: podspec used to find pod using k8s API + :return: + """ + try: + (final_state, result) = launcher.monitor_pod(pod, get_logs=self.get_logs) + finally: + if self.is_delete_operator_pod: + launcher.delete_pod(pod) + if final_state != State.SUCCESS: + if self.log_events_on_failure: + for event in launcher.read_pod_events(pod).items: + self.log.error("Pod Event: %s - %s", event.reason, event.message) + self.patch_already_checked(self.pod) + raise AirflowException(f'Pod returned a failure: {final_state}') + return final_state, result + + def on_kill(self) -> None: + if self.pod: + pod: k8s.V1Pod = self.pod + namespace = pod.metadata.namespace + name = pod.metadata.name + kwargs = {} + if self.termination_grace_period is not None: + kwargs = {"grace_period_seconds": self.termination_grace_period} + self.client.delete_namespaced_pod(name=name, namespace=namespace, **kwargs) diff --git a/reference/providers/cncf/kubernetes/operators/spark_kubernetes.py b/reference/providers/cncf/kubernetes/operators/spark_kubernetes.py new file mode 100644 index 0000000..202cae3 --- /dev/null +++ b/reference/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from airflow.models import BaseOperator +from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook +from airflow.utils.decorators import apply_defaults + + +class SparkKubernetesOperator(BaseOperator): + """ + Creates sparkApplication object in kubernetes cluster: + + .. seealso:: + For more detail about Spark Application Object have a look at the reference: + https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/v1beta2-1.1.0-2.4.5/docs/api-docs.md#sparkapplication + + :param application_file: filepath to kubernetes custom_resource_definition of sparkApplication + :type application_file: str + :param namespace: kubernetes namespace to put sparkApplication + :type namespace: str + :param kubernetes_conn_id: the connection to Kubernetes cluster + :type kubernetes_conn_id: str + :param api_group: kubernetes api group of sparkApplication + :type api_group: str + :param api_version: kubernetes api version of sparkApplication + :type api_version: str + """ + + template_fields = ["application_file", "namespace"] + template_ext = ("yaml", "yml", "json") + ui_color = "#f4a460" + + @apply_defaults + def __init__( + self, + *, + application_file: str, + namespace: Optional[str] = None, + kubernetes_conn_id: str = "kubernetes_default", + api_group: str = "sparkoperator.k8s.io", + api_version: str = "v1beta2", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.application_file = application_file + self.namespace = namespace + self.kubernetes_conn_id = kubernetes_conn_id + self.api_group = api_group + self.api_version = api_version + + def execute(self, context): + self.log.info("Creating sparkApplication") + hook = KubernetesHook(conn_id=self.kubernetes_conn_id) + response = hook.create_custom_object( + group=self.api_group, + version=self.api_version, + plural="sparkapplications", + body=self.application_file, + namespace=self.namespace, + ) + return response diff --git a/reference/providers/cncf/kubernetes/provider.yaml b/reference/providers/cncf/kubernetes/provider.yaml new file mode 100644 index 0000000..70ff4cd --- /dev/null +++ b/reference/providers/cncf/kubernetes/provider.yaml @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-cncf-kubernetes +name: Kubernetes +description: | + `Kubernetes `__ + +versions: + - 1.0.2 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Kubernetes + external-doc-url: https://kubernetes.io/ + how-to-guide: + - /docs/apache-airflow-providers-cncf-kubernetes/operators.rst + logo: /integration-logos/kubernetes/Kubernetes.png + tags: [software] + - integration-name: Spark on Kubernetes + external-doc-url: https://github.com/GoogleCloudPlatform/spark-on-k8s-operator + logo: /integration-logos/kubernetes/Spark-On-Kubernetes.png + tags: [software] + +operators: + - integration-name: Kubernetes + python-modules: + - airflow.providers.cncf.kubernetes.operators.kubernetes_pod + - airflow.providers.cncf.kubernetes.operators.spark_kubernetes + +sensors: + - integration-name: Kubernetes + python-modules: + - airflow.providers.cncf.kubernetes.sensors.spark_kubernetes + +hooks: + - integration-name: Kubernetes + python-modules: + - airflow.providers.cncf.kubernetes.hooks.kubernetes + +hook-class-names: + - airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook diff --git a/reference/providers/cncf/kubernetes/sensors/__init__.py b/reference/providers/cncf/kubernetes/sensors/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/cncf/kubernetes/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/cncf/kubernetes/sensors/spark_kubernetes.py b/reference/providers/cncf/kubernetes/sensors/spark_kubernetes.py new file mode 100644 index 0000000..e2470a3 --- /dev/null +++ b/reference/providers/cncf/kubernetes/sensors/spark_kubernetes.py @@ -0,0 +1,131 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults +from kubernetes import client + + +class SparkKubernetesSensor(BaseSensorOperator): + """ + Checks sparkApplication object in kubernetes cluster: + + .. seealso:: + For more detail about Spark Application Object have a look at the reference: + https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/v1beta2-1.1.0-2.4.5/docs/api-docs.md#sparkapplication + + :param application_name: spark Application resource name + :type application_name: str + :param namespace: the kubernetes namespace where the sparkApplication reside in + :type namespace: str + :param kubernetes_conn_id: the connection to Kubernetes cluster + :type kubernetes_conn_id: str + :param attach_log: determines whether logs for driver pod should be appended to the sensor log + :type attach_log: bool + :param api_group: kubernetes api group of sparkApplication + :type api_group: str + :param api_version: kubernetes api version of sparkApplication + :type api_version: str + """ + + template_fields = ("application_name", "namespace") + FAILURE_STATES = ("FAILED", "UNKNOWN") + SUCCESS_STATES = ("COMPLETED",) + + @apply_defaults + def __init__( + self, + *, + application_name: str, + attach_log: bool = False, + namespace: Optional[str] = None, + kubernetes_conn_id: str = "kubernetes_default", + api_group: str = "sparkoperator.k8s.io", + api_version: str = "v1beta2", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.application_name = application_name + self.attach_log = attach_log + self.namespace = namespace + self.kubernetes_conn_id = kubernetes_conn_id + self.hook = KubernetesHook(conn_id=self.kubernetes_conn_id) + self.api_group = api_group + self.api_version = api_version + + def _log_driver(self, application_state: str, response: dict) -> None: + if not self.attach_log: + return + status_info = response["status"] + if "driverInfo" not in status_info: + return + driver_info = status_info["driverInfo"] + if "podName" not in driver_info: + return + driver_pod_name = driver_info["podName"] + namespace = response["metadata"]["namespace"] + log_method = ( + self.log.error + if application_state in self.FAILURE_STATES + else self.log.info + ) + try: + log = "" + for line in self.hook.get_pod_logs(driver_pod_name, namespace=namespace): + log += line.decode() + log_method(log) + except client.rest.ApiException as e: + self.log.warning( + "Could not read logs for pod %s. It may have been disposed.\n" + "Make sure timeToLiveSeconds is set on your SparkApplication spec.\n" + "underlying exception: %s", + driver_pod_name, + e, + ) + + def poke(self, context: Dict) -> bool: + self.log.info("Poking: %s", self.application_name) + response = self.hook.get_custom_object( + group=self.api_group, + version=self.api_version, + plural="sparkapplications", + name=self.application_name, + namespace=self.namespace, + ) + try: + application_state = response["status"]["applicationState"]["state"] + except KeyError: + return False + if ( + self.attach_log + and application_state in self.FAILURE_STATES + self.SUCCESS_STATES + ): + self._log_driver(application_state, response) + if application_state in self.FAILURE_STATES: + raise AirflowException( + f"Spark application failed with state: {application_state}" + ) + elif application_state in self.SUCCESS_STATES: + self.log.info("Spark application ended successfully") + return True + else: + self.log.info("Spark application is still in state: %s", application_state) + return False diff --git a/reference/providers/databricks/CHANGELOG.rst b/reference/providers/databricks/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/databricks/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/databricks/__init__.py b/reference/providers/databricks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/databricks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/databricks/example_dags/__init__.py b/reference/providers/databricks/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/databricks/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/databricks/example_dags/example_databricks.py b/reference/providers/databricks/example_dags/example_databricks.py new file mode 100644 index 0000000..6836d20 --- /dev/null +++ b/reference/providers/databricks/example_dags/example_databricks.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example DAG which uses the DatabricksSubmitRunOperator. +In this example, we create two tasks which execute sequentially. +The first task is to run a notebook at the workspace path "/test" +and the second task is to run a JAR uploaded to DBFS. Both, +tasks use new clusters. + +Because we have set a downstream dependency on the notebook task, +the spark jar task will NOT run until the notebook task completes +successfully. + +The definition of a successful run is if the run has a result_state of "SUCCESS". +For more information about the state of a run refer to +https://docs.databricks.com/api/latest/jobs.html#runstate +""" + +from airflow import DAG +from airflow.providers.databricks.operators.databricks import ( + DatabricksSubmitRunOperator, +) +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", + "email": ["airflow@example.com"], + "depends_on_past": False, +} + +with DAG( + dag_id="example_databricks_operator", + default_args=default_args, + schedule_interval="@daily", + start_date=days_ago(2), + tags=["example"], +) as dag: + new_cluster = { + "spark_version": "2.1.0-db3-scala2.11", + "node_type_id": "r3.xlarge", + "aws_attributes": {"availability": "ON_DEMAND"}, + "num_workers": 8, + } + + notebook_task_params = { + "new_cluster": new_cluster, + "notebook_task": { + "notebook_path": "/Users/airflow@example.com/PrepareData", + }, + } + # [START howto_operator_databricks_json] + # Example of using the JSON parameter to initialize the operator. + notebook_task = DatabricksSubmitRunOperator( + task_id="notebook_task", json=notebook_task_params + ) + # [END howto_operator_databricks_json] + + # [START howto_operator_databricks_named] + # Example of using the named parameters of DatabricksSubmitRunOperator + # to initialize the operator. + spark_jar_task = DatabricksSubmitRunOperator( + task_id="spark_jar_task", + new_cluster=new_cluster, + spark_jar_task={"main_class_name": "com.example.ProcessData"}, + libraries=[{"jar": "dbfs:/lib/etl-0.1.jar"}], + ) + # [END howto_operator_databricks_named] + notebook_task >> spark_jar_task diff --git a/reference/providers/databricks/hooks/__init__.py b/reference/providers/databricks/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/databricks/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/databricks/hooks/databricks.py b/reference/providers/databricks/hooks/databricks.py new file mode 100644 index 0000000..ac09efb --- /dev/null +++ b/reference/providers/databricks/hooks/databricks.py @@ -0,0 +1,385 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Databricks hook. + +This hook enable the submitting and running of jobs to the Databricks platform. Internally the +operators talk to the ``api/2.0/jobs/runs/submit`` +`endpoint `_. +""" +from time import sleep +from urllib.parse import urlparse + +import requests +from airflow import __version__ +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from requests import PreparedRequest +from requests import exceptions as requests_exceptions +from requests.auth import AuthBase + +RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart") +START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start") +TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete") + +RUN_NOW_ENDPOINT = ("POST", "api/2.0/jobs/run-now") +SUBMIT_RUN_ENDPOINT = ("POST", "api/2.0/jobs/runs/submit") +GET_RUN_ENDPOINT = ("GET", "api/2.0/jobs/runs/get") +CANCEL_RUN_ENDPOINT = ("POST", "api/2.0/jobs/runs/cancel") +USER_AGENT_HEADER = {"user-agent": f"airflow-{__version__}"} + +INSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/install") +UNINSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/uninstall") + + +class RunState: + """Utility class for the run state concept of Databricks runs.""" + + def __init__( + self, life_cycle_state: str, result_state: str, state_message: str + ) -> None: + self.life_cycle_state = life_cycle_state + self.result_state = result_state + self.state_message = state_message + + @property + def is_terminal(self) -> bool: + """True if the current state is a terminal state.""" + if self.life_cycle_state not in RUN_LIFE_CYCLE_STATES: + raise AirflowException( + ( + "Unexpected life cycle state: {}: If the state has " + "been introduced recently, please check the Databricks user " + "guide for troubleshooting information" + ).format(self.life_cycle_state) + ) + return self.life_cycle_state in ("TERMINATED", "SKIPPED", "INTERNAL_ERROR") + + @property + def is_successful(self) -> bool: + """True if the result state is SUCCESS""" + return self.result_state == "SUCCESS" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RunState): + return NotImplemented + return ( + self.life_cycle_state == other.life_cycle_state + and self.result_state == other.result_state + and self.state_message == other.state_message + ) + + def __repr__(self) -> str: + return str(self.__dict__) + + +class DatabricksHook(BaseHook): # noqa + """ + Interact with Databricks. + + :param databricks_conn_id: The name of the databricks connection to use. + :type databricks_conn_id: str + :param timeout_seconds: The amount of time in seconds the requests library + will wait before timing-out. + :type timeout_seconds: int + :param retry_limit: The number of times to retry the connection in case of + service outages. + :type retry_limit: int + :param retry_delay: The number of seconds to wait between retries (it + might be a floating point number). + :type retry_delay: float + """ + + conn_name_attr = "databricks_conn_id" + default_conn_name = "databricks_default" + conn_type = "databricks" + hook_name = "Databricks" + + def __init__( + self, + databricks_conn_id: str = default_conn_name, + timeout_seconds: int = 180, + retry_limit: int = 3, + retry_delay: float = 1.0, + ) -> None: + super().__init__() + self.databricks_conn_id = databricks_conn_id + self.databricks_conn = self.get_connection(databricks_conn_id) + self.timeout_seconds = timeout_seconds + if retry_limit < 1: + raise ValueError("Retry limit must be greater than equal to 1") + self.retry_limit = retry_limit + self.retry_delay = retry_delay + + @staticmethod + def _parse_host(host: str) -> str: + """ + The purpose of this function is to be robust to improper connections + settings provided by users, specifically in the host field. + + For example -- when users supply ``https://xx.cloud.databricks.com`` as the + host, we must strip out the protocol to get the host.:: + + h = DatabricksHook() + assert h._parse_host('https://xx.cloud.databricks.com') == \ + 'xx.cloud.databricks.com' + + In the case where users supply the correct ``xx.cloud.databricks.com`` as the + host, this function is a no-op.:: + + assert h._parse_host('xx.cloud.databricks.com') == 'xx.cloud.databricks.com' + + """ + urlparse_host = urlparse(host).hostname + if urlparse_host: + # In this case, host = https://xx.cloud.databricks.com + return urlparse_host + else: + # In this case, host = xx.cloud.databricks.com + return host + + def _do_api_call(self, endpoint_info, json): + """ + Utility function to perform an API call with retries + + :param endpoint_info: Tuple of method and endpoint + :type endpoint_info: tuple[string, string] + :param json: Parameters for this API call. + :type json: dict + :return: If the api call returns a OK status code, + this function returns the response in JSON. Otherwise, + we throw an AirflowException. + :rtype: dict + """ + method, endpoint = endpoint_info + + if "token" in self.databricks_conn.extra_dejson: + self.log.info("Using token auth. ") + auth = _TokenAuth(self.databricks_conn.extra_dejson["token"]) + if "host" in self.databricks_conn.extra_dejson: + host = self._parse_host(self.databricks_conn.extra_dejson["host"]) + else: + host = self.databricks_conn.host + else: + self.log.info("Using basic auth. ") + auth = (self.databricks_conn.login, self.databricks_conn.password) + host = self.databricks_conn.host + + url = f"https://{self._parse_host(host)}/{endpoint}" + + if method == "GET": + request_func = requests.get + elif method == "POST": + request_func = requests.post + elif method == "PATCH": + request_func = requests.patch + else: + raise AirflowException("Unexpected HTTP Method: " + method) + + attempt_num = 1 + while True: + try: + response = request_func( + url, + json=json if method in ("POST", "PATCH") else None, + params=json if method == "GET" else None, + auth=auth, + headers=USER_AGENT_HEADER, + timeout=self.timeout_seconds, + ) + response.raise_for_status() + return response.json() + except requests_exceptions.RequestException as e: + if not _retryable_error(e): + # In this case, the user probably made a mistake. + # Don't retry. + raise AirflowException( + f"Response: {e.response.content}, Status Code: {e.response.status_code}" + ) + + self._log_request_error(attempt_num, e) + + if attempt_num == self.retry_limit: + raise AirflowException( + ( + "API requests to Databricks failed {} times. " + "Giving up." + ).format(self.retry_limit) + ) + + attempt_num += 1 + sleep(self.retry_delay) + + def _log_request_error(self, attempt_num: int, error: str) -> None: + self.log.error( + "Attempt %s API Request to Databricks failed with reason: %s", + attempt_num, + error, + ) + + def run_now(self, json: dict) -> str: + """ + Utility function to call the ``api/2.0/jobs/run-now`` endpoint. + + :param json: The data used in the body of the request to the ``run-now`` endpoint. + :type json: dict + :return: the run_id as a string + :rtype: str + """ + response = self._do_api_call(RUN_NOW_ENDPOINT, json) + return response["run_id"] + + def submit_run(self, json: dict) -> str: + """ + Utility function to call the ``api/2.0/jobs/runs/submit`` endpoint. + + :param json: The data used in the body of the request to the ``submit`` endpoint. + :type json: dict + :return: the run_id as a string + :rtype: str + """ + response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json) + return response["run_id"] + + def get_run_page_url(self, run_id: str) -> str: + """ + Retrieves run_page_url. + + :param run_id: id of the run + :return: URL of the run page + """ + json = {"run_id": run_id} + response = self._do_api_call(GET_RUN_ENDPOINT, json) + return response["run_page_url"] + + def get_job_id(self, run_id: str) -> str: + """ + Retrieves job_id from run_id. + + :param run_id: id of the run + :type run_id: str + :return: Job id for given Databricks run + """ + json = {"run_id": run_id} + response = self._do_api_call(GET_RUN_ENDPOINT, json) + return response["job_id"] + + def get_run_state(self, run_id: str) -> RunState: + """ + Retrieves run state of the run. + + :param run_id: id of the run + :return: state of the run + """ + json = {"run_id": run_id} + response = self._do_api_call(GET_RUN_ENDPOINT, json) + state = response["state"] + life_cycle_state = state["life_cycle_state"] + # result_state may not be in the state if not terminal + result_state = state.get("result_state", None) + state_message = state["state_message"] + return RunState(life_cycle_state, result_state, state_message) + + def cancel_run(self, run_id: str) -> None: + """ + Cancels the run. + + :param run_id: id of the run + """ + json = {"run_id": run_id} + self._do_api_call(CANCEL_RUN_ENDPOINT, json) + + def restart_cluster(self, json: dict) -> None: + """ + Restarts the cluster. + + :param json: json dictionary containing cluster specification. + """ + self._do_api_call(RESTART_CLUSTER_ENDPOINT, json) + + def start_cluster(self, json: dict) -> None: + """ + Starts the cluster. + + :param json: json dictionary containing cluster specification. + """ + self._do_api_call(START_CLUSTER_ENDPOINT, json) + + def terminate_cluster(self, json: dict) -> None: + """ + Terminates the cluster. + + :param json: json dictionary containing cluster specification. + """ + self._do_api_call(TERMINATE_CLUSTER_ENDPOINT, json) + + def install(self, json: dict) -> None: + """ + Install libraries on the cluster. + + Utility function to call the ``2.0/libraries/install`` endpoint. + + :param json: json dictionary containing cluster_id and an array of library + :type json: dict + """ + self._do_api_call(INSTALL_LIBS_ENDPOINT, json) + + def uninstall(self, json: dict) -> None: + """ + Uninstall libraries on the cluster. + + Utility function to call the ``2.0/libraries/uninstall`` endpoint. + + :param json: json dictionary containing cluster_id and an array of library + :type json: dict + """ + self._do_api_call(UNINSTALL_LIBS_ENDPOINT, json) + + +def _retryable_error(exception) -> bool: + return ( + isinstance( + exception, + (requests_exceptions.ConnectionError, requests_exceptions.Timeout), + ) + or exception.response is not None + and exception.response.status_code >= 500 + ) + + +RUN_LIFE_CYCLE_STATES = [ + "PENDING", + "RUNNING", + "TERMINATING", + "TERMINATED", + "SKIPPED", + "INTERNAL_ERROR", +] + + +class _TokenAuth(AuthBase): + """ + Helper class for requests Auth field. AuthBase requires you to implement the __call__ + magic function. + """ + + def __init__(self, token: str) -> None: + self.token = token + + def __call__(self, r: PreparedRequest) -> PreparedRequest: + r.headers["Authorization"] = "Bearer " + self.token + return r diff --git a/reference/providers/databricks/operators/__init__.py b/reference/providers/databricks/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/databricks/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/databricks/operators/databricks.py b/reference/providers/databricks/operators/databricks.py new file mode 100644 index 0000000..20344a3 --- /dev/null +++ b/reference/providers/databricks/operators/databricks.py @@ -0,0 +1,526 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains Databricks operators.""" + +import time +from typing import Any, Dict, List, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.databricks.hooks.databricks import DatabricksHook +from airflow.utils.decorators import apply_defaults + +XCOM_RUN_ID_KEY = "run_id" +XCOM_RUN_PAGE_URL_KEY = "run_page_url" + + +def _deep_string_coerce(content, json_path: str = "json") -> Union[str, list, dict]: + """ + Coerces content or all values of content if it is a dict to a string. The + function will throw if content contains non-string or non-numeric types. + + The reason why we have this function is because the ``self.json`` field must be a + dict with only string values. This is because ``render_template`` will fail + for numerical values. + """ + coerce = _deep_string_coerce + if isinstance(content, str): + return content + elif isinstance( + content, + ( + int, + float, + ), + ): + # Databricks can tolerate either numeric or string types in the API backend. + return str(content) + elif isinstance(content, (list, tuple)): + return [coerce(e, f"{json_path}[{i}]") for i, e in enumerate(content)] + elif isinstance(content, dict): + return {k: coerce(v, f"{json_path}[{k}]") for k, v in list(content.items())} + else: + param_type = type(content) + msg = f"Type {param_type} used for parameter {json_path} is not a number or a string" + raise AirflowException(msg) + + +def _handle_databricks_operator_execution(operator, hook, log, context) -> None: + """ + Handles the Airflow + Databricks lifecycle logic for a Databricks operator + + :param operator: Databricks operator being handled + :param context: Airflow context + """ + if operator.do_xcom_push: + context["ti"].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id) + log.info("Run submitted with run_id: %s", operator.run_id) + run_page_url = hook.get_run_page_url(operator.run_id) + if operator.do_xcom_push: + context["ti"].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url) + + log.info("View run status, Spark UI, and logs at %s", run_page_url) + while True: + run_state = hook.get_run_state(operator.run_id) + if run_state.is_terminal: + if run_state.is_successful: + log.info("%s completed successfully.", operator.task_id) + log.info("View run status, Spark UI, and logs at %s", run_page_url) + return + else: + error_message = ( + f"{operator.task_id} failed with terminal state: {run_state}" + ) + raise AirflowException(error_message) + else: + log.info("%s in run state: %s", operator.task_id, run_state) + log.info("View run status, Spark UI, and logs at %s", run_page_url) + log.info("Sleeping for %s seconds.", operator.polling_period_seconds) + time.sleep(operator.polling_period_seconds) + + +class DatabricksSubmitRunOperator(BaseOperator): + """ + Submits a Spark job run to Databricks using the + `api/2.0/jobs/runs/submit + `_ + API endpoint. + + There are two ways to instantiate this operator. + + In the first way, you can take the JSON payload that you typically use + to call the ``api/2.0/jobs/runs/submit`` endpoint and pass it directly + to our ``DatabricksSubmitRunOperator`` through the ``json`` parameter. + For example :: + + json = { + 'new_cluster': { + 'spark_version': '2.1.0-db3-scala2.11', + 'num_workers': 2 + }, + 'notebook_task': { + 'notebook_path': '/Users/airflow@example.com/PrepareData', + }, + } + notebook_run = DatabricksSubmitRunOperator(task_id='notebook_run', json=json) + + Another way to accomplish the same thing is to use the named parameters + of the ``DatabricksSubmitRunOperator`` directly. Note that there is exactly + one named parameter for each top level parameter in the ``runs/submit`` + endpoint. In this method, your code would look like this: :: + + new_cluster = { + 'spark_version': '2.1.0-db3-scala2.11', + 'num_workers': 2 + } + notebook_task = { + 'notebook_path': '/Users/airflow@example.com/PrepareData', + } + notebook_run = DatabricksSubmitRunOperator( + task_id='notebook_run', + new_cluster=new_cluster, + notebook_task=notebook_task) + + In the case where both the json parameter **AND** the named parameters + are provided, they will be merged together. If there are conflicts during the merge, + the named parameters will take precedence and override the top level ``json`` keys. + + Currently the named parameters that ``DatabricksSubmitRunOperator`` supports are + - ``spark_jar_task`` + - ``notebook_task`` + - ``spark_python_task`` + - ``spark_submit_task`` + - ``new_cluster`` + - ``existing_cluster_id`` + - ``libraries`` + - ``run_name`` + - ``timeout_seconds`` + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DatabricksSubmitRunOperator` + + :param json: A JSON object containing API parameters which will be passed + directly to the ``api/2.0/jobs/runs/submit`` endpoint. The other named parameters + (i.e. ``spark_jar_task``, ``notebook_task``..) to this operator will + be merged with this json dictionary if they are provided. + If there are conflicts during the merge, the named parameters will + take precedence and override the top level json keys. (templated) + + .. seealso:: + For more information about templating see :ref:`jinja-templating`. + https://docs.databricks.com/api/latest/jobs.html#runs-submit + :type json: dict + :param spark_jar_task: The main class and parameters for the JAR task. Note that + the actual JAR is specified in the ``libraries``. + *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` + *OR* ``spark_submit_task`` should be specified. + This field will be templated. + + .. seealso:: + https://docs.databricks.com/api/latest/jobs.html#jobssparkjartask + :type spark_jar_task: dict + :param notebook_task: The notebook path and parameters for the notebook task. + *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` + *OR* ``spark_submit_task`` should be specified. + This field will be templated. + + .. seealso:: + https://docs.databricks.com/api/latest/jobs.html#jobsnotebooktask + :type notebook_task: dict + :param spark_python_task: The python file path and parameters to run the python file with. + *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` + *OR* ``spark_submit_task`` should be specified. + This field will be templated. + + .. seealso:: + https://docs.databricks.com/api/latest/jobs.html#jobssparkpythontask + :type spark_python_task: dict + :param spark_submit_task: Parameters needed to run a spark-submit command. + *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` + *OR* ``spark_submit_task`` should be specified. + This field will be templated. + + .. seealso:: + https://docs.databricks.com/api/latest/jobs.html#jobssparksubmittask + :type spark_submit_task: dict + :param new_cluster: Specs for a new cluster on which this task will be run. + *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified. + This field will be templated. + + .. seealso:: + https://docs.databricks.com/api/latest/jobs.html#jobsclusterspecnewcluster + :type new_cluster: dict + :param existing_cluster_id: ID for existing cluster on which to run this task. + *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified. + This field will be templated. + :type existing_cluster_id: str + :param libraries: Libraries which this run will use. + This field will be templated. + + .. seealso:: + https://docs.databricks.com/api/latest/libraries.html#managedlibrarieslibrary + :type libraries: list of dicts + :param run_name: The run name used for this task. + By default this will be set to the Airflow ``task_id``. This ``task_id`` is a + required parameter of the superclass ``BaseOperator``. + This field will be templated. + :type run_name: str + :param timeout_seconds: The timeout for this run. By default a value of 0 is used + which means to have no timeout. + This field will be templated. + :type timeout_seconds: int32 + :param databricks_conn_id: The name of the Airflow connection to use. + By default and in the common case this will be ``databricks_default``. To use + token based authentication, provide the key ``token`` in the extra field for the + connection and create the key ``host`` and leave the ``host`` field empty. + :type databricks_conn_id: str + :param polling_period_seconds: Controls the rate which we poll for the result of + this run. By default the operator will poll every 30 seconds. + :type polling_period_seconds: int + :param databricks_retry_limit: Amount of times retry if the Databricks backend is + unreachable. Its value must be greater than or equal to 1. + :type databricks_retry_limit: int + :param databricks_retry_delay: Number of seconds to wait between retries (it + might be a floating point number). + :type databricks_retry_delay: float + :param do_xcom_push: Whether we should push run_id and run_page_url to xcom. + :type do_xcom_push: bool + """ + + # Used in airflow.models.BaseOperator + template_fields = ("json",) + # Databricks brand color (blue) under white text + ui_color = "#1CB1C2" + ui_fgcolor = "#fff" + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + json: Optional[Any] = None, + spark_jar_task: Optional[Dict[str, str]] = None, + notebook_task: Optional[Dict[str, str]] = None, + spark_python_task: Optional[Dict[str, Union[str, List[str]]]] = None, + spark_submit_task: Optional[Dict[str, List[str]]] = None, + new_cluster: Optional[Dict[str, object]] = None, + existing_cluster_id: Optional[str] = None, + libraries: Optional[List[Dict[str, str]]] = None, + run_name: Optional[str] = None, + timeout_seconds: Optional[int] = None, + databricks_conn_id: str = "databricks_default", + polling_period_seconds: int = 30, + databricks_retry_limit: int = 3, + databricks_retry_delay: int = 1, + do_xcom_push: bool = False, + **kwargs, + ) -> None: + """Creates a new ``DatabricksSubmitRunOperator``.""" + super().__init__(**kwargs) + self.json = json or {} + self.databricks_conn_id = databricks_conn_id + self.polling_period_seconds = polling_period_seconds + self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay + if spark_jar_task is not None: + self.json["spark_jar_task"] = spark_jar_task + if notebook_task is not None: + self.json["notebook_task"] = notebook_task + if spark_python_task is not None: + self.json["spark_python_task"] = spark_python_task + if spark_submit_task is not None: + self.json["spark_submit_task"] = spark_submit_task + if new_cluster is not None: + self.json["new_cluster"] = new_cluster + if existing_cluster_id is not None: + self.json["existing_cluster_id"] = existing_cluster_id + if libraries is not None: + self.json["libraries"] = libraries + if run_name is not None: + self.json["run_name"] = run_name + if timeout_seconds is not None: + self.json["timeout_seconds"] = timeout_seconds + if "run_name" not in self.json: + self.json["run_name"] = run_name or kwargs["task_id"] + + self.json = _deep_string_coerce(self.json) + # This variable will be used in case our task gets killed. + self.run_id = None + self.do_xcom_push = do_xcom_push + + def _get_hook(self) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + ) + + def execute(self, context): + hook = self._get_hook() + self.run_id = hook.submit_run(self.json) + _handle_databricks_operator_execution(self, hook, self.log, context) + + def on_kill(self): + hook = self._get_hook() + hook.cancel_run(self.run_id) + self.log.info( + "Task: %s with run_id: %s was requested to be cancelled.", + self.task_id, + self.run_id, + ) + + +class DatabricksRunNowOperator(BaseOperator): + """ + Runs an existing Spark job run to Databricks using the + `api/2.0/jobs/run-now + `_ + API endpoint. + + There are two ways to instantiate this operator. + + In the first way, you can take the JSON payload that you typically use + to call the ``api/2.0/jobs/run-now`` endpoint and pass it directly + to our ``DatabricksRunNowOperator`` through the ``json`` parameter. + For example :: + + json = { + "job_id": 42, + "notebook_params": { + "dry-run": "true", + "oldest-time-to-consider": "1457570074236" + } + } + + notebook_run = DatabricksRunNowOperator(task_id='notebook_run', json=json) + + Another way to accomplish the same thing is to use the named parameters + of the ``DatabricksRunNowOperator`` directly. Note that there is exactly + one named parameter for each top level parameter in the ``run-now`` + endpoint. In this method, your code would look like this: :: + + job_id=42 + + notebook_params = { + "dry-run": "true", + "oldest-time-to-consider": "1457570074236" + } + + python_params = ["douglas adams", "42"] + + spark_submit_params = ["--class", "org.apache.spark.examples.SparkPi"] + + notebook_run = DatabricksRunNowOperator( + job_id=job_id, + notebook_params=notebook_params, + python_params=python_params, + spark_submit_params=spark_submit_params + ) + + In the case where both the json parameter **AND** the named parameters + are provided, they will be merged together. If there are conflicts during the merge, + the named parameters will take precedence and override the top level ``json`` keys. + + Currently the named parameters that ``DatabricksRunNowOperator`` supports are + - ``job_id`` + - ``json`` + - ``notebook_params`` + - ``python_params`` + - ``spark_submit_params`` + + + :param job_id: the job_id of the existing Databricks job. + This field will be templated. + + .. seealso:: + https://docs.databricks.com/api/latest/jobs.html#run-now + :type job_id: str + :param json: A JSON object containing API parameters which will be passed + directly to the ``api/2.0/jobs/run-now`` endpoint. The other named parameters + (i.e. ``notebook_params``, ``spark_submit_params``..) to this operator will + be merged with this json dictionary if they are provided. + If there are conflicts during the merge, the named parameters will + take precedence and override the top level json keys. (templated) + + .. seealso:: + For more information about templating see :ref:`jinja-templating`. + https://docs.databricks.com/api/latest/jobs.html#run-now + :type json: dict + :param notebook_params: A dict from keys to values for jobs with notebook task, + e.g. "notebook_params": {"name": "john doe", "age": "35"}. + The map is passed to the notebook and will be accessible through the + dbutils.widgets.get function. See Widgets for more information. + If not specified upon run-now, the triggered run will use the + job’s base parameters. notebook_params cannot be + specified in conjunction with jar_params. The json representation + of this field (i.e. {"notebook_params":{"name":"john doe","age":"35"}}) + cannot exceed 10,000 bytes. + This field will be templated. + + .. seealso:: + https://docs.databricks.com/user-guide/notebooks/widgets.html + :type notebook_params: dict + :param python_params: A list of parameters for jobs with python tasks, + e.g. "python_params": ["john doe", "35"]. + The parameters will be passed to python file as command line parameters. + If specified upon run-now, it would overwrite the parameters specified in + job setting. + The json representation of this field (i.e. {"python_params":["john doe","35"]}) + cannot exceed 10,000 bytes. + This field will be templated. + + .. seealso:: + https://docs.databricks.com/api/latest/jobs.html#run-now + :type python_params: list[str] + :param spark_submit_params: A list of parameters for jobs with spark submit task, + e.g. "spark_submit_params": ["--class", "org.apache.spark.examples.SparkPi"]. + The parameters will be passed to spark-submit script as command line parameters. + If specified upon run-now, it would overwrite the parameters specified + in job setting. + The json representation of this field cannot exceed 10,000 bytes. + This field will be templated. + + .. seealso:: + https://docs.databricks.com/api/latest/jobs.html#run-now + :type spark_submit_params: list[str] + :param timeout_seconds: The timeout for this run. By default a value of 0 is used + which means to have no timeout. + This field will be templated. + :type timeout_seconds: int32 + :param databricks_conn_id: The name of the Airflow connection to use. + By default and in the common case this will be ``databricks_default``. To use + token based authentication, provide the key ``token`` in the extra field for the + connection and create the key ``host`` and leave the ``host`` field empty. + :type databricks_conn_id: str + :param polling_period_seconds: Controls the rate which we poll for the result of + this run. By default the operator will poll every 30 seconds. + :type polling_period_seconds: int + :param databricks_retry_limit: Amount of times retry if the Databricks backend is + unreachable. Its value must be greater than or equal to 1. + :type databricks_retry_limit: int + :param do_xcom_push: Whether we should push run_id and run_page_url to xcom. + :type do_xcom_push: bool + """ + + # Used in airflow.models.BaseOperator + template_fields = ("json",) + # Databricks brand color (blue) under white text + ui_color = "#1CB1C2" + ui_fgcolor = "#fff" + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + job_id: Optional[str] = None, + json: Optional[Any] = None, + notebook_params: Optional[Dict[str, str]] = None, + python_params: Optional[List[str]] = None, + spark_submit_params: Optional[List[str]] = None, + databricks_conn_id: str = "databricks_default", + polling_period_seconds: int = 30, + databricks_retry_limit: int = 3, + databricks_retry_delay: int = 1, + do_xcom_push: bool = False, + **kwargs, + ) -> None: + """Creates a new ``DatabricksRunNowOperator``.""" + super().__init__(**kwargs) + self.json = json or {} + self.databricks_conn_id = databricks_conn_id + self.polling_period_seconds = polling_period_seconds + self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay + + if job_id is not None: + self.json["job_id"] = job_id + if notebook_params is not None: + self.json["notebook_params"] = notebook_params + if python_params is not None: + self.json["python_params"] = python_params + if spark_submit_params is not None: + self.json["spark_submit_params"] = spark_submit_params + + self.json = _deep_string_coerce(self.json) + # This variable will be used in case our task gets killed. + self.run_id = None + self.do_xcom_push = do_xcom_push + + def _get_hook(self) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + ) + + def execute(self, context): + hook = self._get_hook() + self.run_id = hook.run_now(self.json) + _handle_databricks_operator_execution(self, hook, self.log, context) + + def on_kill(self): + hook = self._get_hook() + hook.cancel_run(self.run_id) + self.log.info( + "Task: %s with run_id: %s was requested to be cancelled.", + self.task_id, + self.run_id, + ) diff --git a/reference/providers/databricks/provider.yaml b/reference/providers/databricks/provider.yaml new file mode 100644 index 0000000..d787b5c --- /dev/null +++ b/reference/providers/databricks/provider.yaml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-databricks +name: Databricks +description: | + `Databricks `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Databricks + external-doc-url: https://databricks.com/ + how-to-guide: + - /docs/apache-airflow-providers-databricks/operators.rst + logo: /integration-logos/databricks/Databricks.png + tags: [service] +operators: + - integration-name: Databricks + python-modules: + - airflow.providers.databricks.operators.databricks + +hooks: + - integration-name: Databricks + python-modules: + - airflow.providers.databricks.hooks.databricks + +hook-class-names: + - airflow.providers.databricks.hooks.databricks.DatabricksHook diff --git a/reference/providers/datadog/CHANGELOG.rst b/reference/providers/datadog/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/datadog/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/datadog/__init__.py b/reference/providers/datadog/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/datadog/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/datadog/hooks/__init__.py b/reference/providers/datadog/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/datadog/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/datadog/hooks/datadog.py b/reference/providers/datadog/hooks/datadog.py new file mode 100644 index 0000000..b49ef48 --- /dev/null +++ b/reference/providers/datadog/hooks/datadog.py @@ -0,0 +1,181 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +from typing import Any, Dict, List, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.utils.log.logging_mixin import LoggingMixin +from datadog import api, initialize + + +class DatadogHook(BaseHook, LoggingMixin): + """ + Uses datadog API to send metrics of practically anything measurable, + so it's possible to track # of db records inserted/deleted, records read + from file and many other useful metrics. + + Depends on the datadog API, which has to be deployed on the same server where + Airflow runs. + + :param datadog_conn_id: The connection to datadog, containing metadata for api keys. + :param datadog_conn_id: str + """ + + def __init__(self, datadog_conn_id: str = "datadog_default") -> None: + super().__init__() + conn = self.get_connection(datadog_conn_id) + self.api_key = conn.extra_dejson.get("api_key", None) + self.app_key = conn.extra_dejson.get("app_key", None) + self.source_type_name = conn.extra_dejson.get("source_type_name", None) + + # If the host is populated, it will use that hostname instead. + # for all metric submissions. + self.host = conn.host + + if self.api_key is None: + raise AirflowException( + "api_key must be specified in the Datadog connection details" + ) + + self.log.info("Setting up api keys for Datadog") + initialize(api_key=self.api_key, app_key=self.app_key) + + def validate_response(self, response: Dict[str, Any]) -> None: + """Validate Datadog response""" + if response["status"] != "ok": + self.log.error("Datadog returned: %s", response) + raise AirflowException("Error status received from Datadog") + + def send_metric( + self, + metric_name: str, + datapoint: Union[float, int], + tags: Optional[List[str]] = None, + type_: Optional[str] = None, + interval: Optional[int] = None, + ) -> Dict[str, Any]: + """ + Sends a single datapoint metric to DataDog + + :param metric_name: The name of the metric + :type metric_name: str + :param datapoint: A single integer or float related to the metric + :type datapoint: int or float + :param tags: A list of tags associated with the metric + :type tags: list + :param type_: Type of your metric: gauge, rate, or count + :type type_: str + :param interval: If the type of the metric is rate or count, define the corresponding interval + :type interval: int + """ + response = api.Metric.send( + metric=metric_name, + points=datapoint, + host=self.host, + tags=tags, + type=type_, + interval=interval, + ) + + self.validate_response(response) + return response + + def query_metric( + self, query: str, from_seconds_ago: int, to_seconds_ago: int + ) -> Dict[str, Any]: + """ + Queries datadog for a specific metric, potentially with some + function applied to it and returns the results. + + :param query: The datadog query to execute (see datadog docs) + :type query: str + :param from_seconds_ago: How many seconds ago to start querying for. + :type from_seconds_ago: int + :param to_seconds_ago: Up to how many seconds ago to query for. + :type to_seconds_ago: int + """ + now = int(time.time()) + + response = api.Metric.query( + start=now - from_seconds_ago, end=now - to_seconds_ago, query=query + ) + + self.validate_response(response) + return response + + # pylint: disable=too-many-arguments + def post_event( + self, + title: str, + text: str, + aggregation_key: Optional[str] = None, + alert_type: Optional[str] = None, + date_happened: Optional[int] = None, + handle: Optional[str] = None, + priority: Optional[str] = None, + related_event_id: Optional[int] = None, + tags: Optional[List[str]] = None, + device_name: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + Posts an event to datadog (processing finished, potentially alerts, other issues) + Think about this as a means to maintain persistence of alerts, rather than + alerting itself. + + :param title: The title of the event + :type title: str + :param text: The body of the event (more information) + :type text: str + :param aggregation_key: Key that can be used to aggregate this event in a stream + :type aggregation_key: str + :param alert_type: The alert type for the event, one of + ["error", "warning", "info", "success"] + :type alert_type: str + :param date_happened: POSIX timestamp of the event; defaults to now + :type date_happened: int + :handle: User to post the event as; defaults to owner of the application key used + to submit. + :param handle: str + :param priority: Priority to post the event as. ("normal" or "low", defaults to "normal") + :type priority: str + :param related_event_id: Post event as a child of the given event + :type related_event_id: id + :param tags: List of tags to apply to the event + :type tags: list[str] + :param device_name: device_name to post the event with + :type device_name: list + """ + response = api.Event.create( + title=title, + text=text, + aggregation_key=aggregation_key, + alert_type=alert_type, + date_happened=date_happened, + handle=handle, + priority=priority, + related_event_id=related_event_id, + tags=tags, + host=self.host, + device_name=device_name, + source_type_name=self.source_type_name, + ) + + self.validate_response(response) + return response diff --git a/reference/providers/datadog/provider.yaml b/reference/providers/datadog/provider.yaml new file mode 100644 index 0000000..be1d434 --- /dev/null +++ b/reference/providers/datadog/provider.yaml @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-datadog +name: Datadog +description: | + `Datadog `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Datadog + external-doc-url: https://www.datadoghq.com/ + logo: /integration-logos/datadog/datadog.png + tags: [service] + +sensors: + - integration-name: Datadog + python-modules: + - airflow.providers.datadog.sensors.datadog + +hooks: + - integration-name: Datadog + python-modules: + - airflow.providers.datadog.hooks.datadog diff --git a/reference/providers/datadog/sensors/__init__.py b/reference/providers/datadog/sensors/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/datadog/sensors/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/datadog/sensors/datadog.py b/reference/providers/datadog/sensors/datadog.py new file mode 100644 index 0000000..c894c80 --- /dev/null +++ b/reference/providers/datadog/sensors/datadog.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Callable, Dict, List, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.datadog.hooks.datadog import DatadogHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults +from datadog import api + + +class DatadogSensor(BaseSensorOperator): + """ + A sensor to listen, with a filter, to datadog event streams and determine + if some event was emitted. + + Depends on the datadog API, which has to be deployed on the same server where + Airflow runs. + + :param datadog_conn_id: The connection to datadog, containing metadata for api keys. + :param datadog_conn_id: str + """ + + ui_color = "#66c3dd" + + @apply_defaults + def __init__( + self, + *, + datadog_conn_id: str = "datadog_default", + from_seconds_ago: int = 3600, + up_to_seconds_from_now: int = 0, + priority: Optional[str] = None, + sources: Optional[str] = None, + tags: Optional[List[str]] = None, + response_check: Optional[Callable[[Dict[str, Any]], bool]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.datadog_conn_id = datadog_conn_id + self.from_seconds_ago = from_seconds_ago + self.up_to_seconds_from_now = up_to_seconds_from_now + self.priority = priority + self.sources = sources + self.tags = tags + self.response_check = response_check + + def poke(self, context: Dict[str, Any]) -> bool: + # This instantiates the hook, but doesn't need it further, + # because the API authenticates globally (unfortunately), + # but for airflow this shouldn't matter too much, because each + # task instance runs in its own process anyway. + DatadogHook(datadog_conn_id=self.datadog_conn_id) + + response = api.Event.query( + start=self.from_seconds_ago, + end=self.up_to_seconds_from_now, + priority=self.priority, + sources=self.sources, + tags=self.tags, + ) + + if isinstance(response, dict) and response.get("status", "ok") != "ok": + self.log.error("Unexpected Datadog result: %s", response) + raise AirflowException("Datadog returned unexpected result") + + if self.response_check: + # run content check on response + return self.response_check(response) + + # If no check was inserted, assume any event that matched yields true. + return len(response) > 0 diff --git a/reference/providers/dependencies.json b/reference/providers/dependencies.json new file mode 100644 index 0000000..e4f100d --- /dev/null +++ b/reference/providers/dependencies.json @@ -0,0 +1,88 @@ +{ + "airbyte": [ + "http" + ], + "amazon": [ + "apache.hive", + "exasol", + "ftp", + "google", + "imap", + "mongo", + "mysql", + "postgres", + "ssh" + ], + "apache.beam": [ + "google" + ], + "apache.druid": [ + "apache.hive" + ], + "apache.hive": [ + "amazon", + "microsoft.mssql", + "mysql", + "presto", + "samba", + "vertica" + ], + "apache.livy": [ + "http" + ], + "dingding": [ + "http" + ], + "discord": [ + "http" + ], + "google": [ + "amazon", + "apache.beam", + "apache.cassandra", + "cncf.kubernetes", + "facebook", + "microsoft.azure", + "microsoft.mssql", + "mysql", + "oracle", + "postgres", + "presto", + "salesforce", + "sftp", + "ssh" + ], + "hashicorp": [ + "google" + ], + "microsoft.azure": [ + "google", + "oracle" + ], + "microsoft.mssql": [ + "odbc" + ], + "mysql": [ + "amazon", + "presto", + "vertica" + ], + "opsgenie": [ + "http" + ], + "postgres": [ + "amazon" + ], + "salesforce": [ + "tableau" + ], + "sftp": [ + "ssh" + ], + "slack": [ + "http" + ], + "snowflake": [ + "slack" + ] +} diff --git a/reference/providers/dingding/CHANGELOG.rst b/reference/providers/dingding/CHANGELOG.rst new file mode 100644 index 0000000..27ce592 --- /dev/null +++ b/reference/providers/dingding/CHANGELOG.rst @@ -0,0 +1,38 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.2 +..... + +Bug fixes +~~~~~~~~~ + +* ``Replace deprecated doc links to the correct one (#14429)`` + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/dingding/__init__.py b/reference/providers/dingding/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/dingding/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/dingding/example_dags/__init__.py b/reference/providers/dingding/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/dingding/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/dingding/example_dags/example_dingding.py b/reference/providers/dingding/example_dags/example_dingding.py new file mode 100644 index 0000000..ad4faf0 --- /dev/null +++ b/reference/providers/dingding/example_dags/example_dingding.py @@ -0,0 +1,224 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example dag for using the DingdingOperator. +""" +from datetime import timedelta + +from airflow import DAG +from airflow.providers.dingding.operators.dingding import DingdingOperator +from airflow.utils.dates import days_ago + +args = { + "owner": "airflow", + "retries": 3, +} + + +# [START howto_operator_dingding_failure_callback] +def failure_callback(context): + """ + The function that will be executed on failure. + + :param context: The context of the executed task. + :type context: dict + """ + message = ( + "AIRFLOW TASK FAILURE TIPS:\n" + "DAG: {}\n" + "TASKS: {}\n" + "Reason: {}\n".format( + context["task_instance"].dag_id, + context["task_instance"].task_id, + context["exception"], + ) + ) + return DingdingOperator( + task_id="dingding_success_callback", + dingding_conn_id="dingding_default", + message_type="text", + message=message, + at_all=True, + ).execute(context) + + +args["on_failure_callback"] = failure_callback +# [END howto_operator_dingding_failure_callback] + +with DAG( + dag_id="example_dingding_operator", + default_args=args, + schedule_interval="@once", + dagrun_timeout=timedelta(minutes=60), + start_date=days_ago(2), + tags=["example"], +) as dag: + + # [START howto_operator_dingding] + text_msg_remind_none = DingdingOperator( + task_id="text_msg_remind_none", + dingding_conn_id="dingding_default", + message_type="text", + message="Airflow dingding text message remind none", + at_mobiles=None, + at_all=False, + ) + # [END howto_operator_dingding] + + text_msg_remind_specific = DingdingOperator( + task_id="text_msg_remind_specific", + dingding_conn_id="dingding_default", + message_type="text", + message="Airflow dingding text message remind specific users", + at_mobiles=["156XXXXXXXX", "130XXXXXXXX"], + at_all=False, + ) + + text_msg_remind_include_invalid = DingdingOperator( + task_id="text_msg_remind_include_invalid", + dingding_conn_id="dingding_default", + message_type="text", + message="Airflow dingding text message remind users including invalid", + # 123 is invalid user or user not in the group + at_mobiles=["156XXXXXXXX", "123"], + at_all=False, + ) + + # [START howto_operator_dingding_remind_users] + text_msg_remind_all = DingdingOperator( + task_id="text_msg_remind_all", + dingding_conn_id="dingding_default", + message_type="text", + message="Airflow dingding text message remind all users in group", + # list of user phone/email here in the group + # when at_all is specific will cover at_mobiles + at_mobiles=["156XXXXXXXX", "130XXXXXXXX"], + at_all=True, + ) + # [END howto_operator_dingding_remind_users] + + link_msg = DingdingOperator( + task_id="link_msg", + dingding_conn_id="dingding_default", + message_type="link", + message={ + "title": "Airflow dingding link message", + "text": "Airflow official documentation link", + "messageUrl": "https://airflow.apache.org", + "picURL": "https://airflow.apache.org/_images/pin_large.png", + }, + ) + + # [START howto_operator_dingding_rich_text] + markdown_msg = DingdingOperator( + task_id="markdown_msg", + dingding_conn_id="dingding_default", + message_type="markdown", + message={ + "title": "Airflow dingding markdown message", + "text": "# Markdown message title\n" + "content content .. \n" + "### sub-title\n" + "![logo](https://airflow.apache.org/_images/pin_large.png)", + }, + at_mobiles=["156XXXXXXXX"], + at_all=False, + ) + # [END howto_operator_dingding_rich_text] + + single_action_card_msg = DingdingOperator( + task_id="single_action_card_msg", + dingding_conn_id="dingding_default", + message_type="actionCard", + message={ + "title": "Airflow dingding single actionCard message", + "text": "Airflow dingding single actionCard message\n" + "![logo](https://airflow.apache.org/_images/pin_large.png)\n" + "This is a official logo in Airflow website.", + "hideAvatar": "0", + "btnOrientation": "0", + "singleTitle": "read more", + "singleURL": "https://airflow.apache.org", + }, + ) + + multi_action_card_msg = DingdingOperator( + task_id="multi_action_card_msg", + dingding_conn_id="dingding_default", + message_type="actionCard", + message={ + "title": "Airflow dingding multi actionCard message", + "text": "Airflow dingding multi actionCard message\n" + "![logo](https://airflow.apache.org/_images/pin_large.png)\n" + "Airflow documentation and GitHub", + "hideAvatar": "0", + "btnOrientation": "0", + "btns": [ + { + "title": "Airflow Documentation", + "actionURL": "https://airflow.apache.org", + }, + { + "title": "Airflow GitHub", + "actionURL": "https://github.com/apache/airflow", + }, + ], + }, + ) + + feed_card_msg = DingdingOperator( + task_id="feed_card_msg", + dingding_conn_id="dingding_default", + message_type="feedCard", + message={ + "links": [ + { + "title": "Airflow DAG feed card", + "messageURL": "https://airflow.apache.org/docs/apache-airflow/stable/ui.html", + "picURL": "https://airflow.apache.org/_images/dags.png", + }, + { + "title": "Airflow tree feed card", + "messageURL": "https://airflow.apache.org/docs/apache-airflow/stable/ui.html", + "picURL": "https://airflow.apache.org/_images/tree.png", + }, + { + "title": "Airflow graph feed card", + "messageURL": "https://airflow.apache.org/docs/apache-airflow/stable/ui.html", + "picURL": "https://airflow.apache.org/_images/graph.png", + }, + ] + }, + ) + + msg_failure_callback = DingdingOperator( + task_id="msg_failure_callback", + dingding_conn_id="dingding_default", + message_type="not_support_msg_type", + message="", + ) + + [ + text_msg_remind_none, + text_msg_remind_specific, + text_msg_remind_include_invalid, + text_msg_remind_all, + ] >> link_msg >> markdown_msg >> [ + single_action_card_msg, + multi_action_card_msg, + ] >> feed_card_msg >> msg_failure_callback diff --git a/reference/providers/dingding/hooks/__init__.py b/reference/providers/dingding/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/dingding/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/dingding/hooks/dingding.py b/reference/providers/dingding/hooks/dingding.py new file mode 100644 index 0000000..8f51df9 --- /dev/null +++ b/reference/providers/dingding/hooks/dingding.py @@ -0,0 +1,131 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from typing import List, Optional, Union + +import requests +from airflow.exceptions import AirflowException +from airflow.providers.http.hooks.http import HttpHook +from requests import Session + + +class DingdingHook(HttpHook): + """ + This hook allows you send Dingding message using Dingding custom bot. + Get Dingding token from conn_id.password. And prefer set domain to + conn_id.host, if not will use default ``https://oapi.dingtalk.com``. + + For more detail message in + `Dingding custom bot `_ + + :param dingding_conn_id: The name of the Dingding connection to use + :type dingding_conn_id: str + :param message_type: Message type you want to send to Dingding, support five type so far + including text, link, markdown, actionCard, feedCard + :type message_type: str + :param message: The message send to Dingding chat group + :type message: str or dict + :param at_mobiles: Remind specific users with this message + :type at_mobiles: list[str] + :param at_all: Remind all people in group or not. If True, will overwrite ``at_mobiles`` + :type at_all: bool + """ + + def __init__( + self, + dingding_conn_id="dingding_default", + message_type: str = "text", + message: Optional[Union[str, dict]] = None, + at_mobiles: Optional[List[str]] = None, + at_all: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__(http_conn_id=dingding_conn_id, *args, **kwargs) # type: ignore[misc] + self.message_type = message_type + self.message = message + self.at_mobiles = at_mobiles + self.at_all = at_all + + def _get_endpoint(self) -> str: + """Get Dingding endpoint for sending message.""" + conn = self.get_connection(self.http_conn_id) + token = conn.password + if not token: + raise AirflowException( + "Dingding token is requests but get nothing, check you conn_id configuration." + ) + return f"robot/send?access_token={token}" + + def _build_message(self) -> str: + """ + Build different type of Dingding message + As most commonly used type, text message just need post message content + rather than a dict like ``{'content': 'message'}`` + """ + if self.message_type in ["text", "markdown"]: + data = { + "msgtype": self.message_type, + self.message_type: {"content": self.message} + if self.message_type == "text" + else self.message, + "at": {"atMobiles": self.at_mobiles, "isAtAll": self.at_all}, + } + else: + data = {"msgtype": self.message_type, self.message_type: self.message} + return json.dumps(data) + + def get_conn(self, headers: Optional[dict] = None) -> Session: + """ + Overwrite HttpHook get_conn because just need base_url and headers and + not don't need generic params + + :param headers: additional headers to be passed through as a dictionary + :type headers: dict + """ + conn = self.get_connection(self.http_conn_id) + self.base_url = conn.host if conn.host else "https://oapi.dingtalk.com" + session = requests.Session() + if headers: + session.headers.update(headers) + return session + + def send(self) -> None: + """Send Dingding message""" + support_type = ["text", "link", "markdown", "actionCard", "feedCard"] + if self.message_type not in support_type: + raise ValueError( + "DingdingWebhookHook only support {} " + "so far, but receive {}".format(support_type, self.message_type) + ) + + data = self._build_message() + self.log.info("Sending Dingding type %s message %s", self.message_type, data) + resp = self.run( + endpoint=self._get_endpoint(), + data=data, + headers={"Content-Type": "application/json"}, + ) + + # Dingding success send message will with errcode equal to 0 + if int(resp.json().get("errcode")) != 0: + raise AirflowException( + "Send Dingding message failed, receive error " f"message {resp.text}" + ) + self.log.info("Success Send Dingding message") diff --git a/reference/providers/dingding/operators/__init__.py b/reference/providers/dingding/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/dingding/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/dingding/operators/dingding.py b/reference/providers/dingding/operators/dingding.py new file mode 100644 index 0000000..ee57cca --- /dev/null +++ b/reference/providers/dingding/operators/dingding.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.dingding.hooks.dingding import DingdingHook +from airflow.utils.decorators import apply_defaults + + +class DingdingOperator(BaseOperator): + """ + This operator allows you send Dingding message using Dingding custom bot. + Get Dingding token from conn_id.password. And prefer set domain to + conn_id.host, if not will use default ``https://oapi.dingtalk.com``. + + For more detail message in + `Dingding custom bot `_ + + :param dingding_conn_id: The name of the Dingding connection to use + :type dingding_conn_id: str + :param message_type: Message type you want to send to Dingding, support five type so far + including text, link, markdown, actionCard, feedCard + :type message_type: str + :param message: The message send to Dingding chat group + :type message: str or dict + :param at_mobiles: Remind specific users with this message + :type at_mobiles: list[str] + :param at_all: Remind all people in group or not. If True, will overwrite ``at_mobiles`` + :type at_all: bool + """ + + template_fields = ("message",) + ui_color = "#4ea4d4" # Dingding icon color + + @apply_defaults + def __init__( + self, + *, + dingding_conn_id: str = "dingding_default", + message_type: str = "text", + message: Union[str, dict, None] = None, + at_mobiles: Optional[List[str]] = None, + at_all: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dingding_conn_id = dingding_conn_id + self.message_type = message_type + self.message = message + self.at_mobiles = at_mobiles + self.at_all = at_all + + def execute(self, context) -> None: + self.log.info("Sending Dingding message.") + hook = DingdingHook( + self.dingding_conn_id, + self.message_type, + self.message, + self.at_mobiles, + self.at_all, + ) + hook.send() diff --git a/reference/providers/dingding/provider.yaml b/reference/providers/dingding/provider.yaml new file mode 100644 index 0000000..3592f32 --- /dev/null +++ b/reference/providers/dingding/provider.yaml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-dingding +name: Dingding +description: | + `Dingding `__ + +versions: + - 1.0.2 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Dingding + external-doc-url: https://oapi.dingtalk.com + how-to-guide: + - /docs/apache-airflow-providers-dingding/operators.rst + logo: /integration-logos/dingding/Dingding.png + tags: [service] + +operators: + - integration-name: Dingding + python-modules: + - airflow.providers.dingding.operators.dingding + +hooks: + - integration-name: IBM Cloudant + python-modules: + - airflow.providers.dingding.hooks.dingding diff --git a/reference/providers/discord/CHANGELOG.rst b/reference/providers/discord/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/discord/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/discord/__init__.py b/reference/providers/discord/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/discord/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/discord/hooks/__init__.py b/reference/providers/discord/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/discord/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/discord/hooks/discord_webhook.py b/reference/providers/discord/hooks/discord_webhook.py new file mode 100644 index 0000000..586a7bc --- /dev/null +++ b/reference/providers/discord/hooks/discord_webhook.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import json +import re +from typing import Any, Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.http.hooks.http import HttpHook + + +class DiscordWebhookHook(HttpHook): + """ + This hook allows you to post messages to Discord using incoming webhooks. + Takes a Discord connection ID with a default relative webhook endpoint. The + default endpoint can be overridden using the webhook_endpoint parameter + (https://discordapp.com/developers/docs/resources/webhook). + + Each Discord webhook can be pre-configured to use a specific username and + avatar_url. You can override these defaults in this hook. + + :param http_conn_id: Http connection ID with host as "https://discord.com/api/" and + default webhook endpoint in the extra field in the form of + {"webhook_endpoint": "webhooks/{webhook.id}/{webhook.token}"} + :type http_conn_id: str + :param webhook_endpoint: Discord webhook endpoint in the form of + "webhooks/{webhook.id}/{webhook.token}" + :type webhook_endpoint: str + :param message: The message you want to send to your Discord channel + (max 2000 characters) + :type message: str + :param username: Override the default username of the webhook + :type username: str + :param avatar_url: Override the default avatar of the webhook + :type avatar_url: str + :param tts: Is a text-to-speech message + :type tts: bool + :param proxy: Proxy to use to make the Discord webhook call + :type proxy: str + """ + + def __init__( + self, + http_conn_id: Optional[str] = None, + webhook_endpoint: Optional[str] = None, + message: str = "", + username: Optional[str] = None, + avatar_url: Optional[str] = None, + tts: bool = False, + proxy: Optional[str] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.http_conn_id: Any = http_conn_id + self.webhook_endpoint = self._get_webhook_endpoint( + http_conn_id, webhook_endpoint + ) + self.message = message + self.username = username + self.avatar_url = avatar_url + self.tts = tts + self.proxy = proxy + + def _get_webhook_endpoint( + self, http_conn_id: Optional[str], webhook_endpoint: Optional[str] + ) -> str: + """ + Given a Discord http_conn_id, return the default webhook endpoint or override if a + webhook_endpoint is manually supplied. + + :param http_conn_id: The provided connection ID + :param webhook_endpoint: The manually provided webhook endpoint + :return: Webhook endpoint (str) to use + """ + if webhook_endpoint: + endpoint = webhook_endpoint + elif http_conn_id: + conn = self.get_connection(http_conn_id) + extra = conn.extra_dejson + endpoint = extra.get("webhook_endpoint", "") + else: + raise AirflowException( + "Cannot get webhook endpoint: No valid Discord webhook endpoint or http_conn_id supplied." + ) + + # make sure endpoint matches the expected Discord webhook format + if not re.match("^webhooks/[0-9]+/[a-zA-Z0-9_-]+$", endpoint): + raise AirflowException( + 'Expected Discord webhook endpoint in the form of "webhooks/{webhook.id}/{webhook.token}".' + ) + + return endpoint + + def _build_discord_payload(self) -> str: + """ + Construct the Discord JSON payload. All relevant parameters are combined here + to a valid Discord JSON payload. + + :return: Discord payload (str) to send + """ + payload: Dict[str, Any] = {} + + if self.username: + payload["username"] = self.username + if self.avatar_url: + payload["avatar_url"] = self.avatar_url + + payload["tts"] = self.tts + + if len(self.message) <= 2000: + payload["content"] = self.message + else: + raise AirflowException( + "Discord message length must be 2000 or fewer characters." + ) + + return json.dumps(payload) + + def execute(self) -> None: + """Execute the Discord webhook call""" + proxies = {} + if self.proxy: + # we only need https proxy for Discord + proxies = {"https": self.proxy} + + discord_payload = self._build_discord_payload() + + self.run( + endpoint=self.webhook_endpoint, + data=discord_payload, + headers={"Content-type": "application/json"}, + extra_options={"proxies": proxies}, + ) diff --git a/reference/providers/discord/operators/__init__.py b/reference/providers/discord/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/discord/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/discord/operators/discord_webhook.py b/reference/providers/discord/operators/discord_webhook.py new file mode 100644 index 0000000..9d2ea26 --- /dev/null +++ b/reference/providers/discord/operators/discord_webhook.py @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookHook +from airflow.providers.http.operators.http import SimpleHttpOperator +from airflow.utils.decorators import apply_defaults + + +class DiscordWebhookOperator(SimpleHttpOperator): + """ + This operator allows you to post messages to Discord using incoming webhooks. + Takes a Discord connection ID with a default relative webhook endpoint. The + default endpoint can be overridden using the webhook_endpoint parameter + (https://discordapp.com/developers/docs/resources/webhook). + + Each Discord webhook can be pre-configured to use a specific username and + avatar_url. You can override these defaults in this operator. + + :param http_conn_id: Http connection ID with host as "https://discord.com/api/" and + default webhook endpoint in the extra field in the form of + {"webhook_endpoint": "webhooks/{webhook.id}/{webhook.token}"} + :type http_conn_id: str + :param webhook_endpoint: Discord webhook endpoint in the form of + "webhooks/{webhook.id}/{webhook.token}" + :type webhook_endpoint: str + :param message: The message you want to send to your Discord channel + (max 2000 characters). (templated) + :type message: str + :param username: Override the default username of the webhook. (templated) + :type username: str + :param avatar_url: Override the default avatar of the webhook + :type avatar_url: str + :param tts: Is a text-to-speech message + :type tts: bool + :param proxy: Proxy to use to make the Discord webhook call + :type proxy: str + """ + + template_fields = ["username", "message"] + + @apply_defaults + def __init__( + self, + *, + http_conn_id: Optional[str] = None, + webhook_endpoint: Optional[str] = None, + message: str = "", + username: Optional[str] = None, + avatar_url: Optional[str] = None, + tts: bool = False, + proxy: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(endpoint=webhook_endpoint, **kwargs) + + if not http_conn_id: + raise AirflowException("No valid Discord http_conn_id supplied.") + + self.http_conn_id = http_conn_id + self.webhook_endpoint = webhook_endpoint + self.message = message + self.username = username + self.avatar_url = avatar_url + self.tts = tts + self.proxy = proxy + self.hook: Optional[DiscordWebhookHook] = None + + def execute(self, context: Dict) -> None: + """Call the DiscordWebhookHook to post message""" + self.hook = DiscordWebhookHook( + self.http_conn_id, + self.webhook_endpoint, + self.message, + self.username, + self.avatar_url, + self.tts, + self.proxy, + ) + self.hook.execute() diff --git a/reference/providers/discord/provider.yaml b/reference/providers/discord/provider.yaml new file mode 100644 index 0000000..e65831a --- /dev/null +++ b/reference/providers/discord/provider.yaml @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-discord +name: Discord +description: | + `Discord `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Discord + external-doc-url: https://discordapp.com + logo: /integration-logos/discord/Discord.png + tags: [service] + +operators: + - integration-name: Discord + python-modules: + - airflow.providers.discord.operators.discord_webhook + +hooks: + - integration-name: Discord + python-modules: + - airflow.providers.discord.hooks.discord_webhook diff --git a/reference/providers/docker/CHANGELOG.rst b/reference/providers/docker/CHANGELOG.rst new file mode 100644 index 0000000..27e8ef1 --- /dev/null +++ b/reference/providers/docker/CHANGELOG.rst @@ -0,0 +1,45 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.2 +..... + +Bug fixes +~~~~~~~~~ + +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` + +1.0.1 +..... + +Updated documentation and readme files. + +Bug fixes +~~~~~~~~~ + +* ``Remove failed DockerOperator tasks with auto_remove=True (#13532) (#13993)`` +* ``Fix error on DockerSwarmOperator with auto_remove True (#13532) (#13852)`` + + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/docker/__init__.py b/reference/providers/docker/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/docker/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/docker/example_dags/__init__.py b/reference/providers/docker/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/docker/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/docker/example_dags/example_docker.py b/reference/providers/docker/example_dags/example_docker.py new file mode 100644 index 0000000..9903c84 --- /dev/null +++ b/reference/providers/docker/example_dags/example_docker.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import timedelta + +from airflow import DAG +from airflow.operators.bash import BashOperator +from airflow.providers.docker.operators.docker import DockerOperator +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), +} + +dag = DAG( + "docker_sample", + default_args=default_args, + schedule_interval=timedelta(minutes=10), + start_date=days_ago(2), +) + +t1 = BashOperator(task_id="print_date", bash_command="date", dag=dag) + +t2 = BashOperator(task_id="sleep", bash_command="sleep 5", retries=3, dag=dag) + +t3 = DockerOperator( + api_version="1.19", + docker_url="tcp://localhost:2375", # Set your docker URL + command="/bin/sleep 30", + image="centos:latest", + network_mode="bridge", + task_id="docker_op_tester", + dag=dag, +) + + +t4 = BashOperator(task_id="print_hello", bash_command='echo "hello world!!!"', dag=dag) + + +t1 >> t2 +t1 >> t3 +t3 >> t4 diff --git a/reference/providers/docker/example_dags/example_docker_copy_data.py b/reference/providers/docker/example_dags/example_docker_copy_data.py new file mode 100644 index 0000000..46bbbda --- /dev/null +++ b/reference/providers/docker/example_dags/example_docker_copy_data.py @@ -0,0 +1,119 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring +""" +This sample "listen to directory". move the new file and print it, +using docker-containers. +The following operators are being used: DockerOperator, +BashOperator & ShortCircuitOperator. +TODO: Review the workflow, change it accordingly to + your environment & enable the code. +""" + +from datetime import timedelta + +from airflow import DAG +from airflow.operators.bash import BashOperator +from airflow.operators.python import ShortCircuitOperator +from airflow.providers.docker.operators.docker import DockerOperator +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), +} + +dag = DAG( + "docker_sample_copy_data", + default_args=default_args, + schedule_interval=timedelta(minutes=10), + start_date=days_ago(2), +) + +locate_file_cmd = """ + sleep 10 + find {{params.source_location}} -type f -printf "%f\n" | head -1 +""" + +t_view = BashOperator( + task_id="view_file", + bash_command=locate_file_cmd, + do_xcom_push=True, + params={"source_location": "/your/input_dir/path"}, + dag=dag, +) + + +def is_data_available(*args, **kwargs): + """Return True if data exists in XCom table for view_file task, false otherwise.""" + ti = kwargs["ti"] + data = ti.xcom_pull(key=None, task_ids="view_file") + return not data == "" + + +t_is_data_available = ShortCircuitOperator( + task_id="check_if_data_available", python_callable=is_data_available, dag=dag +) + +t_move = DockerOperator( + api_version="1.19", + docker_url="tcp://localhost:2375", # replace it with swarm/docker endpoint + image="centos:latest", + network_mode="bridge", + volumes=[ + "/your/host/input_dir/path:/your/input_dir/path", + "/your/host/output_dir/path:/your/output_dir/path", + ], + command=[ + "/bin/bash", + "-c", + "/bin/sleep 30; " + "/bin/mv {{params.source_location}}/{{ ti.xcom_pull('view_file') }} {{params.target_location}};" + "/bin/echo '{{params.target_location}}/{{ ti.xcom_pull('view_file') }}';", + ], + task_id="move_data", + do_xcom_push=True, + params={ + "source_location": "/your/input_dir/path", + "target_location": "/your/output_dir/path", + }, + dag=dag, +) + +print_templated_cmd = """ + cat {{ ti.xcom_pull('move_data') }} +""" + +t_print = DockerOperator( + api_version="1.19", + docker_url="tcp://localhost:2375", + image="centos:latest", + volumes=["/your/host/output_dir/path:/your/output_dir/path"], + command=print_templated_cmd, + task_id="print", + dag=dag, +) + +t_view.set_downstream(t_is_data_available) +t_is_data_available.set_downstream(t_move) +t_move.set_downstream(t_print) diff --git a/reference/providers/docker/example_dags/example_docker_swarm.py b/reference/providers/docker/example_dags/example_docker_swarm.py new file mode 100644 index 0000000..8d90241 --- /dev/null +++ b/reference/providers/docker/example_dags/example_docker_swarm.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import timedelta + +from airflow import DAG +from airflow.providers.docker.operators.docker_swarm import DockerSwarmOperator +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, +} + +dag = DAG( + "docker_swarm_sample", + default_args=default_args, + schedule_interval=timedelta(minutes=10), + start_date=days_ago(1), + catchup=False, +) + +with dag as dag: + t1 = DockerSwarmOperator( + api_version="auto", + docker_url="tcp://localhost:2375", # Set your docker URL + command="/bin/sleep 10", + image="centos:latest", + auto_remove=True, + task_id="sleep_with_swarm", + ) diff --git a/reference/providers/docker/hooks/__init__.py b/reference/providers/docker/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/docker/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/docker/hooks/docker.py b/reference/providers/docker/hooks/docker.py new file mode 100644 index 0000000..6cc8ba1 --- /dev/null +++ b/reference/providers/docker/hooks/docker.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.utils.log.logging_mixin import LoggingMixin +from docker import APIClient +from docker.errors import APIError + + +class DockerHook(BaseHook, LoggingMixin): + """ + Interact with a Docker Daemon or Registry. + + :param docker_conn_id: ID of the Airflow connection where + credentials and extra configuration are stored + :type docker_conn_id: str + """ + + conn_name_attr = "docker_conn_id" + default_conn_name = "docker_default" + conn_type = "docker" + hook_name = "Docker" + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ["schema"], + "relabeling": { + "host": "Registry URL", + "login": "Username", + }, + } + + def __init__( + self, + docker_conn_id: str = default_conn_name, + base_url: Optional[str] = None, + version: Optional[str] = None, + tls: Optional[str] = None, + ) -> None: + super().__init__() + if not base_url: + raise AirflowException("No Docker base URL provided") + if not version: + raise AirflowException("No Docker API version provided") + + conn = self.get_connection(docker_conn_id) + if not conn.host: + raise AirflowException("No Docker URL provided") + if not conn.login: + raise AirflowException("No username provided") + extra_options = conn.extra_dejson + + self.__base_url = base_url + self.__version = version + self.__tls = tls + if conn.port: + self.__registry = f"{conn.host}:{conn.port}" + else: + self.__registry = conn.host + self.__username = conn.login + self.__password = conn.password + self.__email = extra_options.get("email") + self.__reauth = extra_options.get("reauth") != "no" + + def get_conn(self) -> APIClient: + client = APIClient( + base_url=self.__base_url, version=self.__version, tls=self.__tls + ) + self.__login(client) + return client + + def __login(self, client) -> int: + self.log.debug("Logging into Docker") + try: + client.login( + username=self.__username, + password=self.__password, + registry=self.__registry, + email=self.__email, + reauth=self.__reauth, + ) + self.log.debug("Login successful") + except APIError as docker_error: + self.log.error("Docker login failed: %s", str(docker_error)) + raise AirflowException(f"Docker login failed: {docker_error}") diff --git a/reference/providers/docker/operators/__init__.py b/reference/providers/docker/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/docker/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/docker/operators/docker.py b/reference/providers/docker/operators/docker.py new file mode 100644 index 0000000..1da41c7 --- /dev/null +++ b/reference/providers/docker/operators/docker.py @@ -0,0 +1,363 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Implements Docker operator""" +import ast +from tempfile import TemporaryDirectory +from typing import Dict, Iterable, List, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.docker.hooks.docker import DockerHook +from airflow.utils.decorators import apply_defaults +from docker import APIClient, tls + + +# pylint: disable=too-many-instance-attributes +class DockerOperator(BaseOperator): + """ + Execute a command inside a docker container. + + A temporary directory is created on the host and + mounted into a container to allow storing files + that together exceed the default disk size of 10GB in a container. + The path to the mounted directory can be accessed + via the environment variable ``AIRFLOW_TMP_DIR``. + + If a login to a private registry is required prior to pulling the image, a + Docker connection needs to be configured in Airflow and the connection ID + be provided with the parameter ``docker_conn_id``. + + :param image: Docker image from which to create the container. + If image tag is omitted, "latest" will be used. + :type image: str + :param api_version: Remote API version. Set to ``auto`` to automatically + detect the server's version. + :type api_version: str + :param command: Command to be run in the container. (templated) + :type command: str or list + :param container_name: Name of the container. Optional (templated) + :type container_name: str or None + :param cpus: Number of CPUs to assign to the container. + This value gets multiplied with 1024. See + https://docs.docker.com/engine/reference/run/#cpu-share-constraint + :type cpus: float + :param docker_url: URL of the host running the docker daemon. + Default is unix://var/run/docker.sock + :type docker_url: str + :param environment: Environment variables to set in the container. (templated) + :type environment: dict + :param private_environment: Private environment variables to set in the container. + These are not templated, and hidden from the website. + :type private_environment: dict + :param force_pull: Pull the docker image on every run. Default is False. + :type force_pull: bool + :param mem_limit: Maximum amount of memory the container can use. + Either a float value, which represents the limit in bytes, + or a string like ``128m`` or ``1g``. + :type mem_limit: float or str + :param host_tmp_dir: Specify the location of the temporary directory on the host which will + be mapped to tmp_dir. If not provided defaults to using the standard system temp directory. + :type host_tmp_dir: str + :param network_mode: Network mode for the container. + :type network_mode: str + :param tls_ca_cert: Path to a PEM-encoded certificate authority + to secure the docker connection. + :type tls_ca_cert: str + :param tls_client_cert: Path to the PEM-encoded certificate + used to authenticate docker client. + :type tls_client_cert: str + :param tls_client_key: Path to the PEM-encoded key used to authenticate docker client. + :type tls_client_key: str + :param tls_hostname: Hostname to match against + the docker server certificate or False to disable the check. + :type tls_hostname: str or bool + :param tls_ssl_version: Version of SSL to use when communicating with docker daemon. + :type tls_ssl_version: str + :param tmp_dir: Mount point inside the container to + a temporary directory created on the host by the operator. + The path is also made available via the environment variable + ``AIRFLOW_TMP_DIR`` inside the container. + :type tmp_dir: str + :param user: Default user inside the docker container. + :type user: int or str + :param volumes: List of volumes to mount into the container, e.g. + ``['/host/path:/container/path', '/host/path2:/container/path2:ro']``. + :type volumes: list + :param working_dir: Working directory to + set on the container (equivalent to the -w switch the docker client) + :type working_dir: str + :param xcom_all: Push all the stdout or just the last line. + The default is False (last line). + :type xcom_all: bool + :param docker_conn_id: ID of the Airflow connection to use + :type docker_conn_id: str + :param dns: Docker custom DNS servers + :type dns: list[str] + :param dns_search: Docker custom DNS search domain + :type dns_search: list[str] + :param auto_remove: Auto-removal of the container on daemon side when the + container's process exits. + The default is False. + :type auto_remove: bool + :param shm_size: Size of ``/dev/shm`` in bytes. The size must be + greater than 0. If omitted uses system default. + :type shm_size: int + :param tty: Allocate pseudo-TTY to the container + This needs to be set see logs of the Docker container. + :type tty: bool + :param privileged: Give extended privileges to this container. + :type privileged: bool + :param cap_add: Include container capabilities + :type cap_add: list[str] + """ + + template_fields = ("command", "environment", "container_name") + template_ext = ( + ".sh", + ".bash", + ) + + # pylint: disable=too-many-arguments,too-many-locals + @apply_defaults + def __init__( + self, + *, + image: str, + api_version: Optional[str] = None, + command: Optional[Union[str, List[str]]] = None, + container_name: Optional[str] = None, + cpus: float = 1.0, + docker_url: str = "unix://var/run/docker.sock", + environment: Optional[Dict] = None, + private_environment: Optional[Dict] = None, + force_pull: bool = False, + mem_limit: Optional[Union[float, str]] = None, + host_tmp_dir: Optional[str] = None, + network_mode: Optional[str] = None, + tls_ca_cert: Optional[str] = None, + tls_client_cert: Optional[str] = None, + tls_client_key: Optional[str] = None, + tls_hostname: Optional[Union[str, bool]] = None, + tls_ssl_version: Optional[str] = None, + tmp_dir: str = "/tmp/airflow", + user: Optional[Union[str, int]] = None, + volumes: Optional[List[str]] = None, + working_dir: Optional[str] = None, + xcom_all: bool = False, + docker_conn_id: Optional[str] = None, + dns: Optional[List[str]] = None, + dns_search: Optional[List[str]] = None, + auto_remove: bool = False, + shm_size: Optional[int] = None, + tty: bool = False, + privileged: bool = False, + cap_add: Optional[Iterable[str]] = None, + extra_hosts: Optional[Dict[str, str]] = None, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self.api_version = api_version + self.auto_remove = auto_remove + self.command = command + self.container_name = container_name + self.cpus = cpus + self.dns = dns + self.dns_search = dns_search + self.docker_url = docker_url + self.environment = environment or {} + self._private_environment = private_environment or {} + self.force_pull = force_pull + self.image = image + self.mem_limit = mem_limit + self.host_tmp_dir = host_tmp_dir + self.network_mode = network_mode + self.tls_ca_cert = tls_ca_cert + self.tls_client_cert = tls_client_cert + self.tls_client_key = tls_client_key + self.tls_hostname = tls_hostname + self.tls_ssl_version = tls_ssl_version + self.tmp_dir = tmp_dir + self.user = user + self.volumes = volumes or [] + self.working_dir = working_dir + self.xcom_all = xcom_all + self.docker_conn_id = docker_conn_id + self.shm_size = shm_size + self.tty = tty + self.privileged = privileged + self.cap_add = cap_add + self.extra_hosts = extra_hosts + if kwargs.get("xcom_push") is not None: + raise AirflowException( + "'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead" + ) + + self.cli = None + self.container = None + + def get_hook(self) -> DockerHook: + """ + Retrieves hook for the operator. + + :return: The Docker Hook + """ + return DockerHook( + docker_conn_id=self.docker_conn_id, + base_url=self.docker_url, + version=self.api_version, + tls=self.__get_tls_config(), + ) + + def _run_image(self) -> Optional[str]: + """Run a Docker container with the provided image""" + self.log.info("Starting docker container from image %s", self.image) + + with TemporaryDirectory( + prefix="airflowtmp", dir=self.host_tmp_dir + ) as host_tmp_dir: + self.volumes.append(f"{host_tmp_dir}:{self.tmp_dir}") + + if not self.cli: + raise Exception("The 'cli' should be initialized before!") + self.container = self.cli.create_container( + command=self.get_command(), + name=self.container_name, + environment={**self.environment, **self._private_environment}, + host_config=self.cli.create_host_config( + auto_remove=False, + binds=self.volumes, + network_mode=self.network_mode, + shm_size=self.shm_size, + dns=self.dns, + dns_search=self.dns_search, + cpu_shares=int(round(self.cpus * 1024)), + mem_limit=self.mem_limit, + cap_add=self.cap_add, + extra_hosts=self.extra_hosts, + privileged=self.privileged, + ), + image=self.image, + user=self.user, + working_dir=self.working_dir, + tty=self.tty, + ) + + lines = self.cli.attach( + container=self.container["Id"], stdout=True, stderr=True, stream=True + ) + + self.cli.start(self.container["Id"]) + + line = "" + for line in lines: + line = line.strip() + if hasattr(line, "decode"): + # Note that lines returned can also be byte sequences so we have to handle decode here + line = line.decode("utf-8") + self.log.info(line) + + result = self.cli.wait(self.container["Id"]) + if result["StatusCode"] != 0: + if self.auto_remove: + self.cli.remove_container(self.container["Id"]) + raise AirflowException("docker container failed: " + repr(result)) + + # duplicated conditional logic because of expensive operation + ret = None + if self.do_xcom_push: + ret = ( + self.cli.logs(container=self.container["Id"]) + if self.xcom_all + else line.encode("utf-8") + ) + + if self.auto_remove: + self.cli.remove_container(self.container["Id"]) + + return ret + + def execute(self, context) -> Optional[str]: + self.cli = self._get_cli() + if not self.cli: + raise Exception("The 'cli' should be initialized before!") + + # Pull the docker image if `force_pull` is set or image does not exist locally + # pylint: disable=too-many-nested-blocks + if self.force_pull or not self.cli.images(name=self.image): + self.log.info("Pulling docker image %s", self.image) + latest_status = {} + for output in self.cli.pull(self.image, stream=True, decode=True): + if isinstance(output, str): + self.log.info("%s", output) + continue + if isinstance(output, dict) and "status" in output: + output_status = output["status"] + if "id" not in output: + self.log.info("%s", output_status) + continue + + output_id = output["id"] + if latest_status.get(output_id) != output_status: + self.log.info("%s: %s", output_id, output_status) + latest_status[output_id] = output_status + + self.environment["AIRFLOW_TMP_DIR"] = self.tmp_dir + return self._run_image() + + def _get_cli(self) -> APIClient: + if self.docker_conn_id: + return self.get_hook().get_conn() + else: + tls_config = self.__get_tls_config() + return APIClient( + base_url=self.docker_url, version=self.api_version, tls=tls_config + ) + + def get_command(self) -> Union[List[str], str]: + """ + Retrieve command(s). if command string starts with [, it returns the command list) + + :return: the command (or commands) + :rtype: str | List[str] + """ + if isinstance(self.command, str) and self.command.strip().find("[") == 0: + commands = ast.literal_eval(self.command) + else: + commands = self.command + return commands + + def on_kill(self) -> None: + if self.cli is not None: + self.log.info("Stopping docker container") + self.cli.stop(self.container["Id"]) + + def __get_tls_config(self) -> Optional[tls.TLSConfig]: + tls_config = None + if self.tls_ca_cert and self.tls_client_cert and self.tls_client_key: + # Ignore type error on SSL version here - it is deprecated and type annotation is wrong + # it should be string + tls_config = tls.TLSConfig( + ca_cert=self.tls_ca_cert, + client_cert=(self.tls_client_cert, self.tls_client_key), + verify=True, + ssl_version=self.tls_ssl_version, # noqa + assert_hostname=self.tls_hostname, + ) + self.docker_url = self.docker_url.replace("tcp://", "https://") + return tls_config diff --git a/reference/providers/docker/operators/docker_swarm.py b/reference/providers/docker/operators/docker_swarm.py new file mode 100644 index 0000000..749828f --- /dev/null +++ b/reference/providers/docker/operators/docker_swarm.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Run ephemeral Docker Swarm services""" +from typing import Optional + +import requests +from airflow.exceptions import AirflowException +from airflow.providers.docker.operators.docker import DockerOperator +from airflow.utils.decorators import apply_defaults +from airflow.utils.strings import get_random_string +from docker import types + + +class DockerSwarmOperator(DockerOperator): + """ + Execute a command as an ephemeral docker swarm service. + Example use-case - Using Docker Swarm orchestration to make one-time + scripts highly available. + + A temporary directory is created on the host and + mounted into a container to allow storing files + that together exceed the default disk size of 10GB in a container. + The path to the mounted directory can be accessed + via the environment variable ``AIRFLOW_TMP_DIR``. + + If a login to a private registry is required prior to pulling the image, a + Docker connection needs to be configured in Airflow and the connection ID + be provided with the parameter ``docker_conn_id``. + + :param image: Docker image from which to create the container. + If image tag is omitted, "latest" will be used. + :type image: str + :param api_version: Remote API version. Set to ``auto`` to automatically + detect the server's version. + :type api_version: str + :param auto_remove: Auto-removal of the container on daemon side when the + container's process exits. + The default is False. + :type auto_remove: bool + :param command: Command to be run in the container. (templated) + :type command: str or list + :param docker_url: URL of the host running the docker daemon. + Default is unix://var/run/docker.sock + :type docker_url: str + :param environment: Environment variables to set in the container. (templated) + :type environment: dict + :param force_pull: Pull the docker image on every run. Default is False. + :type force_pull: bool + :param mem_limit: Maximum amount of memory the container can use. + Either a float value, which represents the limit in bytes, + or a string like ``128m`` or ``1g``. + :type mem_limit: float or str + :param tls_ca_cert: Path to a PEM-encoded certificate authority + to secure the docker connection. + :type tls_ca_cert: str + :param tls_client_cert: Path to the PEM-encoded certificate + used to authenticate docker client. + :type tls_client_cert: str + :param tls_client_key: Path to the PEM-encoded key used to authenticate docker client. + :type tls_client_key: str + :param tls_hostname: Hostname to match against + the docker server certificate or False to disable the check. + :type tls_hostname: str or bool + :param tls_ssl_version: Version of SSL to use when communicating with docker daemon. + :type tls_ssl_version: str + :param tmp_dir: Mount point inside the container to + a temporary directory created on the host by the operator. + The path is also made available via the environment variable + ``AIRFLOW_TMP_DIR`` inside the container. + :type tmp_dir: str + :param user: Default user inside the docker container. + :type user: int or str + :param docker_conn_id: ID of the Airflow connection to use + :type docker_conn_id: str + :param tty: Allocate pseudo-TTY to the container of this service + This needs to be set see logs of the Docker container / service. + :type tty: bool + :param enable_logging: Show the application's logs in operator's logs. + Supported only if the Docker engine is using json-file or journald logging drivers. + The `tty` parameter should be set to use this with Python applications. + :type enable_logging: bool + """ + + @apply_defaults + def __init__(self, *, image: str, enable_logging: bool = True, **kwargs) -> None: + super().__init__(image=image, **kwargs) + + self.enable_logging = enable_logging + self.service = None + + def execute(self, context) -> None: + self.cli = self._get_cli() + + self.environment["AIRFLOW_TMP_DIR"] = self.tmp_dir + + return self._run_service() + + def _run_service(self) -> None: + self.log.info("Starting docker service from image %s", self.image) + if not self.cli: + raise Exception("The 'cli' should be initialized before!") + self.service = self.cli.create_service( + types.TaskTemplate( + container_spec=types.ContainerSpec( + image=self.image, + command=self.get_command(), + env=self.environment, + user=self.user, + tty=self.tty, + ), + restart_policy=types.RestartPolicy(condition="none"), + resources=types.Resources(mem_limit=self.mem_limit), + ), + name=f"airflow-{get_random_string()}", + labels={"name": f"airflow__{self.dag_id}__{self.task_id}"}, + ) + + self.log.info("Service started: %s", str(self.service)) + + # wait for the service to start the task + while not self.cli.tasks(filters={"service": self.service["ID"]}): + continue + + if self.enable_logging: + self._stream_logs_to_output() + + while True: + if self._has_service_terminated(): + self.log.info( + "Service status before exiting: %s", self._service_status() + ) + break + + if self.service and self._service_status() == "failed": + if self.auto_remove: + self.cli.remove_service(self.service["ID"]) + raise AirflowException("Service failed: " + repr(self.service)) + elif self.auto_remove: + if not self.service: + raise Exception("The 'service' should be initialized before!") + self.cli.remove_service(self.service["ID"]) + + def _service_status(self) -> Optional[str]: + if not self.cli: + raise Exception("The 'cli' should be initialized before!") + return self.cli.tasks(filters={"service": self.service["ID"]})[0]["Status"][ + "State" + ] + + def _has_service_terminated(self) -> bool: + status = self._service_status() + return status in ["failed", "complete"] + + def _stream_logs_to_output(self) -> None: + if not self.cli: + raise Exception("The 'cli' should be initialized before!") + if not self.service: + raise Exception("The 'service' should be initialized before!") + logs = self.cli.service_logs( + self.service["ID"], follow=True, stdout=True, stderr=True, is_tty=self.tty + ) + line = "" + while True: + try: + log = next(logs) + # TODO: Remove this clause once https://github.com/docker/docker-py/issues/931 is fixed + except requests.exceptions.ConnectionError: + # If the service log stream stopped sending messages, check if it the service has + # terminated. + if self._has_service_terminated(): + break + except StopIteration: + # If the service log stream terminated, stop fetching logs further. + break + else: + try: + log = log.decode() + except UnicodeDecodeError: + continue + if log == "\n": + self.log.info(line) + line = "" + else: + line += log + # flush any remaining log stream + if line: + self.log.info(line) + + def on_kill(self) -> None: + if self.cli is not None: + self.log.info("Removing docker service: %s", self.service["ID"]) + self.cli.remove_service(self.service["ID"]) diff --git a/reference/providers/docker/provider.yaml b/reference/providers/docker/provider.yaml new file mode 100644 index 0000000..f2fabab --- /dev/null +++ b/reference/providers/docker/provider.yaml @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-docker +name: Docker +description: | + `Docker `__ + +versions: + - 1.0.2 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Docker + external-doc-url: https://docs.docker.com/ + logo: /integration-logos/docker/Docker.png + tags: [software] + - integration-name: Docker Swarm + external-doc-url: https://docs.docker.com/engine/swarm/ + logo: /integration-logos/docker/Docker-Swarm.png + tags: [software] + +operators: + - integration-name: Docker + python-modules: + - airflow.providers.docker.operators.docker + - integration-name: Docker Swarm + python-modules: + - airflow.providers.docker.operators.docker_swarm + +hooks: + - integration-name: Docker + python-modules: + - airflow.providers.docker.hooks.docker + +hook-class-names: + - airflow.providers.docker.hooks.docker.DockerHook diff --git a/reference/providers/elasticsearch/CHANGELOG.rst b/reference/providers/elasticsearch/CHANGELOG.rst new file mode 100644 index 0000000..7e8e029 --- /dev/null +++ b/reference/providers/elasticsearch/CHANGELOG.rst @@ -0,0 +1,52 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.3 +..... + +Bug fixes +~~~~~~~~~ + +* ``Elasticsearch Provider: Fix logs downloading for tasks (#14686)`` + +1.0.2 +..... + +Bug fixes +~~~~~~~~~ + +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` + +1.0.1 +..... + +Updated documentation and readme files. + +Bug fixes +~~~~~~~~~ + +* ``Respect LogFormat when using ES logging with Json Format (#13310)`` + + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/elasticsearch/__init__.py b/reference/providers/elasticsearch/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/elasticsearch/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/elasticsearch/hooks/__init__.py b/reference/providers/elasticsearch/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/elasticsearch/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/elasticsearch/hooks/elasticsearch.py b/reference/providers/elasticsearch/hooks/elasticsearch.py new file mode 100644 index 0000000..402cbba --- /dev/null +++ b/reference/providers/elasticsearch/hooks/elasticsearch.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional + +from airflow.hooks.dbapi import DbApiHook +from airflow.models.connection import Connection as AirflowConnection +from es.elastic.api import Connection as ESConnection +from es.elastic.api import connect + + +class ElasticsearchHook(DbApiHook): + """Interact with Elasticsearch through the elasticsearch-dbapi.""" + + conn_name_attr = "elasticsearch_conn_id" + default_conn_name = "elasticsearch_default" + conn_type = "elasticsearch" + hook_name = "Elasticsearch" + + def __init__( + self, + schema: str = "http", + connection: Optional[AirflowConnection] = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.schema = schema + self.connection = connection + + def get_conn(self) -> ESConnection: + """Returns a elasticsearch connection object""" + conn_id = getattr(self, self.conn_name_attr) + conn = self.connection or self.get_connection(conn_id) + + conn_args = dict( + host=conn.host, + port=conn.port, + user=conn.login or None, + password=conn.password or None, + scheme=conn.schema or "http", + ) + + if conn.extra_dejson.get("http_compress", False): + conn_args["http_compress"] = bool(["http_compress"]) + + if conn.extra_dejson.get("timeout", False): + conn_args["timeout"] = conn.extra_dejson["timeout"] + + conn = connect(**conn_args) + + return conn + + def get_uri(self) -> str: + conn_id = getattr(self, self.conn_name_attr) + conn = self.connection or self.get_connection(conn_id) + + login = "" + if conn.login: + login = "{conn.login}:{conn.password}@".format(conn=conn) + host = conn.host + if conn.port is not None: + host += f":{conn.port}" + uri = "{conn.conn_type}+{conn.schema}://{login}{host}/".format( + conn=conn, login=login, host=host + ) + + extras_length = len(conn.extra_dejson) + if not extras_length: + return uri + + uri += "?" + + for arg_key, arg_value in conn.extra_dejson.items(): + extras_length -= 1 + uri += f"{arg_key}={arg_value}" + + if extras_length: + uri += "&" + + return uri diff --git a/reference/providers/elasticsearch/log/__init__.py b/reference/providers/elasticsearch/log/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/elasticsearch/log/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/elasticsearch/log/es_task_handler.py b/reference/providers/elasticsearch/log/es_task_handler.py new file mode 100644 index 0000000..044d4c8 --- /dev/null +++ b/reference/providers/elasticsearch/log/es_task_handler.py @@ -0,0 +1,383 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +import sys +from collections import defaultdict +from datetime import datetime +from time import time +from typing import List, Optional, Tuple +from urllib.parse import quote + +# Using `from elasticsearch import *` would break elasticsearch mocking used in unit test. +import elasticsearch +import pendulum +from airflow.configuration import conf +from airflow.models import TaskInstance +from airflow.utils import timezone +from airflow.utils.helpers import parse_template_string +from airflow.utils.log.file_task_handler import FileTaskHandler +from airflow.utils.log.json_formatter import JSONFormatter +from airflow.utils.log.logging_mixin import LoggingMixin +from elasticsearch_dsl import Search + +# Elasticsearch hosted log type +EsLogMsgType = List[Tuple[str, str]] + + +class ElasticsearchTaskHandler(FileTaskHandler, LoggingMixin): + """ + ElasticsearchTaskHandler is a python log handler that + reads logs from Elasticsearch. Note logs are not directly + indexed into Elasticsearch. Instead, it flushes logs + into local files. Additional software setup is required + to index the log into Elasticsearch, such as using + Filebeat and Logstash. + To efficiently query and sort Elasticsearch results, we assume each + log message has a field `log_id` consists of ti primary keys: + `log_id = {dag_id}-{task_id}-{execution_date}-{try_number}` + Log messages with specific log_id are sorted based on `offset`, + which is a unique integer indicates log message's order. + Timestamp here are unreliable because multiple log messages + might have the same timestamp. + """ + + PAGE = 0 + MAX_LINE_PER_PAGE = 1000 + LOG_NAME = "Elasticsearch" + + def __init__( # pylint: disable=too-many-arguments + self, + base_log_folder: str, + filename_template: str, + log_id_template: str, + end_of_log_mark: str, + write_stdout: bool, + json_format: bool, + json_fields: str, + host: str = "localhost:9200", + frontend: str = "localhost:5601", + es_kwargs: Optional[dict] = conf.getsection("elasticsearch_configs"), + ): + """ + :param base_log_folder: base folder to store logs locally + :param log_id_template: log id template + :param host: Elasticsearch host name + """ + es_kwargs = es_kwargs or {} + super().__init__(base_log_folder, filename_template) + self.closed = False + + self.log_id_template, self.log_id_jinja_template = parse_template_string( + log_id_template + ) + + self.client = elasticsearch.Elasticsearch([host], **es_kwargs) + + self.frontend = frontend + self.mark_end_on_close = True + self.end_of_log_mark = end_of_log_mark + self.write_stdout = write_stdout + self.json_format = json_format + self.json_fields = [label.strip() for label in json_fields.split(",")] + self.handler = None + self.context_set = False + + def _render_log_id(self, ti: TaskInstance, try_number: int) -> str: + if self.log_id_jinja_template: + jinja_context = ti.get_template_context() + jinja_context["try_number"] = try_number + return self.log_id_jinja_template.render(**jinja_context) + + if self.json_format: + execution_date = self._clean_execution_date(ti.execution_date) + else: + execution_date = ti.execution_date.isoformat() + return self.log_id_template.format( + dag_id=ti.dag_id, + task_id=ti.task_id, + execution_date=execution_date, + try_number=try_number, + ) + + @staticmethod + def _clean_execution_date(execution_date: datetime) -> str: + """ + Clean up an execution date so that it is safe to query in elasticsearch + by removing reserved characters. + # https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-query-string-query.html#_reserved_characters + + :param execution_date: execution date of the dag run. + """ + return execution_date.strftime("%Y_%m_%dT%H_%M_%S_%f") + + @staticmethod + def _group_logs_by_host(logs): + grouped_logs = defaultdict(list) + for log in logs: + key = getattr(log, "host", "default_host") + grouped_logs[key].append(log) + + # return items sorted by timestamp. + result = sorted( + grouped_logs.items(), key=lambda kv: getattr(kv[1][0], "message", "_") + ) + + return result + + def _read_grouped_logs(self): + return True + + def _read( + self, ti: TaskInstance, try_number: int, metadata: Optional[dict] = None + ) -> Tuple[EsLogMsgType, dict]: + """ + Endpoint for streaming log. + + :param ti: task instance object + :param try_number: try_number of the task instance + :param metadata: log metadata, + can be used for steaming log reading and auto-tailing. + :return: a list of tuple with host and log documents, metadata. + """ + if not metadata: + metadata = {"offset": 0} + if "offset" not in metadata: + metadata["offset"] = 0 + + offset = metadata["offset"] + log_id = self._render_log_id(ti, try_number) + + logs = self.es_read(log_id, offset, metadata) + logs_by_host = self._group_logs_by_host(logs) + + next_offset = offset if not logs else logs[-1].offset + + # Ensure a string here. Large offset numbers will get JSON.parsed incorrectly + # on the client. Sending as a string prevents this issue. + # https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER + metadata["offset"] = str(next_offset) + + # end_of_log_mark may contain characters like '\n' which is needed to + # have the log uploaded but will not be stored in elasticsearch. + loading_hosts = [ + item[0] + for item in logs_by_host + if item[-1][-1].message != self.end_of_log_mark.strip() + ] + metadata["end_of_log"] = False if not logs else len(loading_hosts) == 0 + + cur_ts = pendulum.now() + # Assume end of log after not receiving new log for 5 min, + # as executor heartbeat is 1 min and there might be some + # delay before Elasticsearch makes the log available. + if "last_log_timestamp" in metadata: + last_log_ts = timezone.parse(metadata["last_log_timestamp"]) + if ( + cur_ts.diff(last_log_ts).in_minutes() >= 5 + or "max_offset" in metadata + and int(offset) >= int(metadata["max_offset"]) + ): + metadata["end_of_log"] = True + + if int(offset) != int(next_offset) or "last_log_timestamp" not in metadata: + metadata["last_log_timestamp"] = str(cur_ts) + + # If we hit the end of the log, remove the actual end_of_log message + # to prevent it from showing in the UI. + def concat_logs(lines): + log_range = ( + (len(lines) - 1) + if lines[-1].message == self.end_of_log_mark.strip() + else len(lines) + ) + return "\n".join([self._format_msg(lines[i]) for i in range(log_range)]) + + message = [(host, concat_logs(hosted_log)) for host, hosted_log in logs_by_host] + + return message, metadata + + def _format_msg(self, log_line): + """Format ES Record to match settings.LOG_FORMAT when used with json_format""" + # Using formatter._style.format makes it future proof i.e. + # if we change the formatter style from '%' to '{' or '$', this will still work + if self.json_format: + try: + # pylint: disable=protected-access + return self.formatter._style.format(_ESJsonLogFmt(**log_line.to_dict())) + except Exception: # noqa pylint: disable=broad-except + pass + + # Just a safe-guard to preserve backwards-compatibility + return log_line.message + + def es_read(self, log_id: str, offset: str, metadata: dict) -> list: + """ + Returns the logs matching log_id in Elasticsearch and next offset. + Returns '' if no log is found or there was an error. + + :param log_id: the log_id of the log to read. + :type log_id: str + :param offset: the offset start to read log from. + :type offset: str + :param metadata: log metadata, used for steaming log download. + :type metadata: dict + """ + # Offset is the unique key for sorting logs given log_id. + search = ( + Search(using=self.client) + .query("match_phrase", log_id=log_id) + .sort("offset") + ) + + search = search.filter("range", offset={"gt": int(offset)}) + max_log_line = search.count() + if ( + "download_logs" in metadata + and metadata["download_logs"] + and "max_offset" not in metadata + ): + try: + if max_log_line > 0: + metadata["max_offset"] = ( + search[max_log_line - 1].execute()[-1].offset + ) + else: + metadata["max_offset"] = 0 + except Exception: # pylint: disable=broad-except + self.log.exception( + "Could not get current log size with log_id: %s", log_id + ) + + logs = [] + if max_log_line != 0: + try: + + logs = search[ + self.MAX_LINE_PER_PAGE * self.PAGE : self.MAX_LINE_PER_PAGE + ].execute() + except Exception as e: # pylint: disable=broad-except + self.log.exception( + "Could not read log with log_id: %s, error: %s", log_id, str(e) + ) + + return logs + + def set_context(self, ti: TaskInstance) -> None: + """ + Provide task_instance context to airflow task handler. + + :param ti: task instance object + """ + self.mark_end_on_close = not ti.raw + + if self.json_format: + self.formatter = JSONFormatter( + fmt=self.formatter._fmt, # pylint: disable=protected-access + json_fields=self.json_fields, + extras={ + "dag_id": str(ti.dag_id), + "task_id": str(ti.task_id), + "execution_date": self._clean_execution_date(ti.execution_date), + "try_number": str(ti.try_number), + "log_id": self._render_log_id(ti, ti.try_number), + "offset": int(time() * (10 ** 9)), + }, + ) + + if self.write_stdout: + if self.context_set: + # We don't want to re-set up the handler if this logger has + # already been initialized + return + + self.handler = logging.StreamHandler(stream=sys.__stdout__) # type: ignore + self.handler.setLevel(self.level) # type: ignore + self.handler.setFormatter(self.formatter) # type: ignore + else: + super().set_context(ti) + self.context_set = True + + def close(self) -> None: + # When application exit, system shuts down all handlers by + # calling close method. Here we check if logger is already + # closed to prevent uploading the log to remote storage multiple + # times when `logging.shutdown` is called. + if self.closed: + return + + if not self.mark_end_on_close: + self.closed = True + return + + # Case which context of the handler was not set. + if self.handler is None: + self.closed = True + return + + # Reopen the file stream, because FileHandler.close() would be called + # first in logging.shutdown() and the stream in it would be set to None. + if self.handler.stream is None or self.handler.stream.closed: + self.handler.stream = ( + self.handler._open() + ) # pylint: disable=protected-access + + # Mark the end of file using end of log mark, + # so we know where to stop while auto-tailing. + self.handler.stream.write(self.end_of_log_mark) + + if self.write_stdout: + self.handler.close() + sys.stdout = sys.__stdout__ + + super().close() + + self.closed = True + + @property + def log_name(self) -> str: + """The log name""" + return self.LOG_NAME + + def get_external_log_url(self, task_instance: TaskInstance, try_number: int) -> str: + """ + Creates an address for an external log collecting service. + + :param task_instance: task instance object + :type: task_instance: TaskInstance + :param try_number: task instance try_number to read logs from. + :type try_number: Optional[int] + :return: URL to the external log collection service + :rtype: str + """ + log_id = self.log_id_template.format( + dag_id=task_instance.dag_id, + task_id=task_instance.task_id, + execution_date=task_instance.execution_date, + try_number=try_number, + ) + url = "https://" + self.frontend.format(log_id=quote(log_id)) + return url + + +class _ESJsonLogFmt: + """Helper class to read ES Logs and re-format it to match settings.LOG_FORMAT""" + + # A separate class is needed because 'self.formatter._style.format' uses '.__dict__' + def __init__(self, **kwargs): + self.__dict__.update(kwargs) diff --git a/reference/providers/elasticsearch/provider.yaml b/reference/providers/elasticsearch/provider.yaml new file mode 100644 index 0000000..dcef7f0 --- /dev/null +++ b/reference/providers/elasticsearch/provider.yaml @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-elasticsearch +name: Elasticsearch +description: | + `Elasticsearch `__ + +versions: + - 1.0.3 + - 1.0.2 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Elasticsearch + external-doc-url: https://www.elastic.co/elasticsearch + logo: /integration-logos/elasticsearch/Elasticsearch.png + tags: [software] + +hooks: + - integration-name: Elasticsearch + python-modules: + - airflow.providers.elasticsearch.hooks.elasticsearch + +hook-class-names: + - airflow.providers.elasticsearch.hooks.elasticsearch.ElasticsearchHook diff --git a/reference/providers/exasol/CHANGELOG.rst b/reference/providers/exasol/CHANGELOG.rst new file mode 100644 index 0000000..02e855b --- /dev/null +++ b/reference/providers/exasol/CHANGELOG.rst @@ -0,0 +1,43 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.1.1 +..... + +Bug fixes +~~~~~~~~~ + +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` + +1.1.0 +..... + +Updated documentation and readme files. + +Features +~~~~~~~~ + +* ``Add ExasolToS3Operator (#13847)`` + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/exasol/__init__.py b/reference/providers/exasol/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/exasol/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/exasol/hooks/__init__.py b/reference/providers/exasol/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/exasol/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/exasol/hooks/exasol.py b/reference/providers/exasol/hooks/exasol.py new file mode 100644 index 0000000..3d808a7 --- /dev/null +++ b/reference/providers/exasol/hooks/exasol.py @@ -0,0 +1,227 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from contextlib import closing +from typing import Any, Dict, List, Optional, Tuple, Union + +import pyexasol +from airflow.hooks.dbapi import DbApiHook +from pyexasol import ExaConnection + + +class ExasolHook(DbApiHook): + """ + Interact with Exasol. + You can specify the pyexasol ``compression``, ``encryption``, ``json_lib`` + and ``client_name`` parameters in the extra field of your connection + as ``{"compression": True, "json_lib": "rapidjson", etc}``. + See `pyexasol reference + `_ + for more details. + """ + + conn_name_attr = "exasol_conn_id" + default_conn_name = "exasol_default" + conn_type = "exasol" + hook_name = "Exasol" + supports_autocommit = True + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.schema = kwargs.pop("schema", None) + + def get_conn(self) -> ExaConnection: + conn_id = getattr(self, self.conn_name_attr) + conn = self.get_connection(conn_id) + conn_args = dict( + dsn=f"{conn.host}:{conn.port}", + user=conn.login, + password=conn.password, + schema=self.schema or conn.schema, + ) + # check for parameters in conn.extra + for arg_name, arg_val in conn.extra_dejson.items(): + if arg_name in ["compression", "encryption", "json_lib", "client_name"]: + conn_args[arg_name] = arg_val + + conn = pyexasol.connect(**conn_args) + return conn + + def get_pandas_df( + self, sql: Union[str, list], parameters: Optional[dict] = None, **kwargs + ) -> None: + """ + Executes the sql and returns a pandas dataframe + + :param sql: the sql statement to be executed (str) or a list of + sql statements to execute + :type sql: str or list + :param parameters: The parameters to render the SQL query with. + :type parameters: dict or iterable + :param kwargs: (optional) passed into pyexasol.ExaConnection.export_to_pandas method + :type kwargs: dict + """ + with closing(self.get_conn()) as conn: + conn.export_to_pandas(sql, query_params=parameters, **kwargs) + + def get_records( + self, sql: Union[str, list], parameters: Optional[dict] = None + ) -> List[Union[dict, Tuple[Any, ...]]]: + """ + Executes the sql and returns a set of records. + + :param sql: the sql statement to be executed (str) or a list of + sql statements to execute + :type sql: str or list + :param parameters: The parameters to render the SQL query with. + :type parameters: dict or iterable + """ + with closing(self.get_conn()) as conn: + with closing(conn.execute(sql, parameters)) as cur: + return cur.fetchall() + + def get_first( + self, sql: Union[str, list], parameters: Optional[dict] = None + ) -> Optional[Any]: + """ + Executes the sql and returns the first resulting row. + + :param sql: the sql statement to be executed (str) or a list of + sql statements to execute + :type sql: str or list + :param parameters: The parameters to render the SQL query with. + :type parameters: dict or iterable + """ + with closing(self.get_conn()) as conn: + with closing(conn.execute(sql, parameters)) as cur: + return cur.fetchone() + + def export_to_file( + self, + filename: str, + query_or_table: str, + query_params: Optional[Dict] = None, + export_params: Optional[Dict] = None, + ) -> None: + """ + Exports data to a file. + + :param filename: Path to the file to which the data has to be exported + :type filename: str + :param query_or_table: the sql statement to be executed or table name to export + :type query_or_table: str + :param query_params: Query parameters passed to underlying ``export_to_file`` + method of :class:`~pyexasol.connection.ExaConnection`. + :type query_params: dict + :param export_params: Extra parameters passed to underlying ``export_to_file`` + method of :class:`~pyexasol.connection.ExaConnection`. + :type export_params: dict + """ + self.log.info("Getting data from exasol") + with closing(self.get_conn()) as conn: + conn.export_to_file( + dst=filename, + query_or_table=query_or_table, + query_params=query_params, + export_params=export_params, + ) + self.log.info("Data saved to %s", filename) + + def run( + self, + sql: Union[str, list], + autocommit: bool = False, + parameters: Optional[dict] = None, + ) -> None: + """ + Runs a command or a list of commands. Pass a list of sql + statements to the sql parameter to get them to execute + sequentially + + :param sql: the sql statement to be executed (str) or a list of + sql statements to execute + :type sql: str or list + :param autocommit: What to set the connection's autocommit setting to + before executing the query. + :type autocommit: bool + :param parameters: The parameters to render the SQL query with. + :type parameters: dict or iterable + """ + if isinstance(sql, str): + sql = [sql] + + with closing(self.get_conn()) as conn: + if self.supports_autocommit: + self.set_autocommit(conn, autocommit) + + for query in sql: + self.log.info(query) + with closing(conn.execute(query, parameters)) as cur: + self.log.info(cur.row_count) + # If autocommit was set to False for db that supports autocommit, + # or if db does not supports autocommit, we do a manual commit. + if not self.get_autocommit(conn): + conn.commit() + + def set_autocommit(self, conn, autocommit: bool) -> None: + """ + Sets the autocommit flag on the connection + + :param conn: Connection to set autocommit setting to. + :type conn: connection object + :param autocommit: The autocommit setting to set. + :type autocommit: bool + """ + if not self.supports_autocommit and autocommit: + self.log.warning( + "%s connection doesn't support autocommit but autocommit activated.", + getattr(self, self.conn_name_attr), + ) + conn.set_autocommit(autocommit) + + def get_autocommit(self, conn) -> bool: + """ + Get autocommit setting for the provided connection. + Return True if autocommit is set. + Return False if autocommit is not set or set to False or conn + does not support autocommit. + + :param conn: Connection to get autocommit setting from. + :type conn: connection object + :return: connection autocommit setting. + :rtype: bool + """ + autocommit = conn.attr.get("autocommit") + if autocommit is None: + autocommit = super().get_autocommit(conn) + return autocommit + + @staticmethod + def _serialize_cell(cell, conn=None) -> object: + """ + Exasol will adapt all arguments to the execute() method internally, + hence we return cell without any conversion. + + :param cell: The cell to insert into the table + :type cell: object + :param conn: The database connection + :type conn: connection object + :return: The cell + :rtype: object + """ + return cell diff --git a/reference/providers/exasol/operators/__init__.py b/reference/providers/exasol/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/exasol/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/exasol/operators/exasol.py b/reference/providers/exasol/operators/exasol.py new file mode 100644 index 0000000..df2118f --- /dev/null +++ b/reference/providers/exasol/operators/exasol.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from airflow.models import BaseOperator +from airflow.providers.exasol.hooks.exasol import ExasolHook +from airflow.utils.decorators import apply_defaults + + +class ExasolOperator(BaseOperator): + """ + Executes sql code in a specific Exasol database + + :param sql: the sql code to be executed. (templated) + :type sql: Can receive a str representing a sql statement, + a list of str (sql statements), or reference to a template file. + Template reference are recognized by str ending in '.sql' + :param exasol_conn_id: reference to a specific Exasol database + :type exasol_conn_id: string + :param autocommit: if True, each command is automatically committed. + (default value: False) + :type autocommit: bool + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: dict + :param schema: (optional) name of the schema which overwrite defined one in connection + :type schema: string + """ + + template_fields = ("sql",) + template_ext = (".sql",) + ui_color = "#ededed" + + @apply_defaults + def __init__( + self, + *, + sql: str, + exasol_conn_id: str = "exasol_default", + autocommit: bool = False, + parameters: Optional[dict] = None, + schema: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.exasol_conn_id = exasol_conn_id + self.sql = sql + self.autocommit = autocommit + self.parameters = parameters + self.schema = schema + + def execute(self, context) -> None: + self.log.info("Executing: %s", self.sql) + hook = ExasolHook(exasol_conn_id=self.exasol_conn_id, schema=self.schema) + hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) diff --git a/reference/providers/exasol/provider.yaml b/reference/providers/exasol/provider.yaml new file mode 100644 index 0000000..672c0a8 --- /dev/null +++ b/reference/providers/exasol/provider.yaml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-exasol +name: Exasol +description: | + `Exasol `__ + +versions: + - 1.1.1 + - 1.1.0 + - 1.0.0 + +integrations: + - integration-name: Exasol + external-doc-url: https://docs.exasol.com/home.htm + logo: /integration-logos/exasol/Exasol.png + tags: [software] + +operators: + - integration-name: Exasol + python-modules: + - airflow.providers.exasol.operators.exasol + +hooks: + - integration-name: Exasol + python-modules: + - airflow.providers.exasol.hooks.exasol + +hook-class-names: + - airflow.providers.exasol.hooks.exasol.ExasolHook diff --git a/reference/providers/facebook/CHANGELOG.rst b/reference/providers/facebook/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/facebook/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/facebook/__init__.py b/reference/providers/facebook/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/facebook/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/facebook/ads/__init__.py b/reference/providers/facebook/ads/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/facebook/ads/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/facebook/ads/hooks/__init__.py b/reference/providers/facebook/ads/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/facebook/ads/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/facebook/ads/hooks/ads.py b/reference/providers/facebook/ads/hooks/ads.py new file mode 100644 index 0000000..b0b30d6 --- /dev/null +++ b/reference/providers/facebook/ads/hooks/ads.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Facebook Ads Reporting hooks""" +import time +from enum import Enum +from typing import Any, Dict, List + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from facebook_business.adobjects.adaccount import AdAccount +from facebook_business.adobjects.adreportrun import AdReportRun +from facebook_business.adobjects.adsinsights import AdsInsights +from facebook_business.api import FacebookAdsApi + + +class JobStatus(Enum): + """Available options for facebook async task status""" + + COMPLETED = "Job Completed" + STARTED = "Job Started" + RUNNING = "Job Running" + FAILED = "Job Failed" + SKIPPED = "Job Skipped" + + +class FacebookAdsReportingHook(BaseHook): + """ + Hook for the Facebook Ads API + + .. seealso:: + For more information on the Facebook Ads API, take a look at the API docs: + https://developers.facebook.com/docs/marketing-apis/ + + :param facebook_conn_id: Airflow Facebook Ads connection ID + :type facebook_conn_id: str + :param api_version: The version of Facebook API. Default to v6.0 + :type api_version: str + + """ + + conn_name_attr = "facebook_conn_id" + default_conn_name = "facebook_default" + conn_type = "facebook_social" + hook_name = "Facebook Ads" + + def __init__( + self, + facebook_conn_id: str = default_conn_name, + api_version: str = "v6.0", + ) -> None: + super().__init__() + self.facebook_conn_id = facebook_conn_id + self.api_version = api_version + self.client_required_fields = [ + "app_id", + "app_secret", + "access_token", + "account_id", + ] + + def _get_service(self) -> FacebookAdsApi: + """Returns Facebook Ads Client using a service account""" + config = self.facebook_ads_config + return FacebookAdsApi.init( + app_id=config["app_id"], + app_secret=config["app_secret"], + access_token=config["access_token"], + account_id=config["account_id"], + api_version=self.api_version, + ) + + @cached_property + def facebook_ads_config(self) -> Dict: + """ + Gets Facebook ads connection from meta db and sets + facebook_ads_config attribute with returned config file + """ + self.log.info("Fetching fb connection: %s", self.facebook_conn_id) + conn = self.get_connection(self.facebook_conn_id) + config = conn.extra_dejson + missing_keys = self.client_required_fields - config.keys() + if missing_keys: + message = f"{missing_keys} fields are missing" + raise AirflowException(message) + return config + + def bulk_facebook_report( + self, + params: Dict[str, Any], + fields: List[str], + sleep_time: int = 5, + ) -> List[AdsInsights]: + """ + Pulls data from the Facebook Ads API + + :param fields: List of fields that is obtained from Facebook. Found in AdsInsights.Field class. + https://developers.facebook.com/docs/marketing-api/insights/parameters/v6.0 + :type fields: List[str] + :param params: Parameters that determine the query for Facebook + https://developers.facebook.com/docs/marketing-api/insights/parameters/v6.0 + :type fields: Dict[str, Any] + :param sleep_time: Time to sleep when async call is happening + :type sleep_time: int + + :return: Facebook Ads API response, converted to Facebook Ads Row objects + :rtype: List[AdsInsights] + """ + api = self._get_service() + ad_account = AdAccount(api.get_default_account_id(), api=api) + _async = ad_account.get_insights(params=params, fields=fields, is_async=True) + while True: + request = _async.api_get() + async_status = request[AdReportRun.Field.async_status] + percent = request[AdReportRun.Field.async_percent_completion] + self.log.info( + "%s %s completed, async_status: %s", percent, "%", async_status + ) + if async_status == JobStatus.COMPLETED.value: + self.log.info("Job run completed") + break + if async_status in [JobStatus.SKIPPED.value, JobStatus.FAILED.value]: + message = f"{async_status}. Please retry." + raise AirflowException(message) + time.sleep(sleep_time) + report_run_id = _async.api_get()["report_run_id"] + report_object = AdReportRun(report_run_id, api=api) + insights = report_object.get_insights() + self.log.info("Extracting data from returned Facebook Ads Iterators") + return list(insights) diff --git a/reference/providers/facebook/provider.yaml b/reference/providers/facebook/provider.yaml new file mode 100644 index 0000000..afd197e --- /dev/null +++ b/reference/providers/facebook/provider.yaml @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-facebook +name: Facebook +description: | + `Facebook Ads `__ + +versions: + - 1.0.1 + - 1.0.0 +integrations: + - integration-name: Facebook Ads + external-doc-url: http://business.facebook.com + logo: /integration-logos/facebook/Facebook-Ads.png + tags: [service] + +hooks: + - integration-name: Facebook Ads + python-modules: + - airflow.providers.facebook.ads.hooks.ads + +hook-class-names: + - airflow.providers.facebook.ads.hooks.ads.FacebookAdsReportingHook diff --git a/reference/providers/ftp/CHANGELOG.rst b/reference/providers/ftp/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/ftp/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/ftp/__init__.py b/reference/providers/ftp/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/ftp/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/ftp/hooks/__init__.py b/reference/providers/ftp/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/ftp/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/ftp/hooks/ftp.py b/reference/providers/ftp/hooks/ftp.py new file mode 100644 index 0000000..43398e6 --- /dev/null +++ b/reference/providers/ftp/hooks/ftp.py @@ -0,0 +1,281 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import datetime +import ftplib +import os.path +from typing import Any, List, Optional + +from airflow.hooks.base import BaseHook + + +class FTPHook(BaseHook): + """ + Interact with FTP. + + Errors that may occur throughout but should be handled downstream. + You can specify mode for data transfers in the extra field of your + connection as ``{"passive": "true"}``. + """ + + conn_name_attr = "ftp_conn_id" + default_conn_name = "ftp_default" + conn_type = "ftp" + hook_name = "FTP" + + def __init__(self, ftp_conn_id: str = default_conn_name) -> None: + super().__init__() + self.ftp_conn_id = ftp_conn_id + self.conn: Optional[ftplib.FTP] = None + + def __enter__(self): + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self.conn is not None: + self.close_conn() + + def get_conn(self) -> ftplib.FTP: + """Returns a FTP connection object""" + if self.conn is None: + params = self.get_connection(self.ftp_conn_id) + pasv = params.extra_dejson.get("passive", True) + self.conn = ftplib.FTP(params.host, params.login, params.password) + self.conn.set_pasv(pasv) + + return self.conn + + def close_conn(self): + """ + Closes the connection. An error will occur if the + connection wasn't ever opened. + """ + conn = self.conn + conn.quit() + self.conn = None + + def describe_directory(self, path: str) -> dict: + """ + Returns a dictionary of {filename: {attributes}} for all files + on the remote system (where the MLSD command is supported). + + :param path: full path to the remote directory + :type path: str + """ + conn = self.get_conn() + conn.cwd(path) + files = dict(conn.mlsd()) + return files + + def list_directory(self, path: str) -> List[str]: + """ + Returns a list of files on the remote system. + + :param path: full path to the remote directory to list + :type path: str + """ + conn = self.get_conn() + conn.cwd(path) + + files = conn.nlst() + return files + + def create_directory(self, path: str) -> None: + """ + Creates a directory on the remote system. + + :param path: full path to the remote directory to create + :type path: str + """ + conn = self.get_conn() + conn.mkd(path) + + def delete_directory(self, path: str) -> None: + """ + Deletes a directory on the remote system. + + :param path: full path to the remote directory to delete + :type path: str + """ + conn = self.get_conn() + conn.rmd(path) + + def retrieve_file(self, remote_full_path, local_full_path_or_buffer, callback=None): + """ + Transfers the remote file to a local location. + + If local_full_path_or_buffer is a string path, the file will be put + at that location; if it is a file-like buffer, the file will + be written to the buffer but not closed. + + :param remote_full_path: full path to the remote file + :type remote_full_path: str + :param local_full_path_or_buffer: full path to the local file or a + file-like buffer + :type local_full_path_or_buffer: str or file-like buffer + :param callback: callback which is called each time a block of data + is read. if you do not use a callback, these blocks will be written + to the file or buffer passed in. if you do pass in a callback, note + that writing to a file or buffer will need to be handled inside the + callback. + [default: output_handle.write()] + :type callback: callable + + .. code-block:: python + + hook = FTPHook(ftp_conn_id='my_conn') + + remote_path = '/path/to/remote/file' + local_path = '/path/to/local/file' + + # with a custom callback (in this case displaying progress on each read) + def print_progress(percent_progress): + self.log.info('Percent Downloaded: %s%%' % percent_progress) + + total_downloaded = 0 + total_file_size = hook.get_size(remote_path) + output_handle = open(local_path, 'wb') + def write_to_file_with_progress(data): + total_downloaded += len(data) + output_handle.write(data) + percent_progress = (total_downloaded / total_file_size) * 100 + print_progress(percent_progress) + hook.retrieve_file(remote_path, None, callback=write_to_file_with_progress) + + # without a custom callback data is written to the local_path + hook.retrieve_file(remote_path, local_path) + + """ + conn = self.get_conn() + + is_path = isinstance(local_full_path_or_buffer, str) + + # without a callback, default to writing to a user-provided file or + # file-like buffer + if not callback: + if is_path: + output_handle = open(local_full_path_or_buffer, "wb") + else: + output_handle = local_full_path_or_buffer + callback = output_handle.write + else: + output_handle = None + + remote_path, remote_file_name = os.path.split(remote_full_path) + conn.cwd(remote_path) + self.log.info("Retrieving file from FTP: %s", remote_full_path) + conn.retrbinary(f"RETR {remote_file_name}", callback) + self.log.info("Finished retrieving file from FTP: %s", remote_full_path) + + if is_path and output_handle: + output_handle.close() + + def store_file(self, remote_full_path: str, local_full_path_or_buffer: Any) -> None: + """ + Transfers a local file to the remote location. + + If local_full_path_or_buffer is a string path, the file will be read + from that location; if it is a file-like buffer, the file will + be read from the buffer but not closed. + + :param remote_full_path: full path to the remote file + :type remote_full_path: str + :param local_full_path_or_buffer: full path to the local file or a + file-like buffer + :type local_full_path_or_buffer: str or file-like buffer + """ + conn = self.get_conn() + + is_path = isinstance(local_full_path_or_buffer, str) + + if is_path: + input_handle = open(local_full_path_or_buffer, "rb") + else: + input_handle = local_full_path_or_buffer + remote_path, remote_file_name = os.path.split(remote_full_path) + conn.cwd(remote_path) + conn.storbinary(f"STOR {remote_file_name}", input_handle) + + if is_path: + input_handle.close() + + def delete_file(self, path: str) -> None: + """ + Removes a file on the FTP Server. + + :param path: full path to the remote file + :type path: str + """ + conn = self.get_conn() + conn.delete(path) + + def rename(self, from_name: str, to_name: str) -> str: + """ + Rename a file. + + :param from_name: rename file from name + :param to_name: rename file to name + """ + conn = self.get_conn() + return conn.rename(from_name, to_name) + + def get_mod_time(self, path: str) -> datetime.datetime: + """ + Returns a datetime object representing the last time the file was modified + + :param path: remote file path + :type path: str + """ + conn = self.get_conn() + ftp_mdtm = conn.sendcmd("MDTM " + path) + time_val = ftp_mdtm[4:] + # time_val optionally has microseconds + try: + return datetime.datetime.strptime(time_val, "%Y%m%d%H%M%S.%f") + except ValueError: + return datetime.datetime.strptime(time_val, "%Y%m%d%H%M%S") + + def get_size(self, path: str) -> Optional[int]: + """ + Returns the size of a file (in bytes) + + :param path: remote file path + :type path: str + """ + conn = self.get_conn() + size = conn.size(path) + return int(size) if size else None + + +class FTPSHook(FTPHook): + """Interact with FTPS.""" + + def get_conn(self) -> ftplib.FTP: + """Returns a FTPS connection object.""" + if self.conn is None: + params = self.get_connection(self.ftp_conn_id) + pasv = params.extra_dejson.get("passive", True) + + if params.port: + ftplib.FTP_TLS.port = params.port + + self.conn = ftplib.FTP_TLS(params.host, params.login, params.password) + self.conn.set_pasv(pasv) + + return self.conn diff --git a/reference/providers/ftp/provider.yaml b/reference/providers/ftp/provider.yaml new file mode 100644 index 0000000..1a1c6f3 --- /dev/null +++ b/reference/providers/ftp/provider.yaml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-ftp +name: File Transfer Protocol (FTP) +description: | + `File Transfer Protocol (FTP) `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: File Transfer Protocol (FTP) + external-doc-url: https://tools.ietf.org/html/rfc114 + logo: /integration-logos/ftp/FTP.png + tags: [protocol] + +sensors: + - integration-name: File Transfer Protocol (FTP) + python-modules: + - airflow.providers.ftp.sensors.ftp + +hooks: + - integration-name: File Transfer Protocol (FTP) + python-modules: + - airflow.providers.ftp.hooks.ftp + +hook-class-names: + - airflow.providers.ftp.hooks.ftp.FTPHook diff --git a/reference/providers/ftp/sensors/__init__.py b/reference/providers/ftp/sensors/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/ftp/sensors/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/ftp/sensors/ftp.py b/reference/providers/ftp/sensors/ftp.py new file mode 100644 index 0000000..90cbca4 --- /dev/null +++ b/reference/providers/ftp/sensors/ftp.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import ftplib +import re + +from airflow.providers.ftp.hooks.ftp import FTPHook, FTPSHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class FTPSensor(BaseSensorOperator): + """ + Waits for a file or directory to be present on FTP. + + :param path: Remote file or directory path + :type path: str + :param fail_on_transient_errors: Fail on all errors, + including 4xx transient errors. Default True. + :type fail_on_transient_errors: bool + :param ftp_conn_id: The connection to run the sensor against + :type ftp_conn_id: str + """ + + template_fields = ("path",) + + """Errors that are transient in nature, and where action can be retried""" + transient_errors = [421, 425, 426, 434, 450, 451, 452] + + error_code_pattern = re.compile(r"([\d]+)") + + @apply_defaults + def __init__( + self, + *, + path: str, + ftp_conn_id: str = "ftp_default", + fail_on_transient_errors: bool = True, + **kwargs + ) -> None: + super().__init__(**kwargs) + + self.path = path + self.ftp_conn_id = ftp_conn_id + self.fail_on_transient_errors = fail_on_transient_errors + + def _create_hook(self) -> FTPHook: + """Return connection hook.""" + return FTPHook(ftp_conn_id=self.ftp_conn_id) + + def _get_error_code(self, e): + """Extract error code from ftp exception""" + try: + matches = self.error_code_pattern.match(str(e)) + code = int(matches.group(0)) + return code + except ValueError: + return e + + def poke(self, context: dict) -> bool: + with self._create_hook() as hook: + self.log.info("Poking for %s", self.path) + try: + hook.get_mod_time(self.path) + except ftplib.error_perm as e: + self.log.error("Ftp error encountered: %s", str(e)) + error_code = self._get_error_code(e) + if (error_code != 550) and ( + self.fail_on_transient_errors + or (error_code not in self.transient_errors) + ): + raise e + + return False + + return True + + +class FTPSSensor(FTPSensor): + """Waits for a file or directory to be present on FTP over SSL.""" + + def _create_hook(self) -> FTPHook: + """Return connection hook.""" + return FTPSHook(ftp_conn_id=self.ftp_conn_id) diff --git a/reference/providers/google/CHANGELOG.rst b/reference/providers/google/CHANGELOG.rst new file mode 100644 index 0000000..5ed132b --- /dev/null +++ b/reference/providers/google/CHANGELOG.rst @@ -0,0 +1,149 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +2.1.0 +..... + +Features +~~~~~~~~ + +* ``Corrects order of argument in docstring in GCSHook.download method (#14497)`` +* ``Refactor SQL/BigQuery/Qubole/Druid Check operators (#12677)`` +* ``Add GoogleDriveToLocalOperator (#14191)`` +* ``Add 'exists_ok' flag to BigQueryCreateEmptyTable(Dataset)Operator (#14026)`` +* ``Add materialized view support for BigQuery (#14201)`` +* ``Add BigQueryUpdateTableOperator (#14149)`` +* ``Add param to CloudDataTransferServiceOperator (#14118)`` +* ``Add gdrive_to_gcs operator, drive sensor, additional functionality to drive hook (#13982)`` +* ``Improve GCSToSFTPOperator paths handling (#11284)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fixes to dataproc operators and hook (#14086)`` +* ``#9803 fix bug in copy operation without wildcard (#13919)`` + +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +Updated ``google-cloud-*`` libraries +```````````````````````````````````` + +This release of the provider package contains third-party library updates, which may require updating your +DAG files or custom hooks and operators, if you were using objects from those libraries. +Updating of these libraries is necessary to be able to use new features made available by new versions of +the libraries and to obtain bug fixes that are only available for new versions of the library. + +Details are covered in the UPDATING.md files for each library, but there are some details +that you should pay attention to. + + ++-----------------------------------------------------------------------------------------------------+----------------------+---------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| Library name | Previous constraints | Current constraints | Upgrade Documentation | ++=====================================================================================================+======================+=====================+=======================================================================================================================================+ +| `google-cloud-automl `_ | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0`` | `Upgrading google-cloud-automl `_ | ++-----------------------------------------------------------------------------------------------------+----------------------+---------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| `google-cloud-bigquery-datatransfer `_ | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0`` | `Upgrading google-cloud-bigquery-datatransfer `_ | ++-----------------------------------------------------------------------------------------------------+----------------------+---------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| `google-cloud-datacatalog `_ | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0`` | `Upgrading google-cloud-datacatalog `_ | ++-----------------------------------------------------------------------------------------------------+----------------------+---------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| `google-cloud-dataproc `_ | ``>=1.0.1,<2.0.0`` | ``>=2.2.0,<3.0.0`` | `Upgrading google-cloud-dataproc `_ | ++-----------------------------------------------------------------------------------------------------+----------------------+---------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| `google-cloud-kms `_ | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | `Upgrading google-cloud-kms `_ | ++-----------------------------------------------------------------------------------------------------+----------------------+---------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| `google-cloud-logging `_ | ``>=1.14.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | `Upgrading google-cloud-logging `_ | ++-----------------------------------------------------------------------------------------------------+----------------------+---------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| `google-cloud-monitoring `_ | ``>=0.34.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | `Upgrading google-cloud-monitoring `_ | ++-----------------------------------------------------------------------------------------------------+----------------------+---------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| `google-cloud-os-login `_ | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | `Upgrading google-cloud-os-login `_ | ++-----------------------------------------------------------------------------------------------------+----------------------+---------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| `google-cloud-pubsub `_ | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | `Upgrading google-cloud-pubsub `_ | ++-----------------------------------------------------------------------------------------------------+----------------------+---------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| `google-cloud-tasks `_ | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | `Upgrading google-cloud-task `_ | ++-----------------------------------------------------------------------------------------------------+----------------------+---------------------+---------------------------------------------------------------------------------------------------------------------------------------+ + +The field names use the snake_case convention +````````````````````````````````````````````` + +If your DAG uses an object from the above mentioned libraries passed by XCom, it is necessary to update the +naming convention of the fields that are read. Previously, the fields used the CamelSnake convention, +now the snake_case convention is used. + +**Before:** + +.. code-block:: python + + set_acl_permission = GCSBucketCreateAclEntryOperator( + task_id="gcs-set-acl-permission", + bucket=BUCKET_NAME, + entity="user-{{ task_instance.xcom_pull('get-instance')['persistenceIamIdentity']" + ".split(':', 2)[1] }}", + role="OWNER", + ) + + +**After:** + +.. code-block:: python + + set_acl_permission = GCSBucketCreateAclEntryOperator( + task_id="gcs-set-acl-permission", + bucket=BUCKET_NAME, + entity="user-{{ task_instance.xcom_pull('get-instance')['persistence_iam_identity']" + ".split(':', 2)[1] }}", + role="OWNER", + ) + + +Features +~~~~~~~~ + +* ``Add Apache Beam operators (#12814)`` +* ``Add Google Cloud Workflows Operators (#13366)`` +* ``Replace 'google_cloud_storage_conn_id' by 'gcp_conn_id' when using 'GCSHook' (#13851)`` +* ``Add How To Guide for Dataflow (#13461)`` +* ``Generalize MLEngineStartTrainingJobOperator to custom images (#13318)`` +* ``Add Parquet data type to BaseSQLToGCSOperator (#13359)`` +* ``Add DataprocCreateWorkflowTemplateOperator (#13338)`` +* ``Add OracleToGCS Transfer (#13246)`` +* ``Add timeout option to gcs hook methods. (#13156)`` +* ``Add regional support to dataproc workflow template operators (#12907)`` +* ``Add project_id to client inside BigQuery hook update_table method (#13018)`` + +Bug fixes +~~~~~~~~~ + +* ``Fix four bugs in StackdriverTaskHandler (#13784)`` +* ``Decode Remote Google Logs (#13115)`` +* ``Fix and improve GCP BigTable hook and system test (#13896)`` +* ``updated Google DV360 Hook to fix SDF issue (#13703)`` +* ``Fix insert_all method of BigQueryHook to support tables without schema (#13138)`` +* ``Fix Google BigQueryHook method get_schema() (#13136)`` +* ``Fix Data Catalog operators (#13096)`` + + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/google/__init__.py b/reference/providers/google/__init__.py new file mode 100644 index 0000000..7e07016 --- /dev/null +++ b/reference/providers/google/__init__.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import importlib +import logging + +# HACK: +# Sphinx-autoapi doesn't like imports to excluded packages in the main module. +conf = importlib.import_module("airflow.configuration").conf + +PROVIDERS_GOOGLE_VERBOSE_LOGGING: bool = conf.getboolean( + "providers_google", "VERBOSE_LOGGING", fallback=False +) +if PROVIDERS_GOOGLE_VERBOSE_LOGGING: + for logger_name in ["google_auth_httplib2", "httplib2", "googleapiclient"]: + logger = logging.getLogger(logger_name) + logger.handlers += [ + handler + for handler in logging.getLogger().handlers + if handler.name in ["task", "console"] + ] + logger.level = logging.DEBUG + logger.propagate = False + + import httplib2 + + httplib2.debuglevel = 4 diff --git a/reference/providers/google/ads/__init__.py b/reference/providers/google/ads/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/ads/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/ads/example_dags/__init__.py b/reference/providers/google/ads/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/ads/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/ads/example_dags/example_ads.py b/reference/providers/google/ads/example_dags/example_ads.py new file mode 100644 index 0000000..e38fd47 --- /dev/null +++ b/reference/providers/google/ads/example_dags/example_ads.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that shows how to use GoogleAdsToGcsOperator. +""" +import os + +from airflow import models +from airflow.providers.google.ads.operators.ads import GoogleAdsListAccountsOperator +from airflow.providers.google.ads.transfers.ads_to_gcs import GoogleAdsToGcsOperator +from airflow.utils import dates + +# [START howto_google_ads_env_variables] +CLIENT_IDS = ["1111111111", "2222222222"] +BUCKET = os.environ.get("GOOGLE_ADS_BUCKET", "gs://test-google-ads-bucket") +GCS_OBJ_PATH = "folder_name/google-ads-api-results.csv" +GCS_ACCOUNTS_CSV = "folder_name/accounts.csv" +QUERY = """ + SELECT + segments.date, + customer.id, + campaign.id, + ad_group.id, + ad_group_ad.ad.id, + metrics.impressions, + metrics.clicks, + metrics.conversions, + metrics.all_conversions, + metrics.cost_micros + FROM + ad_group_ad + WHERE + segments.date >= '2020-02-01' + AND segments.date <= '2020-02-29' + """ + +FIELDS_TO_EXTRACT = [ + "segments.date.value", + "customer.id.value", + "campaign.id.value", + "ad_group.id.value", + "ad_group_ad.ad.id.value", + "metrics.impressions.value", + "metrics.clicks.value", + "metrics.conversions.value", + "metrics.all_conversions.value", + "metrics.cost_micros.value", +] + +# [END howto_google_ads_env_variables] + +with models.DAG( + "example_google_ads", + schedule_interval=None, # Override to match your needs + start_date=dates.days_ago(1), +) as dag: + # [START howto_google_ads_to_gcs_operator] + run_operator = GoogleAdsToGcsOperator( + client_ids=CLIENT_IDS, + query=QUERY, + attributes=FIELDS_TO_EXTRACT, + obj=GCS_OBJ_PATH, + bucket=BUCKET, + task_id="run_operator", + ) + # [END howto_google_ads_to_gcs_operator] + + # [START howto_ads_list_accounts_operator] + list_accounts = GoogleAdsListAccountsOperator( + task_id="list_accounts", bucket=BUCKET, object_name=GCS_ACCOUNTS_CSV + ) + # [END howto_ads_list_accounts_operator] diff --git a/reference/providers/google/ads/hooks/__init__.py b/reference/providers/google/ads/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/ads/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/ads/hooks/ads.py b/reference/providers/google/ads/hooks/ads.py new file mode 100644 index 0000000..1439b8f --- /dev/null +++ b/reference/providers/google/ads/hooks/ads.py @@ -0,0 +1,221 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Ad hook.""" +from tempfile import NamedTemporaryFile +from typing import IO, Any, Dict, Generator, List + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow import AirflowException +from airflow.hooks.base import BaseHook +from google.ads.google_ads.client import GoogleAdsClient +from google.ads.google_ads.errors import GoogleAdsException +from google.ads.google_ads.v2.types import GoogleAdsRow +from google.api_core.page_iterator import GRPCIterator +from google.auth.exceptions import GoogleAuthError +from googleapiclient.discovery import Resource + + +class GoogleAdsHook(BaseHook): + """ + Hook for the Google Ads API. + + This hook requires two connections: + + - gcp_conn_id - provides service account details (like any other GCP connection) + - google_ads_conn_id - which contains information from Google Ads config.yaml file + in the ``extras``. Example of the ``extras``: + + .. code-block:: json + + { + "google_ads_client": { + "developer_token": "{{ INSERT_TOKEN }}", + "path_to_private_key_file": null, + "delegated_account": "{{ INSERT_DELEGATED_ACCOUNT }}" + } + } + + The ``path_to_private_key_file`` is resolved by the hook using credentials from gcp_conn_id. + https://developers.google.com/google-ads/api/docs/client-libs/python/oauth-service + + .. seealso:: + For more information on how Google Ads authentication flow works take a look at: + https://developers.google.com/google-ads/api/docs/client-libs/python/oauth-service + + .. seealso:: + For more information on the Google Ads API, take a look at the API docs: + https://developers.google.com/google-ads/api/docs/start + + :param gcp_conn_id: The connection ID with the service account details. + :type gcp_conn_id: str + :param google_ads_conn_id: The connection ID with the details of Google Ads config.yaml file. + :type google_ads_conn_id: str + + :return: list of Google Ads Row object(s) + :rtype: list[GoogleAdsRow] + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + google_ads_conn_id: str = "google_ads_default", + api_version: str = "v3", + ) -> None: + super().__init__() + self.gcp_conn_id = gcp_conn_id + self.google_ads_conn_id = google_ads_conn_id + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.google_ads_config: Dict[str, Any] = {} + + @cached_property + def _get_service(self) -> Re# + """Connects and authenticates with the Google Ads API using a service account""" + with NamedTemporaryFile("w", suffix=".json") as secrets_temp: + self._get_config() + self._update_config_with_secret(secrets_temp) + try: + client = GoogleAdsClient.load_from_dict(self.google_ads_config) + return client.get_service("GoogleAdsService", version=self.api_version) + except GoogleAuthError as e: + self.log.error("Google Auth Error: %s", e) + raise + + @cached_property + def _get_customer_service(self) -> Re# + """Connects and authenticates with the Google Ads API using a service account""" + with NamedTemporaryFile("w", suffix=".json") as secrets_temp: + self._get_config() + self._update_config_with_secret(secrets_temp) + try: + client = GoogleAdsClient.load_from_dict(self.google_ads_config) + return client.get_service("CustomerService", version=self.api_version) + except GoogleAuthError as e: + self.log.error("Google Auth Error: %s", e) + raise + + def _get_config(self) -> None: + """ + Gets google ads connection from meta db and sets google_ads_config attribute with returned config + file + """ + conn = self.get_connection(self.google_ads_conn_id) + if "google_ads_client" not in conn.extra_dejson: + raise AirflowException("google_ads_client not found in extra field") + + self.google_ads_config = conn.extra_dejson["google_ads_client"] + + def _update_config_with_secret(self, secrets_temp: IO[str]) -> None: + """ + Gets Google Cloud secret from connection and saves the contents to the temp file + Updates google ads config with file path of the temp file containing the secret + Note, the secret must be passed as a file path for Google Ads API + """ + secret_conn = self.get_connection(self.gcp_conn_id) + secret = secret_conn.extra_dejson["extra__google_cloud_platform__keyfile_dict"] + secrets_temp.write(secret) + secrets_temp.flush() + + self.google_ads_config["path_to_private_key_file"] = secrets_temp.name + + def search( + self, client_ids: List[str], query: str, page_size: int = 10000, **kwargs + ) -> List[GoogleAdsRow]: + """ + Pulls data from the Google Ads API + + :param client_ids: Google Ads client ID(s) to query the API for. + :type client_ids: List[str] + :param query: Google Ads Query Language query. + :type query: str + :param page_size: Number of results to return per page. Max 10000. + :type page_size: int + + :return: Google Ads API response, converted to Google Ads Row objects + :rtype: list[GoogleAdsRow] + """ + service = self._get_service + iterators = ( + service.search(client_id, query=query, page_size=page_size, **kwargs) + for client_id in client_ids + ) + self.log.info("Fetched Google Ads Iterators") + + return self._extract_rows(iterators) + + def _extract_rows( + self, iterators: Generator[GRPCIterator, None, None] + ) -> List[GoogleAdsRow]: + """ + Convert Google Page Iterator (GRPCIterator) objects to Google Ads Rows + + :param iterators: List of Google Page Iterator (GRPCIterator) objects + :type iterators: generator[GRPCIterator, None, None] + + :return: API response for all clients in the form of Google Ads Row object(s) + :rtype: list[GoogleAdsRow] + """ + try: + self.log.info("Extracting data from returned Google Ads Iterators") + return [row for iterator in iterators for row in iterator] + except GoogleAdsException as e: + self.log.error( + "Request ID %s failed with status %s and includes the following errors:", + e.request_id, + e.error.code().name, + ) + for error in e.failure.errors: + self.log.error("\tError with message: %s.", error.message) + if error.location: + for field_path_element in error.location.field_path_elements: + self.log.error( + "\t\tOn field: %s", field_path_element.field_name + ) + raise + + def list_accessible_customers(self) -> List[str]: + """ + Returns resource names of customers directly accessible by the user authenticating the call. + The resulting list of customers is based on your OAuth credentials. The request returns a list + of all accounts that you are able to act upon directly given your current credentials. This will + not necessarily include all accounts within the account hierarchy; rather, it will only include + accounts where your authenticated user has been added with admin or other rights in the account. + + ..seealso:: + https://developers.google.com/google-ads/api/reference/rpc + + :return: List of names of customers + """ + try: + accessible_customers = ( + self._get_customer_service.list_accessible_customers() + ) + return accessible_customers.resource_names + except GoogleAdsException as ex: + for error in ex.failure.errors: + self.log.error('\tError with message "%s".', error.message) + if error.location: + for field_path_element in error.location.field_path_elements: + self.log.error( + "\t\tOn field: %s", field_path_element.field_name + ) + raise diff --git a/reference/providers/google/ads/operators/__init__.py b/reference/providers/google/ads/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/ads/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/ads/operators/ads.py b/reference/providers/google/ads/operators/ads.py new file mode 100644 index 0000000..45aa829 --- /dev/null +++ b/reference/providers/google/ads/operators/ads.py @@ -0,0 +1,121 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Ad to GCS operators.""" +import csv +from tempfile import NamedTemporaryFile +from typing import Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.ads.hooks.ads import GoogleAdsHook +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils.decorators import apply_defaults + + +class GoogleAdsListAccountsOperator(BaseOperator): + """ + Saves list of customers on GCS in form of a csv file. + + The resulting list of customers is based on your OAuth credentials. The request returns a list + of all accounts that you are able to act upon directly given your current credentials. This will + not necessarily include all accounts within the account hierarchy; rather, it will only include + accounts where your authenticated user has been added with admin or other rights in the account. + + ..seealso:: + https://developers.google.com/google-ads/api/reference/rpc + + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleAdsListAccountsOperator` + + :param bucket: The GCS bucket to upload to + :type bucket: str + :param object_name: GCS path to save the csv file. Must be the full file path (ex. `path/to/file.csv`) + :type object_name: str + :param gcp_conn_id: Airflow Google Cloud connection ID + :type gcp_conn_id: str + :param google_ads_conn_id: Airflow Google Ads connection ID + :type google_ads_conn_id: str + :param page_size: The number of results per API page request. Max 10,000 + :type page_size: int + :param gzip: Option to compress local file or file data for upload + :type gzip: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "bucket", + "object_name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + bucket: str, + object_name: str, + gcp_conn_id: str = "google_cloud_default", + google_ads_conn_id: str = "google_ads_default", + gzip: bool = False, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.bucket = bucket + self.object_name = object_name + self.gcp_conn_id = gcp_conn_id + self.google_ads_conn_id = google_ads_conn_id + self.gzip = gzip + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> str: + uri = f"gs://{self.bucket}/{self.object_name}" + + ads_hook = GoogleAdsHook( + gcp_conn_id=self.gcp_conn_id, google_ads_conn_id=self.google_ads_conn_id + ) + + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + with NamedTemporaryFile("w+") as temp_file: + # Download accounts + accounts = ads_hook.list_accessible_customers() + writer = csv.writer(temp_file) + writer.writerows(accounts) + temp_file.flush() + + # Upload to GCS + gcs_hook.upload( + bucket_name=self.bucket, + object_name=self.object_name, + gzip=self.gzip, + filename=temp_file.name, + ) + self.log.info("Uploaded %s to %s", len(accounts), uri) + + return uri diff --git a/reference/providers/google/ads/transfers/__init__.py b/reference/providers/google/ads/transfers/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/ads/transfers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/ads/transfers/ads_to_gcs.py b/reference/providers/google/ads/transfers/ads_to_gcs.py new file mode 100644 index 0000000..bc7dd18 --- /dev/null +++ b/reference/providers/google/ads/transfers/ads_to_gcs.py @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import csv +from operator import attrgetter +from tempfile import NamedTemporaryFile +from typing import List, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.ads.hooks.ads import GoogleAdsHook +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils.decorators import apply_defaults + + +class GoogleAdsToGcsOperator(BaseOperator): + """ + Fetches the daily results from the Google Ads API for 1-n clients + Converts and saves the data as a temporary CSV file + Uploads the CSV to Google Cloud Storage + + .. seealso:: + For more information on the Google Ads API, take a look at the API docs: + https://developers.google.com/google-ads/api/docs/start + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleAdsToGcsOperator` + + :param client_ids: Google Ads client IDs to query + :type client_ids: List[str] + :param query: Google Ads Query Language API query + :type query: str + :param attributes: List of Google Ads Row attributes to extract + :type attributes: List[str] + :param bucket: The GCS bucket to upload to + :type bucket: str + :param obj: GCS path to save the object. Must be the full file path (ex. `path/to/file.txt`) + :type obj: str + :param gcp_conn_id: Airflow Google Cloud connection ID + :type gcp_conn_id: str + :param google_ads_conn_id: Airflow Google Ads connection ID + :type google_ads_conn_id: str + :param page_size: The number of results per API page request. Max 10,000 + :type page_size: int + :param gzip: Option to compress local file or file data for upload + :type gzip: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "client_ids", + "query", + "attributes", + "bucket", + "obj", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + client_ids: List[str], + query: str, + attributes: List[str], + bucket: str, + obj: str, + gcp_conn_id: str = "google_cloud_default", + google_ads_conn_id: str = "google_ads_default", + page_size: int = 10000, + gzip: bool = False, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.client_ids = client_ids + self.query = query + self.attributes = attributes + self.bucket = bucket + self.obj = obj + self.gcp_conn_id = gcp_conn_id + self.google_ads_conn_id = google_ads_conn_id + self.page_size = page_size + self.gzip = gzip + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + service = GoogleAdsHook( + gcp_conn_id=self.gcp_conn_id, google_ads_conn_id=self.google_ads_conn_id + ) + rows = service.search( + client_ids=self.client_ids, query=self.query, page_size=self.page_size + ) + + try: + getter = attrgetter(*self.attributes) + converted_rows = [getter(row) for row in rows] + except Exception as e: + self.log.error( + "An error occurred in converting the Google Ad Rows. \n Error %s", e + ) + raise + + with NamedTemporaryFile("w", suffix=".csv") as csvfile: + writer = csv.writer(csvfile) + writer.writerows(converted_rows) + csvfile.flush() + + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + hook.upload( + bucket_name=self.bucket, + object_name=self.obj, + filename=csvfile.name, + gzip=self.gzip, + ) + self.log.info("%s uploaded to GCS", self.obj) diff --git a/reference/providers/google/cloud/__init__.py b/reference/providers/google/cloud/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/cloud/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/cloud/_internal_client/__init__.py b/reference/providers/google/cloud/_internal_client/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/cloud/_internal_client/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/cloud/_internal_client/secret_manager_client.py b/reference/providers/google/cloud/_internal_client/secret_manager_client.py new file mode 100644 index 0000000..28f6088 --- /dev/null +++ b/reference/providers/google/cloud/_internal_client/secret_manager_client.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +from typing import Optional + +import google + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.version import version +from google.api_core.exceptions import NotFound, PermissionDenied +from google.api_core.gapic_v1.client_info import ClientInfo +from google.cloud.secretmanager_v1 import SecretManagerServiceClient + +SECRET_ID_PATTERN = r"^[a-zA-Z0-9-_]*$" + + +class _SecretManagerClient(LoggingMixin): + """ + Retrieves Secrets object from Google Cloud Secrets Manager. This is a common class reused between + SecretsManager and Secrets Hook that provides the shared authentication and verification mechanisms. + This class should not be used directly, use SecretsManager or SecretsHook instead + + + :param credentials: Credentials used to authenticate to GCP + :type credentials: google.auth.credentials.Credentials + """ + + def __init__( + self, + credentials: google.auth.credentials.Credentials, + ) -> None: + super().__init__() + self.credentials = credentials + + @staticmethod + def is_valid_secret_name(secret_name: str) -> bool: + """ + Returns true if the secret name is valid. + :param secret_name: name of the secret + :type secret_name: str + :return: + """ + return bool(re.match(SECRET_ID_PATTERN, secret_name)) + + @cached_property + def client(self) -> SecretManagerServiceClient: + """Create an authenticated KMS client""" + _client = SecretManagerServiceClient( + credentials=self.credentials, + client_info=ClientInfo(client_library_version="airflow_v" + version), + ) + return _client + + def get_secret( + self, secret_id: str, project_id: str, secret_version: str = "latest" + ) -> Optional[str]: + """ + Get secret value from the Secret Manager. + + :param secret_id: Secret Key + :type secret_id: str + :param project_id: Project id to use + :type project_id: str + :param secret_version: version of the secret (default is 'latest') + :type secret_version: str + """ + name = self.client.secret_version_path(project_id, secret_id, secret_version) + try: + response = self.client.access_secret_version(name) + value = response.payload.data.decode("UTF-8") + return value + except NotFound: + self.log.error( + "Google Cloud API Call Error (NotFound): Secret ID %s not found.", + secret_id, + ) + return None + except PermissionDenied: + self.log.error( + """Google Cloud API Call Error (PermissionDenied): No access for Secret ID %s. + Did you add 'secretmanager.versions.access' permission?""", + secret_id, + ) + return None diff --git a/reference/providers/google/cloud/example_dags/__init__.py b/reference/providers/google/cloud/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/cloud/example_dags/example_automl_nl_text_classification.py b/reference/providers/google/cloud/example_dags/example_automl_nl_text_classification.py new file mode 100644 index 0000000..3762959 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_automl_nl_text_classification.py @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that uses Google AutoML services. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") +GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") +GCP_AUTOML_TEXT_CLS_BUCKET = os.environ.get("GCP_AUTOML_TEXT_CLS_BUCKET", "gs://") + +# Example values +DATASET_ID = "" + +# Example model +MODEL = { + "display_name": "auto_model_1", + "dataset_id": DATASET_ID, + "text_classification_model_metadata": {}, +} + +# Example dataset +DATASET = { + "display_name": "test_text_cls_dataset", + "text_classification_dataset_metadata": {"classification_type": "MULTICLASS"}, +} + +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_TEXT_CLS_BUCKET]}} + +extract_object_id = CloudAutoMLHook.extract_object_id + +# Example DAG for AutoML Natural Language Text Classification +with models.DAG( + "example_automl_text_cls", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as example_dag: + create_dataset_task = AutoMLCreateDatasetOperator( + task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION + ) + + dataset_id = ( + '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' + ) + + import_dataset_task = AutoMLImportDataOperator( + task_id="import_dataset_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + + MODEL["dataset_id"] = dataset_id + + create_model = AutoMLTrainModelOperator( + task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION + ) + + model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" + + delete_model_task = AutoMLDeleteModelOperator( + task_id="delete_model_task", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + delete_datasets_task = AutoMLDeleteDatasetOperator( + task_id="delete_datasets_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/reference/providers/google/cloud/example_dags/example_automl_nl_text_extraction.py b/reference/providers/google/cloud/example_dags/example_automl_nl_text_extraction.py new file mode 100644 index 0000000..b9a2ccf --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_automl_nl_text_extraction.py @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that uses Google AutoML services. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") +GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") +GCP_AUTOML_TEXT_BUCKET = os.environ.get( + "GCP_AUTOML_TEXT_BUCKET", "gs://cloud-ml-data/NL-entity/dataset.csv" +) + +# Example values +DATASET_ID = "" + +# Example model +MODEL = { + "display_name": "auto_model_1", + "dataset_id": DATASET_ID, + "text_extraction_model_metadata": {}, +} + +# Example dataset +DATASET = {"display_name": "test_text_dataset", "text_extraction_dataset_metadata": {}} + +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_TEXT_BUCKET]}} + +extract_object_id = CloudAutoMLHook.extract_object_id + +# Example DAG for AutoML Natural Language Entities Extraction +with models.DAG( + "example_automl_text", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + user_defined_macros={"extract_object_id": extract_object_id}, + tags=["example"], +) as example_dag: + create_dataset_task = AutoMLCreateDatasetOperator( + task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION + ) + + dataset_id = ( + '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' + ) + + import_dataset_task = AutoMLImportDataOperator( + task_id="import_dataset_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + + MODEL["dataset_id"] = dataset_id + + create_model = AutoMLTrainModelOperator( + task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION + ) + + model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" + + delete_model_task = AutoMLDeleteModelOperator( + task_id="delete_model_task", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + delete_datasets_task = AutoMLDeleteDatasetOperator( + task_id="delete_datasets_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/reference/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py b/reference/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py new file mode 100644 index 0000000..bccc6b0 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that uses Google AutoML services. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") +GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") +GCP_AUTOML_SENTIMENT_BUCKET = os.environ.get("GCP_AUTOML_SENTIMENT_BUCKET", "gs://") + +# Example values +DATASET_ID = "" + +# Example model +MODEL = { + "display_name": "auto_model_1", + "dataset_id": DATASET_ID, + "text_sentiment_model_metadata": {}, +} + +# Example dataset +DATASET = { + "display_name": "test_text_sentiment_dataset", + "text_sentiment_dataset_metadata": {"sentiment_max": 10}, +} + +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_SENTIMENT_BUCKET]}} + +extract_object_id = CloudAutoMLHook.extract_object_id + +# Example DAG for AutoML Natural Language Text Sentiment +with models.DAG( + "example_automl_text_sentiment", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + user_defined_macros={"extract_object_id": extract_object_id}, + tags=["example"], +) as example_dag: + create_dataset_task = AutoMLCreateDatasetOperator( + task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION + ) + + dataset_id = ( + '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' + ) + + import_dataset_task = AutoMLImportDataOperator( + task_id="import_dataset_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + + MODEL["dataset_id"] = dataset_id + + create_model = AutoMLTrainModelOperator( + task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION + ) + + model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" + + delete_model_task = AutoMLDeleteModelOperator( + task_id="delete_model_task", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + delete_datasets_task = AutoMLDeleteDatasetOperator( + task_id="delete_datasets_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/reference/providers/google/cloud/example_dags/example_automl_tables.py b/reference/providers/google/cloud/example_dags/example_automl_tables.py new file mode 100644 index 0000000..b37373f --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_automl_tables.py @@ -0,0 +1,308 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that uses Google AutoML services. +""" +import os +from copy import deepcopy +from typing import Dict, List + +from airflow import models +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLBatchPredictOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLDeployModelOperator, + AutoMLGetModelOperator, + AutoMLImportDataOperator, + AutoMLListDatasetOperator, + AutoMLPredictOperator, + AutoMLTablesListColumnSpecsOperator, + AutoMLTablesListTableSpecsOperator, + AutoMLTablesUpdateDatasetOperator, + AutoMLTrainModelOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") +GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") +GCP_AUTOML_DATASET_BUCKET = os.environ.get( + "GCP_AUTOML_DATASET_BUCKET", "gs://cloud-ml-tables-data/bank-marketing.csv" +) +TARGET = os.environ.get("GCP_AUTOML_TARGET", "Deposit") + +# Example values +MODEL_ID = "TBL123456" +DATASET_ID = "TBL123456" + +# Example model +MODEL = { + "display_name": "auto_model_1", + "dataset_id": DATASET_ID, + "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, +} + +# Example dataset +DATASET = { + "display_name": "test_set", + "tables_dataset_metadata": {"target_column_spec_id": ""}, +} + +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_DATASET_BUCKET]}} + +extract_object_id = CloudAutoMLHook.extract_object_id + + +def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: + """ + Using column name returns spec of the column. + """ + for column in columns_specs: + if column["display_name"] == column_name: + return extract_object_id(column) + raise Exception(f"Unknown target column: {column_name}") + + +# Example DAG to create dataset, train model_id and deploy it. +with models.DAG( + "example_create_and_deploy", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + user_defined_macros={ + "get_target_column_spec": get_target_column_spec, + "target": TARGET, + "extract_object_id": extract_object_id, + }, + tags=["example"], +) as create_deploy_dag: + # [START howto_operator_automl_create_dataset] + create_dataset_task = AutoMLCreateDatasetOperator( + task_id="create_dataset_task", + dataset=DATASET, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + dataset_id = ( + "{{ task_instance.xcom_pull('create_dataset_task', key='dataset_id') }}" + ) + # [END howto_operator_automl_create_dataset] + + MODEL["dataset_id"] = dataset_id + + # [START howto_operator_automl_import_data] + import_dataset_task = AutoMLImportDataOperator( + task_id="import_dataset_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + # [END howto_operator_automl_import_data] + + # [START howto_operator_automl_specs] + list_tables_spec_task = AutoMLTablesListTableSpecsOperator( + task_id="list_tables_spec_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_automl_specs] + + # [START howto_operator_automl_column_specs] + list_columns_spec_task = AutoMLTablesListColumnSpecsOperator( + task_id="list_columns_spec_task", + dataset_id=dataset_id, + table_spec_id="{{ extract_object_id(task_instance.xcom_pull('list_tables_spec_task')[0]) }}", + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_automl_column_specs] + + # [START howto_operator_automl_update_dataset] + update = deepcopy(DATASET) + update["name"] = '{{ task_instance.xcom_pull("create_dataset_task")["name"] }}' + update["tables_dataset_metadata"][ # type: ignore + "target_column_spec_id" + ] = "{{ get_target_column_spec(task_instance.xcom_pull('list_columns_spec_task'), target) }}" + + update_dataset_task = AutoMLTablesUpdateDatasetOperator( + task_id="update_dataset_task", + dataset=update, + location=GCP_AUTOML_LOCATION, + ) + # [END howto_operator_automl_update_dataset] + + # [START howto_operator_automl_create_model] + create_model_task = AutoMLTrainModelOperator( + task_id="create_model_task", + model=MODEL, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + model_id = "{{ task_instance.xcom_pull('create_model_task', key='model_id') }}" + # [END howto_operator_automl_create_model] + + # [START howto_operator_automl_delete_model] + delete_model_task = AutoMLDeleteModelOperator( + task_id="delete_model_task", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_automl_delete_model] + + delete_datasets_task = AutoMLDeleteDatasetOperator( + task_id="delete_datasets_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + ( + create_dataset_task # noqa + >> import_dataset_task # noqa + >> list_tables_spec_task # noqa + >> list_columns_spec_task # noqa + >> update_dataset_task # noqa + >> create_model_task # noqa + >> delete_model_task # noqa + >> delete_datasets_task # noqa + ) + + +# Example DAG for AutoML datasets operations +with models.DAG( + "example_automl_dataset", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + user_defined_macros={"extract_object_id": extract_object_id}, +) as example_dag: + create_dataset_task = AutoMLCreateDatasetOperator( + task_id="create_dataset_task", + dataset=DATASET, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + dataset_id = ( + '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' + ) + + import_dataset_task = AutoMLImportDataOperator( + task_id="import_dataset_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + + list_tables_spec_task = AutoMLTablesListTableSpecsOperator( + task_id="list_tables_spec_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + list_columns_spec_task = AutoMLTablesListColumnSpecsOperator( + task_id="list_columns_spec_task", + dataset_id=dataset_id, + table_spec_id="{{ extract_object_id(task_instance.xcom_pull('list_tables_spec_task')[0]) }}", + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + # [START howto_operator_list_dataset] + list_datasets_task = AutoMLListDatasetOperator( + task_id="list_datasets_task", + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_list_dataset] + + # [START howto_operator_delete_dataset] + delete_datasets_task = AutoMLDeleteDatasetOperator( + task_id="delete_datasets_task", + dataset_id="{{ task_instance.xcom_pull('list_datasets_task', key='dataset_id_list') | list }}", + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_delete_dataset] + + ( + create_dataset_task # noqa + >> import_dataset_task # noqa + >> list_tables_spec_task # noqa + >> list_columns_spec_task # noqa + >> list_datasets_task # noqa + >> delete_datasets_task # noqa + ) + +with models.DAG( + "example_gcp_get_deploy", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as get_deploy_dag: + # [START howto_operator_get_model] + get_model_task = AutoMLGetModelOperator( + task_id="get_model_task", + model_id=MODEL_ID, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_get_model] + + # [START howto_operator_deploy_model] + deploy_model_task = AutoMLDeployModelOperator( + task_id="deploy_model_task", + model_id=MODEL_ID, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_deploy_model] + + +with models.DAG( + "example_gcp_predict", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as predict_dag: + # [START howto_operator_prediction] + predict_task = AutoMLPredictOperator( + task_id="predict_task", + model_id=MODEL_ID, + payload={}, # Add your own payload, the used model_id must be deployed + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_prediction] + + # [START howto_operator_batch_prediction] + batch_predict_task = AutoMLBatchPredictOperator( + task_id="batch_predict_task", + model_id=MODEL_ID, + input_config={}, # Add your config + output_config={}, # Add your config + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_batch_prediction] diff --git a/reference/providers/google/cloud/example_dags/example_automl_translation.py b/reference/providers/google/cloud/example_dags/example_automl_translation.py new file mode 100644 index 0000000..c5fb748 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_automl_translation.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that uses Google AutoML services. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") +GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") +GCP_AUTOML_TRANSLATION_BUCKET = os.environ.get( + "GCP_AUTOML_TRANSLATION_BUCKET", "gs://project-vcm/file" +) + +# Example values +DATASET_ID = "TRL123456789" + +# Example model +MODEL = { + "display_name": "auto_model_1", + "dataset_id": DATASET_ID, + "translation_model_metadata": {}, +} + +# Example dataset +DATASET = { + "display_name": "test_translation_dataset", + "translation_dataset_metadata": { + "source_language_code": "en", + "target_language_code": "es", + }, +} + +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_TRANSLATION_BUCKET]}} + +extract_object_id = CloudAutoMLHook.extract_object_id + + +# Example DAG for AutoML Translation +with models.DAG( + "example_automl_translation", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + user_defined_macros={"extract_object_id": extract_object_id}, + tags=["example"], +) as example_dag: + create_dataset_task = AutoMLCreateDatasetOperator( + task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION + ) + + dataset_id = ( + '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' + ) + + import_dataset_task = AutoMLImportDataOperator( + task_id="import_dataset_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + + MODEL["dataset_id"] = dataset_id + + create_model = AutoMLTrainModelOperator( + task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION + ) + + model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" + + delete_model_task = AutoMLDeleteModelOperator( + task_id="delete_model_task", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + delete_datasets_task = AutoMLDeleteDatasetOperator( + task_id="delete_datasets_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/reference/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py b/reference/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py new file mode 100644 index 0000000..f330e5f --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that uses Google AutoML services. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") +GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") +GCP_AUTOML_VIDEO_BUCKET = os.environ.get( + "GCP_AUTOML_VIDEO_BUCKET", "gs://automl-video-demo-data/hmdb_split1.csv" +) + +# Example values +DATASET_ID = "VCN123455678" + +# Example model +MODEL = { + "display_name": "auto_model_1", + "dataset_id": DATASET_ID, + "video_classification_model_metadata": {}, +} + +# Example dataset +DATASET = { + "display_name": "test_video_dataset", + "video_classification_dataset_metadata": {}, +} + +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_VIDEO_BUCKET]}} + +extract_object_id = CloudAutoMLHook.extract_object_id + + +# Example DAG for AutoML Video Intelligence Classification +with models.DAG( + "example_automl_video", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + user_defined_macros={"extract_object_id": extract_object_id}, + tags=["example"], +) as example_dag: + create_dataset_task = AutoMLCreateDatasetOperator( + task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION + ) + + dataset_id = ( + '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' + ) + + import_dataset_task = AutoMLImportDataOperator( + task_id="import_dataset_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + + MODEL["dataset_id"] = dataset_id + + create_model = AutoMLTrainModelOperator( + task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION + ) + + model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" + + delete_model_task = AutoMLDeleteModelOperator( + task_id="delete_model_task", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + delete_datasets_task = AutoMLDeleteDatasetOperator( + task_id="delete_datasets_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/reference/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py b/reference/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py new file mode 100644 index 0000000..2c1377f --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py @@ -0,0 +1,108 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that uses Google AutoML services. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") +GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") +GCP_AUTOML_TRACKING_BUCKET = os.environ.get( + "GCP_AUTOML_TRACKING_BUCKET", + "gs://automl-video-datasets/youtube_8m_videos_animal_tiny.csv", +) + +# Example values +DATASET_ID = "VOT123456789" + +# Example model +MODEL = { + "display_name": "auto_model_1", + "dataset_id": DATASET_ID, + "video_object_tracking_model_metadata": {}, +} + +# Example dataset +DATASET = { + "display_name": "test_video_tracking_dataset", + "video_object_tracking_dataset_metadata": {}, +} + +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_TRACKING_BUCKET]}} + +extract_object_id = CloudAutoMLHook.extract_object_id + + +# Example DAG for AutoML Video Intelligence Object Tracking +with models.DAG( + "example_automl_video_tracking", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + user_defined_macros={"extract_object_id": extract_object_id}, + tags=["example"], +) as example_dag: + create_dataset_task = AutoMLCreateDatasetOperator( + task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION + ) + + dataset_id = ( + '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' + ) + + import_dataset_task = AutoMLImportDataOperator( + task_id="import_dataset_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + + MODEL["dataset_id"] = dataset_id + + create_model = AutoMLTrainModelOperator( + task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION + ) + + model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" + + delete_model_task = AutoMLDeleteModelOperator( + task_id="delete_model_task", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + delete_datasets_task = AutoMLDeleteDatasetOperator( + task_id="delete_datasets_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/reference/providers/google/cloud/example_dags/example_automl_vision_classification.py b/reference/providers/google/cloud/example_dags/example_automl_vision_classification.py new file mode 100644 index 0000000..ec98e98 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_automl_vision_classification.py @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that uses Google AutoML services. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") +GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") +GCP_AUTOML_VISION_BUCKET = os.environ.get( + "GCP_AUTOML_VISION_BUCKET", "gs://your-bucket" +) + +# Example values +DATASET_ID = "ICN123455678" + +# Example model +MODEL = { + "display_name": "auto_model_2", + "dataset_id": DATASET_ID, + "image_classification_model_metadata": {"train_budget": 1}, +} + +# Example dataset +DATASET = { + "display_name": "test_vision_dataset", + "image_classification_dataset_metadata": {"classification_type": "MULTILABEL"}, +} + +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_VISION_BUCKET]}} + +extract_object_id = CloudAutoMLHook.extract_object_id + + +# Example DAG for AutoML Vision Classification +with models.DAG( + "example_automl_vision", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + user_defined_macros={"extract_object_id": extract_object_id}, + tags=["example"], +) as example_dag: + create_dataset_task = AutoMLCreateDatasetOperator( + task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION + ) + + dataset_id = ( + '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' + ) + + import_dataset_task = AutoMLImportDataOperator( + task_id="import_dataset_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + + MODEL["dataset_id"] = dataset_id + + create_model = AutoMLTrainModelOperator( + task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION + ) + + model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" + + delete_model_task = AutoMLDeleteModelOperator( + task_id="delete_model_task", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + delete_datasets_task = AutoMLDeleteDatasetOperator( + task_id="delete_datasets_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/reference/providers/google/cloud/example_dags/example_automl_vision_object_detection.py b/reference/providers/google/cloud/example_dags/example_automl_vision_object_detection.py new file mode 100644 index 0000000..7cdc857 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_automl_vision_object_detection.py @@ -0,0 +1,108 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that uses Google AutoML services. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") +GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") +GCP_AUTOML_DETECTION_BUCKET = os.environ.get( + "GCP_AUTOML_DETECTION_BUCKET", + "gs://cloud-ml-data/img/openimage/csv/salads_ml_use.csv", +) + +# Example values +DATASET_ID = "" + +# Example model +MODEL = { + "display_name": "auto_model", + "dataset_id": DATASET_ID, + "image_object_detection_model_metadata": {}, +} + +# Example dataset +DATASET = { + "display_name": "test_detection_dataset", + "image_object_detection_dataset_metadata": {}, +} + +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_DETECTION_BUCKET]}} + +extract_object_id = CloudAutoMLHook.extract_object_id + + +# Example DAG for AutoML Vision Object Detection +with models.DAG( + "example_automl_vision_detection", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + user_defined_macros={"extract_object_id": extract_object_id}, + tags=["example"], +) as example_dag: + create_dataset_task = AutoMLCreateDatasetOperator( + task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION + ) + + dataset_id = ( + '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' + ) + + import_dataset_task = AutoMLImportDataOperator( + task_id="import_dataset_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + + MODEL["dataset_id"] = dataset_id + + create_model = AutoMLTrainModelOperator( + task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION + ) + + model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" + + delete_model_task = AutoMLDeleteModelOperator( + task_id="delete_model_task", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + delete_datasets_task = AutoMLDeleteDatasetOperator( + task_id="delete_datasets_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/reference/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py b/reference/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py new file mode 100644 index 0000000..473f39a --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs import ( + AzureFileShareToGCSOperator, +) + +DEST_GCS_BUCKET = os.environ.get("GCP_GCS_BUCKET", "gs://test-gcs-example-bucket") +AZURE_SHARE_NAME = os.environ.get("AZURE_SHARE_NAME", "test-azure-share") +AZURE_DIRECTORY_NAME = "test-azure-dir" + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), +} + +with DAG( + dag_id="azure_fileshare_to_gcs_example", + default_args=default_args, + schedule_interval=None, + start_date=datetime(2018, 11, 1), + tags=["example"], +) as dag: + # [START howto_operator_azure_fileshare_to_gcs_basic] + sync_azure_files_with_gcs = AzureFileShareToGCSOperator( + task_id="sync_azure_files_with_gcs", + share_name=AZURE_SHARE_NAME, + dest_gcs=DEST_GCS_BUCKET, + directory_name=AZURE_DIRECTORY_NAME, + wasb_conn_id="azure_fileshare_default", + gcp_conn_id="google_cloud_default", + replace=False, + gzip=True, + google_impersonation_chain=None, + ) + # [END howto_operator_azure_fileshare_to_gcs_basic] diff --git a/reference/providers/google/cloud/example_dags/example_bigquery_dts.py b/reference/providers/google/cloud/example_dags/example_bigquery_dts.py new file mode 100644 index 0000000..083cf09 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_bigquery_dts.py @@ -0,0 +1,116 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that creates and deletes Bigquery data transfer configurations. +""" +import os +import time + +from airflow import models +from airflow.providers.google.cloud.operators.bigquery_dts import ( + BigQueryCreateDataTransferOperator, + BigQueryDataTransferServiceStartTransferRunsOperator, + BigQueryDeleteDataTransferConfigOperator, +) +from airflow.providers.google.cloud.sensors.bigquery_dts import ( + BigQueryDataTransferServiceTransferRunSensor, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +BUCKET_URI = os.environ.get( + "GCP_DTS_BUCKET_URI", "gs://cloud-ml-tables-data/bank-marketing.csv" +) +GCP_DTS_BQ_DATASET = os.environ.get("GCP_DTS_BQ_DATASET", "test_dts") +GCP_DTS_BQ_TABLE = os.environ.get("GCP_DTS_BQ_TABLE", "GCS_Test") + +# [START howto_bigquery_dts_create_args] + +# In the case of Airflow, the customer needs to create a transfer +# config with the automatic scheduling disabled and then trigger +# a transfer run using a specialized Airflow operator +schedule_options = {"disable_auto_scheduling": True} + +PARAMS = { + "field_delimiter": ",", + "max_bad_records": "0", + "skip_leading_rows": "1", + "data_path_template": BUCKET_URI, + "destination_table_name_template": GCP_DTS_BQ_TABLE, + "file_format": "CSV", +} + +TRANSFER_CONFIG = { + "destination_dataset_id": GCP_DTS_BQ_DATASET, + "display_name": "GCS Test Config", + "data_source_id": "google_cloud_storage", + "schedule_options": schedule_options, + "params": PARAMS, +} + +# [END howto_bigquery_dts_create_args] + +with models.DAG( + "example_gcp_bigquery_dts", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + # [START howto_bigquery_create_data_transfer] + gcp_bigquery_create_transfer = BigQueryCreateDataTransferOperator( + transfer_config=TRANSFER_CONFIG, + project_id=GCP_PROJECT_ID, + task_id="gcp_bigquery_create_transfer", + ) + + transfer_config_id = "{{ task_instance.xcom_pull('gcp_bigquery_create_transfer', key='transfer_config_id') }}" + # [END howto_bigquery_create_data_transfer] + + # [START howto_bigquery_start_transfer] + gcp_bigquery_start_transfer = BigQueryDataTransferServiceStartTransferRunsOperator( + task_id="gcp_bigquery_start_transfer", + transfer_config_id=transfer_config_id, + requested_run_time={"seconds": int(time.time() + 60)}, + ) + run_id = ( + "{{ task_instance.xcom_pull('gcp_bigquery_start_transfer', key='run_id') }}" + ) + # [END howto_bigquery_start_transfer] + + # [START howto_bigquery_dts_sensor] + gcp_run_sensor = BigQueryDataTransferServiceTransferRunSensor( + task_id="gcp_run_sensor", + transfer_config_id=transfer_config_id, + run_id=run_id, + expected_statuses={"SUCCEEDED"}, + ) + # [END howto_bigquery_dts_sensor] + + # [START howto_bigquery_delete_data_transfer] + gcp_bigquery_delete_transfer = BigQueryDeleteDataTransferConfigOperator( + transfer_config_id=transfer_config_id, task_id="gcp_bigquery_delete_transfer" + ) + # [END howto_bigquery_delete_data_transfer] + + ( + gcp_bigquery_create_transfer # noqa + >> gcp_bigquery_start_transfer # noqa + >> gcp_run_sensor # noqa + >> gcp_bigquery_delete_transfer # noqa + ) diff --git a/reference/providers/google/cloud/example_dags/example_bigquery_operations.py b/reference/providers/google/cloud/example_dags/example_bigquery_operations.py new file mode 100644 index 0000000..1daccea --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_bigquery_operations.py @@ -0,0 +1,246 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google BigQuery service. +""" +import os +import time +from urllib.parse import urlparse + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryCreateExternalTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryDeleteTableOperator, + BigQueryGetDatasetOperator, + BigQueryGetDatasetTablesOperator, + BigQueryPatchDatasetOperator, + BigQueryUpdateDatasetOperator, + BigQueryUpdateTableOperator, + BigQueryUpsertTableOperator, +) +from airflow.utils.dates import days_ago + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +BQ_LOCATION = "europe-north1" + +DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_dataset_operations") +LOCATION_DATASET_NAME = f"{DATASET_NAME}_location" +DATA_SAMPLE_GCS_URL = os.environ.get( + "GCP_BIGQUERY_DATA_GCS_URL", + "gs://cloud-samples-data/bigquery/us-states/us-states.csv", +) + +DATA_SAMPLE_GCS_URL_PARTS = urlparse(DATA_SAMPLE_GCS_URL) +DATA_SAMPLE_GCS_BUCKET_NAME = DATA_SAMPLE_GCS_URL_PARTS.netloc +DATA_SAMPLE_GCS_OBJECT_NAME = DATA_SAMPLE_GCS_URL_PARTS.path[1:] + + +with models.DAG( + "example_bigquery_operations", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + # [START howto_operator_bigquery_create_table] + create_table = BigQueryCreateEmptyTableOperator( + task_id="create_table", + dataset_id=DATASET_NAME, + table_id="test_table", + schema_fields=[ + {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, + ], + ) + # [END howto_operator_bigquery_create_table] + + # [START howto_operator_bigquery_delete_table] + delete_table = BigQueryDeleteTableOperator( + task_id="delete_table", + deletion_dataset_table=f"{PROJECT_ID}.{DATASET_NAME}.test_table", + ) + # [END howto_operator_bigquery_delete_table] + + # [START howto_operator_bigquery_create_view] + create_view = BigQueryCreateEmptyTableOperator( + task_id="create_view", + dataset_id=DATASET_NAME, + table_id="test_view", + view={ + "query": f"SELECT * FROM `{PROJECT_ID}.{DATASET_NAME}.test_table`", + "useLegacySql": False, + }, + ) + # [END howto_operator_bigquery_create_view] + + # [START howto_operator_bigquery_delete_view] + delete_view = BigQueryDeleteTableOperator( + task_id="delete_view", + deletion_dataset_table=f"{PROJECT_ID}.{DATASET_NAME}.test_view", + ) + # [END howto_operator_bigquery_delete_view] + + # [START howto_operator_bigquery_create_materialized_view] + create_materialized_view = BigQueryCreateEmptyTableOperator( + task_id="create_materialized_view", + dataset_id=DATASET_NAME, + table_id="test_materialized_view", + materialized_view={ + "query": f"SELECT SUM(salary) AS sum_salary FROM `{PROJECT_ID}.{DATASET_NAME}.test_table`", + "enableRefresh": True, + "refreshIntervalMs": 2000000, + }, + ) + # [END howto_operator_bigquery_create_materialized_view] + + # [START howto_operator_bigquery_delete_materialized_view] + delete_materialized_view = BigQueryDeleteTableOperator( + task_id="delete_materialized_view", + deletion_dataset_table=f"{PROJECT_ID}.{DATASET_NAME}.test_materialized_view", + ) + # [END howto_operator_bigquery_delete_materialized_view] + + # [START howto_operator_bigquery_create_external_table] + create_external_table = BigQueryCreateExternalTableOperator( + task_id="create_external_table", + bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + source_objects=[DATA_SAMPLE_GCS_OBJECT_NAME], + destination_project_dataset_table=f"{DATASET_NAME}.external_table", + skip_leading_rows=1, + schema_fields=[ + {"name": "name", "type": "STRING"}, + {"name": "post_abbr", "type": "STRING"}, + ], + ) + # [END howto_operator_bigquery_create_external_table] + + # [START howto_operator_bigquery_upsert_table] + upsert_table = BigQueryUpsertTableOperator( + task_id="upsert_table", + dataset_id=DATASET_NAME, + table_resource={ + "tableReference": {"tableId": "test_table_id"}, + "expirationTime": (int(time.time()) + 300) * 1000, + }, + ) + # [END howto_operator_bigquery_upsert_table] + + # [START howto_operator_bigquery_create_dataset] + create_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create-dataset", dataset_id=DATASET_NAME + ) + # [END howto_operator_bigquery_create_dataset] + + # [START howto_operator_bigquery_get_dataset_tables] + get_dataset_tables = BigQueryGetDatasetTablesOperator( + task_id="get_dataset_tables", dataset_id=DATASET_NAME + ) + # [END howto_operator_bigquery_get_dataset_tables] + + # [START howto_operator_bigquery_get_dataset] + get_dataset = BigQueryGetDatasetOperator( + task_id="get-dataset", dataset_id=DATASET_NAME + ) + # [END howto_operator_bigquery_get_dataset] + + get_dataset_result = BashOperator( + task_id="get_dataset_result", + bash_command="echo \"{{ task_instance.xcom_pull('get-dataset')['id'] }}\"", + ) + + # [START howto_operator_bigquery_update_table] + update_table = BigQueryUpdateTableOperator( + task_id="update_table", + dataset_id=DATASET_NAME, + table_id="test_table", + fields=[ + {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, + ], + table_resource={ + "friendlyName": "Updated Table", + "description": "Updated Table", + }, + ) + # [END howto_operator_bigquery_update_table] + + # [START howto_operator_bigquery_patch_dataset] + patch_dataset = BigQueryPatchDatasetOperator( + task_id="patch_dataset", + dataset_id=DATASET_NAME, + dataset_resource={ + "friendlyName": "Patched Dataset", + "description": "Patched dataset", + }, + ) + # [END howto_operator_bigquery_patch_dataset] + + # [START howto_operator_bigquery_update_dataset] + update_dataset = BigQueryUpdateDatasetOperator( + task_id="update_dataset", + dataset_id=DATASET_NAME, + dataset_resource={"description": "Updated dataset"}, + ) + # [END howto_operator_bigquery_update_dataset] + + # [START howto_operator_bigquery_delete_dataset] + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True + ) + # [END howto_operator_bigquery_delete_dataset] + + create_dataset >> patch_dataset >> update_dataset >> get_dataset >> get_dataset_result >> delete_dataset + + update_dataset >> create_table >> create_view >> create_materialized_view >> update_table >> [ + get_dataset_tables, + delete_view, + ] >> upsert_table >> delete_materialized_view >> delete_table >> delete_dataset + update_dataset >> create_external_table >> delete_dataset + +with models.DAG( + "example_bigquery_operations_location", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +): + create_dataset_with_location = BigQueryCreateEmptyDatasetOperator( + task_id="create_dataset_with_location", + dataset_id=LOCATION_DATASET_NAME, + location=BQ_LOCATION, + ) + + create_table_with_location = BigQueryCreateEmptyTableOperator( + task_id="create_table_with_location", + dataset_id=LOCATION_DATASET_NAME, + table_id="test_table", + schema_fields=[ + {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, + ], + ) + + delete_dataset_with_location = BigQueryDeleteDatasetOperator( + task_id="delete_dataset_with_location", + dataset_id=LOCATION_DATASET_NAME, + delete_contents=True, + ) + create_dataset_with_location >> create_table_with_location >> delete_dataset_with_location diff --git a/reference/providers/google/cloud/example_dags/example_bigquery_queries.py b/reference/providers/google/cloud/example_dags/example_bigquery_queries.py new file mode 100644 index 0000000..f0a9395 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_bigquery_queries.py @@ -0,0 +1,210 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google BigQuery service. +""" +import os +from datetime import datetime + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCheckOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryExecuteQueryOperator, + BigQueryGetDataOperator, + BigQueryInsertJobOperator, + BigQueryIntervalCheckOperator, + BigQueryValueCheckOperator, +) +from airflow.utils.dates import days_ago + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_dataset") +LOCATION = "southamerica-east1" + +TABLE_1 = "table1" +TABLE_2 = "table2" + + +INSERT_DATE = datetime.now().strftime("%Y-%m-%d") + +# [START howto_operator_bigquery_query] +INSERT_ROWS_QUERY = ( + f"INSERT {DATASET_NAME}.{TABLE_1} VALUES " + f"(42, 'monthy python', '{INSERT_DATE}'), " + f"(42, 'fishy fish', '{INSERT_DATE}');" +) +# [END howto_operator_bigquery_query] + +SCHEMA = [ + {"name": "value", "type": "INTEGER", "mode": "REQUIRED"}, + {"name": "name", "type": "STRING", "mode": "NULLABLE"}, + {"name": "ds", "type": "DATE", "mode": "NULLABLE"}, +] + +for location in [None, LOCATION]: + dag_id = ( + "example_bigquery_queries_location" if location else "example_bigquery_queries" + ) + + with models.DAG( + dag_id, + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], + user_defined_macros={"DATASET": DATASET_NAME, "TABLE": TABLE_1}, + ) as dag_with_locations: + create_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create-dataset", + dataset_id=DATASET_NAME, + location=location, + ) + + create_table_1 = BigQueryCreateEmptyTableOperator( + task_id="create_table_1", + dataset_id=DATASET_NAME, + table_id=TABLE_1, + schema_fields=SCHEMA, + location=location, + ) + + create_table_2 = BigQueryCreateEmptyTableOperator( + task_id="create_table_2", + dataset_id=DATASET_NAME, + table_id=TABLE_2, + schema_fields=SCHEMA, + location=location, + ) + + create_dataset >> [create_table_1, create_table_2] + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True + ) + + # [START howto_operator_bigquery_insert_job] + insert_query_job = BigQueryInsertJobOperator( + task_id="insert_query_job", + configuration={ + "query": { + "query": INSERT_ROWS_QUERY, + "useLegacySql": False, + } + }, + location=location, + ) + # [END howto_operator_bigquery_insert_job] + + # [START howto_operator_bigquery_select_job] + select_query_job = BigQueryInsertJobOperator( + task_id="select_query_job", + configuration={ + "query": { + "query": "{% include 'example_bigquery_query.sql' %}", + "useLegacySql": False, + } + }, + location=location, + ) + # [END howto_operator_bigquery_select_job] + + execute_insert_query = BigQueryExecuteQueryOperator( + task_id="execute_insert_query", + sql=INSERT_ROWS_QUERY, + use_legacy_sql=False, + location=location, + ) + + bigquery_execute_multi_query = BigQueryExecuteQueryOperator( + task_id="execute_multi_query", + sql=[ + f"SELECT * FROM {DATASET_NAME}.{TABLE_2}", + f"SELECT COUNT(*) FROM {DATASET_NAME}.{TABLE_2}", + ], + use_legacy_sql=False, + location=location, + ) + + execute_query_save = BigQueryExecuteQueryOperator( + task_id="execute_query_save", + sql=f"SELECT * FROM {DATASET_NAME}.{TABLE_1}", + use_legacy_sql=False, + destination_dataset_table=f"{DATASET_NAME}.{TABLE_2}", + location=location, + ) + + # [START howto_operator_bigquery_get_data] + get_data = BigQueryGetDataOperator( + task_id="get_data", + dataset_id=DATASET_NAME, + table_id=TABLE_1, + max_results=10, + selected_fields="value,name", + location=location, + ) + # [END howto_operator_bigquery_get_data] + + get_data_result = BashOperator( + task_id="get_data_result", + bash_command="echo \"{{ task_instance.xcom_pull('get_data') }}\"", + ) + + # [START howto_operator_bigquery_check] + check_count = BigQueryCheckOperator( + task_id="check_count", + sql=f"SELECT COUNT(*) FROM {DATASET_NAME}.{TABLE_1}", + use_legacy_sql=False, + location=location, + ) + # [END howto_operator_bigquery_check] + + # [START howto_operator_bigquery_value_check] + check_value = BigQueryValueCheckOperator( + task_id="check_value", + sql=f"SELECT COUNT(*) FROM {DATASET_NAME}.{TABLE_1}", + pass_value=4, + use_legacy_sql=False, + location=location, + ) + # [END howto_operator_bigquery_value_check] + + # [START howto_operator_bigquery_interval_check] + check_interval = BigQueryIntervalCheckOperator( + task_id="check_interval", + table=f"{DATASET_NAME}.{TABLE_1}", + days_back=1, + metrics_thresholds={"COUNT(*)": 1.5}, + use_legacy_sql=False, + location=location, + ) + # [END howto_operator_bigquery_interval_check] + + [create_table_1, create_table_2] >> insert_query_job >> select_query_job + + insert_query_job >> execute_insert_query + execute_insert_query >> get_data >> get_data_result >> delete_dataset + execute_insert_query >> execute_query_save >> bigquery_execute_multi_query >> delete_dataset + execute_insert_query >> [ + check_count, + check_value, + check_interval, + ] >> delete_dataset diff --git a/reference/providers/google/cloud/example_dags/example_bigquery_query.sql b/reference/providers/google/cloud/example_dags/example_bigquery_query.sql new file mode 100644 index 0000000..b629f27 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_bigquery_query.sql @@ -0,0 +1,20 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +SELECT * FROM {{ DATASET }}.{{ TABLE }} diff --git a/reference/providers/google/cloud/example_dags/example_bigquery_sensors.py b/reference/providers/google/cloud/example_dags/example_bigquery_sensors.py new file mode 100644 index 0000000..4b33b7e --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_bigquery_sensors.py @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google BigQuery Sensors. +""" +import os +from datetime import datetime + +from airflow import models +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryExecuteQueryOperator, +) +from airflow.providers.google.cloud.sensors.bigquery import ( + BigQueryTableExistenceSensor, + BigQueryTablePartitionExistenceSensor, +) +from airflow.utils.dates import days_ago + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_sensors_dataset") + +TABLE_NAME = "partitioned_table" +INSERT_DATE = datetime.now().strftime("%Y-%m-%d") + +PARTITION_NAME = "{{ ds_nodash }}" + +INSERT_ROWS_QUERY = f"INSERT {DATASET_NAME}.{TABLE_NAME} VALUES " "(42, '{{ ds }}')" + +SCHEMA = [ + {"name": "value", "type": "INTEGER", "mode": "REQUIRED"}, + {"name": "ds", "type": "DATE", "mode": "NULLABLE"}, +] + +dag_id = "example_bigquery_sensors" + +with models.DAG( + dag_id, + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], + user_defined_macros={"DATASET": DATASET_NAME, "TABLE": TABLE_NAME}, + default_args={"project_id": PROJECT_ID}, +) as dag_with_locations: + create_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create-dataset", dataset_id=DATASET_NAME, project_id=PROJECT_ID + ) + + create_table = BigQueryCreateEmptyTableOperator( + task_id="create_table", + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + schema_fields=SCHEMA, + time_partitioning={ + "type": "DAY", + "field": "ds", + }, + ) + # [START howto_sensor_bigquery_table] + check_table_exists = BigQueryTableExistenceSensor( + task_id="check_table_exists", + project_id=PROJECT_ID, + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + ) + # [END howto_sensor_bigquery_table] + + execute_insert_query = BigQueryExecuteQueryOperator( + task_id="execute_insert_query", sql=INSERT_ROWS_QUERY, use_legacy_sql=False + ) + + # [START howto_sensor_bigquery_table_partition] + check_table_partition_exists = BigQueryTablePartitionExistenceSensor( + task_id="check_table_partition_exists", + project_id=PROJECT_ID, + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + partition_id=PARTITION_NAME, + ) + # [END howto_sensor_bigquery_table_partition] + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True + ) + + create_dataset >> create_table + create_table >> check_table_exists + create_table >> execute_insert_query + execute_insert_query >> check_table_partition_exists + check_table_exists >> delete_dataset + check_table_partition_exists >> delete_dataset diff --git a/reference/providers/google/cloud/example_dags/example_bigquery_to_bigquery.py b/reference/providers/google/cloud/example_dags/example_bigquery_to_bigquery.py new file mode 100644 index 0000000..b872630 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_bigquery_to_bigquery.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google BigQuery service. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, +) +from airflow.providers.google.cloud.transfers.bigquery_to_bigquery import ( + BigQueryToBigQueryOperator, +) +from airflow.utils.dates import days_ago + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_dataset_transfer") +ORIGIN = "origin" +TARGET = "target" + +with models.DAG( + "example_bigquery_to_bigquery", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + copy_selected_data = BigQueryToBigQueryOperator( + task_id="copy_selected_data", + source_project_dataset_tables=f"{DATASET_NAME}.{ORIGIN}", + destination_project_dataset_table=f"{DATASET_NAME}.{TARGET}", + ) + + create_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create_dataset", dataset_id=DATASET_NAME + ) + + for table in [ORIGIN, TARGET]: + create_table = BigQueryCreateEmptyTableOperator( + task_id=f"create_{table}_table", + dataset_id=DATASET_NAME, + table_id=table, + schema_fields=[ + {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, + ], + ) + create_dataset >> create_table >> copy_selected_data + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True + ) + + copy_selected_data >> delete_dataset diff --git a/reference/providers/google/cloud/example_dags/example_bigquery_to_gcs.py b/reference/providers/google/cloud/example_dags/example_bigquery_to_gcs.py new file mode 100644 index 0000000..7232bba --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_bigquery_to_gcs.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google BigQuery service. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, +) +from airflow.providers.google.cloud.transfers.bigquery_to_gcs import ( + BigQueryToGCSOperator, +) +from airflow.utils.dates import days_ago + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_dataset_transfer") +DATA_EXPORT_BUCKET_NAME = os.environ.get( + "GCP_BIGQUERY_EXPORT_BUCKET_NAME", "test-bigquery-gcs-data" +) +TABLE = "table_42" + +with models.DAG( + "example_bigquery_to_gcs", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + bigquery_to_gcs = BigQueryToGCSOperator( + task_id="bigquery_to_gcs", + source_project_dataset_table=f"{DATASET_NAME}.{TABLE}", + destination_cloud_storage_uris=[ + f"gs://{DATA_EXPORT_BUCKET_NAME}/export-bigquery.csv" + ], + ) + + create_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create_dataset", dataset_id=DATASET_NAME + ) + + create_table = BigQueryCreateEmptyTableOperator( + task_id="create_table", + dataset_id=DATASET_NAME, + table_id=TABLE, + schema_fields=[ + {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, + ], + ) + create_dataset >> create_table >> bigquery_to_gcs + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True + ) + + bigquery_to_gcs >> delete_dataset diff --git a/reference/providers/google/cloud/example_dags/example_bigquery_transfer.py b/reference/providers/google/cloud/example_dags/example_bigquery_transfer.py new file mode 100644 index 0000000..bb31221 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_bigquery_transfer.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google BigQuery service. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, +) +from airflow.providers.google.cloud.transfers.bigquery_to_bigquery import ( + BigQueryToBigQueryOperator, +) +from airflow.providers.google.cloud.transfers.bigquery_to_gcs import ( + BigQueryToGCSOperator, +) +from airflow.utils.dates import days_ago + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_dataset_transfer") +DATA_EXPORT_BUCKET_NAME = os.environ.get( + "GCP_BIGQUERY_EXPORT_BUCKET_NAME", "test-bigquery-sample-data" +) +ORIGIN = "origin" +TARGET = "target" + +with models.DAG( + "example_bigquery_transfer", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + copy_selected_data = BigQueryToBigQueryOperator( + task_id="copy_selected_data", + source_project_dataset_tables=f"{DATASET_NAME}.{ORIGIN}", + destination_project_dataset_table=f"{DATASET_NAME}.{TARGET}", + ) + + bigquery_to_gcs = BigQueryToGCSOperator( + task_id="bigquery_to_gcs", + source_project_dataset_table=f"{DATASET_NAME}.{ORIGIN}", + destination_cloud_storage_uris=[ + f"gs://{DATA_EXPORT_BUCKET_NAME}/export-bigquery.csv" + ], + ) + + create_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create_dataset", dataset_id=DATASET_NAME + ) + + for table in [ORIGIN, TARGET]: + create_table = BigQueryCreateEmptyTableOperator( + task_id=f"create_{table}_table", + dataset_id=DATASET_NAME, + table_id=table, + schema_fields=[ + {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, + ], + ) + create_dataset >> create_table >> [copy_selected_data, bigquery_to_gcs] + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True + ) + + [copy_selected_data, bigquery_to_gcs] >> delete_dataset diff --git a/reference/providers/google/cloud/example_dags/example_bigtable.py b/reference/providers/google/cloud/example_dags/example_bigtable.py new file mode 100644 index 0000000..d133e89 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_bigtable.py @@ -0,0 +1,219 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that creates and performs following operations on Cloud Bigtable: +- creates an Instance +- creates a Table +- updates Cluster +- waits for Table replication completeness +- deletes the Table +- deletes the Instance + +This DAG relies on the following environment variables: + +* GCP_PROJECT_ID - Google Cloud project +* CBT_INSTANCE_ID - desired ID of a Cloud Bigtable instance +* CBT_INSTANCE_DISPLAY_NAME - desired human-readable display name of the Instance +* CBT_INSTANCE_TYPE - type of the Instance, e.g. 1 for DEVELOPMENT + See https://googleapis.github.io/google-cloud-python/latest/bigtable/instance.html#google.cloud.bigtable.instance.Instance # noqa E501 # pylint: disable=line-too-long +* CBT_INSTANCE_LABELS - labels to add for the Instance +* CBT_CLUSTER_ID - desired ID of the main Cluster created for the Instance +* CBT_CLUSTER_ZONE - zone in which main Cluster will be created. e.g. europe-west1-b + See available zones: https://cloud.google.com/bigtable/docs/locations +* CBT_CLUSTER_NODES - initial amount of nodes of the Cluster +* CBT_CLUSTER_NODES_UPDATED - amount of nodes for BigtableClusterUpdateOperator +* CBT_CLUSTER_STORAGE_TYPE - storage for the Cluster, e.g. 1 for SSD + See https://googleapis.github.io/google-cloud-python/latest/bigtable/instance.html#google.cloud.bigtable.instance.Instance.cluster # noqa E501 # pylint: disable=line-too-long +* CBT_TABLE_ID - desired ID of the Table +* CBT_POKE_INTERVAL - number of seconds between every attempt of Sensor check + +""" + +import json +from os import getenv + +from airflow import models +from airflow.providers.google.cloud.operators.bigtable import ( + BigtableCreateInstanceOperator, + BigtableCreateTableOperator, + BigtableDeleteInstanceOperator, + BigtableDeleteTableOperator, + BigtableUpdateClusterOperator, + BigtableUpdateInstanceOperator, +) +from airflow.providers.google.cloud.sensors.bigtable import ( + BigtableTableReplicationCompletedSensor, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = getenv("GCP_PROJECT_ID", "example-project") +CBT_INSTANCE_ID = getenv("GCP_BIG_TABLE_INSTANCE_ID", "some-instance-id") +CBT_INSTANCE_DISPLAY_NAME = getenv( + "GCP_BIG_TABLE_INSTANCE_DISPLAY_NAME", "Human-readable name" +) +CBT_INSTANCE_DISPLAY_NAME_UPDATED = getenv( + "GCP_BIG_TABLE_INSTANCE_DISPLAY_NAME_UPDATED", + f"{CBT_INSTANCE_DISPLAY_NAME} - updated", +) +CBT_INSTANCE_TYPE = getenv("GCP_BIG_TABLE_INSTANCE_TYPE", "2") +CBT_INSTANCE_TYPE_PROD = getenv("GCP_BIG_TABLE_INSTANCE_TYPE_PROD", "1") +CBT_INSTANCE_LABELS = getenv("GCP_BIG_TABLE_INSTANCE_LABELS", "{}") +CBT_INSTANCE_LABELS_UPDATED = getenv( + "GCP_BIG_TABLE_INSTANCE_LABELS_UPDATED", '{"env": "prod"}' +) +CBT_CLUSTER_ID = getenv("GCP_BIG_TABLE_CLUSTER_ID", "some-cluster-id") +CBT_CLUSTER_ZONE = getenv("GCP_BIG_TABLE_CLUSTER_ZONE", "europe-west1-b") +CBT_CLUSTER_NODES = getenv("GCP_BIG_TABLE_CLUSTER_NODES", "3") +CBT_CLUSTER_NODES_UPDATED = getenv("GCP_BIG_TABLE_CLUSTER_NODES_UPDATED", "5") +CBT_CLUSTER_STORAGE_TYPE = getenv("GCP_BIG_TABLE_CLUSTER_STORAGE_TYPE", "2") +CBT_TABLE_ID = getenv("GCP_BIG_TABLE_TABLE_ID", "some-table-id") +CBT_POKE_INTERVAL = getenv("GCP_BIG_TABLE_POKE_INTERVAL", "60") + + +with models.DAG( + "example_gcp_bigtable_operators", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + # [START howto_operator_gcp_bigtable_instance_create] + create_instance_task = BigtableCreateInstanceOperator( + project_id=GCP_PROJECT_ID, + instance_id=CBT_INSTANCE_ID, + main_cluster_id=CBT_CLUSTER_ID, + main_cluster_zone=CBT_CLUSTER_ZONE, + instance_display_name=CBT_INSTANCE_DISPLAY_NAME, + instance_type=int(CBT_INSTANCE_TYPE), + instance_labels=json.loads(CBT_INSTANCE_LABELS), + cluster_nodes=None, + cluster_storage_type=int(CBT_CLUSTER_STORAGE_TYPE), + task_id="create_instance_task", + ) + create_instance_task2 = BigtableCreateInstanceOperator( + instance_id=CBT_INSTANCE_ID, + main_cluster_id=CBT_CLUSTER_ID, + main_cluster_zone=CBT_CLUSTER_ZONE, + instance_display_name=CBT_INSTANCE_DISPLAY_NAME, + instance_type=int(CBT_INSTANCE_TYPE), + instance_labels=json.loads(CBT_INSTANCE_LABELS), + cluster_nodes=int(CBT_CLUSTER_NODES), + cluster_storage_type=int(CBT_CLUSTER_STORAGE_TYPE), + task_id="create_instance_task2", + ) + create_instance_task >> create_instance_task2 + # [END howto_operator_gcp_bigtable_instance_create] + + # [START howto_operator_gcp_bigtable_instance_update] + update_instance_task = BigtableUpdateInstanceOperator( + instance_id=CBT_INSTANCE_ID, + instance_display_name=CBT_INSTANCE_DISPLAY_NAME_UPDATED, + instance_type=int(CBT_INSTANCE_TYPE_PROD), + instance_labels=json.loads(CBT_INSTANCE_LABELS_UPDATED), + task_id="update_instance_task", + ) + # [END howto_operator_gcp_bigtable_instance_update] + + # [START howto_operator_gcp_bigtable_cluster_update] + cluster_update_task = BigtableUpdateClusterOperator( + project_id=GCP_PROJECT_ID, + instance_id=CBT_INSTANCE_ID, + cluster_id=CBT_CLUSTER_ID, + nodes=int(CBT_CLUSTER_NODES_UPDATED), + task_id="update_cluster_task", + ) + cluster_update_task2 = BigtableUpdateClusterOperator( + instance_id=CBT_INSTANCE_ID, + cluster_id=CBT_CLUSTER_ID, + nodes=int(CBT_CLUSTER_NODES_UPDATED), + task_id="update_cluster_task2", + ) + cluster_update_task >> cluster_update_task2 + # [END howto_operator_gcp_bigtable_cluster_update] + + # [START howto_operator_gcp_bigtable_instance_delete] + delete_instance_task = BigtableDeleteInstanceOperator( + project_id=GCP_PROJECT_ID, + instance_id=CBT_INSTANCE_ID, + task_id="delete_instance_task", + ) + delete_instance_task2 = BigtableDeleteInstanceOperator( + instance_id=CBT_INSTANCE_ID, + task_id="delete_instance_task2", + ) + # [END howto_operator_gcp_bigtable_instance_delete] + + # [START howto_operator_gcp_bigtable_table_create] + create_table_task = BigtableCreateTableOperator( + project_id=GCP_PROJECT_ID, + instance_id=CBT_INSTANCE_ID, + table_id=CBT_TABLE_ID, + task_id="create_table", + ) + create_table_task2 = BigtableCreateTableOperator( + instance_id=CBT_INSTANCE_ID, + table_id=CBT_TABLE_ID, + task_id="create_table_task2", + ) + create_table_task >> create_table_task2 + # [END howto_operator_gcp_bigtable_table_create] + + # [START howto_operator_gcp_bigtable_table_wait_for_replication] + wait_for_table_replication_task = BigtableTableReplicationCompletedSensor( + project_id=GCP_PROJECT_ID, + instance_id=CBT_INSTANCE_ID, + table_id=CBT_TABLE_ID, + poke_interval=int(CBT_POKE_INTERVAL), + timeout=180, + task_id="wait_for_table_replication_task", + ) + wait_for_table_replication_task2 = BigtableTableReplicationCompletedSensor( + instance_id=CBT_INSTANCE_ID, + table_id=CBT_TABLE_ID, + poke_interval=int(CBT_POKE_INTERVAL), + timeout=180, + task_id="wait_for_table_replication_task2", + ) + # [END howto_operator_gcp_bigtable_table_wait_for_replication] + + # [START howto_operator_gcp_bigtable_table_delete] + delete_table_task = BigtableDeleteTableOperator( + project_id=GCP_PROJECT_ID, + instance_id=CBT_INSTANCE_ID, + table_id=CBT_TABLE_ID, + task_id="delete_table_task", + ) + delete_table_task2 = BigtableDeleteTableOperator( + instance_id=CBT_INSTANCE_ID, + table_id=CBT_TABLE_ID, + task_id="delete_table_task2", + ) + # [END howto_operator_gcp_bigtable_table_delete] + + wait_for_table_replication_task >> delete_table_task + wait_for_table_replication_task2 >> delete_table_task + wait_for_table_replication_task >> delete_table_task2 + wait_for_table_replication_task2 >> delete_table_task2 + create_instance_task >> create_table_task >> cluster_update_task + cluster_update_task >> update_instance_task >> delete_table_task + create_instance_task2 >> create_table_task2 >> cluster_update_task2 >> delete_table_task2 + + # Only delete instances after all tables are deleted + [ + delete_table_task, + delete_table_task2, + ] >> delete_instance_task >> delete_instance_task2 diff --git a/reference/providers/google/cloud/example_dags/example_cloud_build.py b/reference/providers/google/cloud/example_dags/example_cloud_build.py new file mode 100644 index 0000000..48eeb4c --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_cloud_build.py @@ -0,0 +1,131 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that displays interactions with Google Cloud Build. + +This DAG relies on the following OS environment variables: + +* GCP_PROJECT_ID - Google Cloud Project to use for the Cloud Function. +* GCP_CLOUD_BUILD_ARCHIVE_URL - Path to the zipped source in Google Cloud Storage. + This object must be a gzipped archive file (.tar.gz) containing source to build. +* GCP_CLOUD_BUILD_REPOSITORY_NAME - Name of the Cloud Source Repository. + +""" + +import os +from pathlib import Path + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.cloud_build import ( + CloudBuildCreateBuildOperator, +) +from airflow.utils import dates +from future.backports.urllib.parse import urlparse + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") + +GCP_SOURCE_ARCHIVE_URL = os.environ.get( + "GCP_CLOUD_BUILD_ARCHIVE_URL", "gs://example-bucket/file" +) +GCP_SOURCE_REPOSITORY_NAME = os.environ.get( + "GCP_CLOUD_BUILD_REPOSITORY_NAME", "repository-name" +) + +GCP_SOURCE_ARCHIVE_URL_PARTS = urlparse(GCP_SOURCE_ARCHIVE_URL) +GCP_SOURCE_BUCKET_NAME = GCP_SOURCE_ARCHIVE_URL_PARTS.netloc + +CURRENT_FOLDER = Path(__file__).parent + +# [START howto_operator_gcp_create_build_from_storage_body] +create_build_from_storage_body = { + "source": {"storageSource": GCP_SOURCE_ARCHIVE_URL}, + "steps": [ + { + "name": "gcr.io/cloud-builders/docker", + "args": [ + "build", + "-t", + f"gcr.io/$PROJECT_ID/{GCP_SOURCE_BUCKET_NAME}", + ".", + ], + } + ], + "images": [f"gcr.io/$PROJECT_ID/{GCP_SOURCE_BUCKET_NAME}"], +} +# [END howto_operator_gcp_create_build_from_storage_body] + +# [START howto_operator_create_build_from_repo_body] +create_build_from_repo_body = { + "source": { + "repoSource": {"repoName": GCP_SOURCE_REPOSITORY_NAME, "branchName": "master"} + }, + "steps": [ + { + "name": "gcr.io/cloud-builders/docker", + "args": ["build", "-t", "gcr.io/$PROJECT_ID/$REPO_NAME", "."], + } + ], + "images": ["gcr.io/$PROJECT_ID/$REPO_NAME"], +} +# [END howto_operator_create_build_from_repo_body] + +with models.DAG( + "example_gcp_cloud_build", + default_args=dict(start_date=dates.days_ago(1)), + schedule_interval=None, + tags=["example"], +) as dag: + # [START howto_operator_create_build_from_storage] + create_build_from_storage = CloudBuildCreateBuildOperator( + task_id="create_build_from_storage", + project_id=GCP_PROJECT_ID, + body=create_build_from_storage_body, + ) + # [END howto_operator_create_build_from_storage] + + # [START howto_operator_create_build_from_storage_result] + create_build_from_storage_result = BashOperator( + bash_command="echo '{{ task_instance.xcom_pull('create_build_from_storage')['images'][0] }}'", + task_id="create_build_from_storage_result", + ) + # [END howto_operator_create_build_from_storage_result] + + create_build_from_repo = CloudBuildCreateBuildOperator( + task_id="create_build_from_repo", + project_id=GCP_PROJECT_ID, + body=create_build_from_repo_body, + ) + + create_build_from_repo_result = BashOperator( + bash_command="echo '{{ task_instance.xcom_pull('create_build_from_repo')['images'][0] }}'", + task_id="create_build_from_repo_result", + ) + + # [START howto_operator_gcp_create_build_from_yaml_body] + create_build_from_file = CloudBuildCreateBuildOperator( + task_id="create_build_from_file", + project_id=GCP_PROJECT_ID, + body=str(CURRENT_FOLDER.joinpath("example_cloud_build.yaml")), + params={"name": "Airflow"}, + ) + # [END howto_operator_gcp_create_build_from_yaml_body] + create_build_from_storage >> create_build_from_storage_result # pylint: disable=pointless-statement + + create_build_from_repo >> create_build_from_repo_result # pylint: disable=pointless-statement diff --git a/reference/providers/google/cloud/example_dags/example_cloud_build.yaml b/reference/providers/google/cloud/example_dags/example_cloud_build.yaml new file mode 100644 index 0000000..7d62a7e --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_cloud_build.yaml @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +--- +steps: + - name: 'ubuntu' + args: ['echo', 'Hello {{ params.name}}'] diff --git a/reference/providers/google/cloud/example_dags/example_cloud_memorystore.py b/reference/providers/google/cloud/example_dags/example_cloud_memorystore.py new file mode 100644 index 0000000..70d58c0 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_cloud_memorystore.py @@ -0,0 +1,334 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Google Cloud Memorystore service. +""" +import os + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.cloud_memorystore import ( + CloudMemorystoreCreateInstanceAndImportOperator, + CloudMemorystoreCreateInstanceOperator, + CloudMemorystoreDeleteInstanceOperator, + CloudMemorystoreExportAndDeleteInstanceOperator, + CloudMemorystoreExportInstanceOperator, + CloudMemorystoreFailoverInstanceOperator, + CloudMemorystoreGetInstanceOperator, + CloudMemorystoreImportOperator, + CloudMemorystoreListInstancesOperator, + CloudMemorystoreMemcachedApplyParametersOperator, + CloudMemorystoreMemcachedCreateInstanceOperator, + CloudMemorystoreMemcachedDeleteInstanceOperator, + CloudMemorystoreMemcachedGetInstanceOperator, + CloudMemorystoreMemcachedListInstancesOperator, + CloudMemorystoreMemcachedUpdateInstanceOperator, + CloudMemorystoreMemcachedUpdateParametersOperator, + CloudMemorystoreScaleInstanceOperator, + CloudMemorystoreUpdateInstanceOperator, +) +from airflow.providers.google.cloud.operators.gcs import GCSBucketCreateAclEntryOperator +from airflow.utils import dates +from google.cloud.memcache_v1beta2.types import cloud_memcache +from google.cloud.redis_v1 import FailoverInstanceRequest, Instance + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") + +MEMORYSTORE_REDIS_INSTANCE_NAME = os.environ.get( + "GCP_MEMORYSTORE_REDIS_INSTANCE_NAME", "test-memorystore-redis" +) +MEMORYSTORE_REDIS_INSTANCE_NAME_2 = os.environ.get( + "GCP_MEMORYSTORE_REDIS_INSTANCE_NAME_2", "test-memorystore-redis-2" +) +MEMORYSTORE_REDIS_INSTANCE_NAME_3 = os.environ.get( + "GCP_MEMORYSTORE_REDIS_INSTANCE_NAME_3", "test-memorystore-redis-3" +) +MEMORYSTORE_MEMCACHED_INSTANCE_NAME = os.environ.get( + "GCP_MEMORYSTORE_MEMCACHED_INSTANCE_NAME", "test-memorystore-memcached-1" +) + +BUCKET_NAME = os.environ.get("GCP_MEMORYSTORE_BUCKET", "test-memorystore-bucket") +EXPORT_GCS_URL = f"gs://{BUCKET_NAME}/my-export.rdb" + +# [START howto_operator_instance] +FIRST_INSTANCE = {"tier": Instance.Tier.BASIC, "memory_size_gb": 1} +# [END howto_operator_instance] + +SECOND_INSTANCE = {"tier": Instance.Tier.STANDARD_HA, "memory_size_gb": 3} + +# [START howto_operator_memcached_instance] +MEMCACHED_INSTANCE = { + "name": "", + "node_count": 1, + "node_config": {"cpu_count": 1, "memory_size_mb": 1024}, +} +# [END howto_operator_memcached_instance] + + +with models.DAG( + "gcp_cloud_memorystore_redis", + schedule_interval=None, # Override to match your needs + start_date=dates.days_ago(1), + tags=["example"], +) as dag: + # [START howto_operator_create_instance] + create_instance = CloudMemorystoreCreateInstanceOperator( + task_id="create-instance", + location="europe-north1", + instance_id=MEMORYSTORE_REDIS_INSTANCE_NAME, + instance=FIRST_INSTANCE, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_create_instance] + + # [START howto_operator_create_instance_result] + create_instance_result = BashOperator( + task_id="create-instance-result", + bash_command="echo \"{{ task_instance.xcom_pull('create-instance') }}\"", + ) + # [END howto_operator_create_instance_result] + + create_instance_2 = CloudMemorystoreCreateInstanceOperator( + task_id="create-instance-2", + location="europe-north1", + instance_id=MEMORYSTORE_REDIS_INSTANCE_NAME_2, + instance=SECOND_INSTANCE, + project_id=GCP_PROJECT_ID, + ) + + # [START howto_operator_get_instance] + get_instance = CloudMemorystoreGetInstanceOperator( + task_id="get-instance", + location="europe-north1", + instance=MEMORYSTORE_REDIS_INSTANCE_NAME, + project_id=GCP_PROJECT_ID, + do_xcom_push=True, + ) + # [END howto_operator_get_instance] + + # [START howto_operator_get_instance_result] + get_instance_result = BashOperator( + task_id="get-instance-result", + bash_command="echo \"{{ task_instance.xcom_pull('get-instance') }}\"", + ) + # [END howto_operator_get_instance_result] + + # [START howto_operator_failover_instance] + failover_instance = CloudMemorystoreFailoverInstanceOperator( + task_id="failover-instance", + location="europe-north1", + instance=MEMORYSTORE_REDIS_INSTANCE_NAME_2, + data_protection_mode=FailoverInstanceRequest.DataProtectionMode.LIMITED_DATA_LOSS, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_failover_instance] + + # [START howto_operator_list_instances] + list_instances = CloudMemorystoreListInstancesOperator( + task_id="list-instances", location="-", page_size=100, project_id=GCP_PROJECT_ID + ) + # [END howto_operator_list_instances] + + # [START howto_operator_list_instances_result] + list_instances_result = BashOperator( + task_id="list-instances-result", + bash_command="echo \"{{ task_instance.xcom_pull('get-instance') }}\"", + ) + # [END howto_operator_list_instances_result] + + # [START howto_operator_update_instance] + update_instance = CloudMemorystoreUpdateInstanceOperator( + task_id="update-instance", + location="europe-north1", + instance_id=MEMORYSTORE_REDIS_INSTANCE_NAME, + project_id=GCP_PROJECT_ID, + update_mask={"paths": ["memory_size_gb"]}, + instance={"memory_size_gb": 2}, + ) + # [END howto_operator_update_instance] + + # [START howto_operator_set_acl_permission] + set_acl_permission = GCSBucketCreateAclEntryOperator( + task_id="gcs-set-acl-permission", + bucket=BUCKET_NAME, + entity="user-{{ task_instance.xcom_pull('get-instance')['persistence_iam_identity']" + ".split(':', 2)[1] }}", + role="OWNER", + ) + # [END howto_operator_set_acl_permission] + + # [START howto_operator_export_instance] + export_instance = CloudMemorystoreExportInstanceOperator( + task_id="export-instance", + location="europe-north1", + instance=MEMORYSTORE_REDIS_INSTANCE_NAME, + output_config={"gcs_destination": {"uri": EXPORT_GCS_URL}}, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_export_instance] + + # [START howto_operator_import_instance] + import_instance = CloudMemorystoreImportOperator( + task_id="import-instance", + location="europe-north1", + instance=MEMORYSTORE_REDIS_INSTANCE_NAME_2, + input_config={"gcs_source": {"uri": EXPORT_GCS_URL}}, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_import_instance] + + # [START howto_operator_delete_instance] + delete_instance = CloudMemorystoreDeleteInstanceOperator( + task_id="delete-instance", + location="europe-north1", + instance=MEMORYSTORE_REDIS_INSTANCE_NAME, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_delete_instance] + + delete_instance_2 = CloudMemorystoreDeleteInstanceOperator( + task_id="delete-instance-2", + location="europe-north1", + instance=MEMORYSTORE_REDIS_INSTANCE_NAME_2, + project_id=GCP_PROJECT_ID, + ) + + # [END howto_operator_create_instance_and_import] + create_instance_and_import = CloudMemorystoreCreateInstanceAndImportOperator( + task_id="create-instance-and-import", + location="europe-north1", + instance_id=MEMORYSTORE_REDIS_INSTANCE_NAME_3, + instance=FIRST_INSTANCE, + input_config={"gcs_source": {"uri": EXPORT_GCS_URL}}, + project_id=GCP_PROJECT_ID, + ) + # [START howto_operator_create_instance_and_import] + + # [START howto_operator_scale_instance] + scale_instance = CloudMemorystoreScaleInstanceOperator( + task_id="scale-instance", + location="europe-north1", + instance_id=MEMORYSTORE_REDIS_INSTANCE_NAME_3, + project_id=GCP_PROJECT_ID, + memory_size_gb=3, + ) + # [END howto_operator_scale_instance] + + # [END howto_operator_export_and_delete_instance] + export_and_delete_instance = CloudMemorystoreExportAndDeleteInstanceOperator( + task_id="export-and-delete-instance", + location="europe-north1", + instance=MEMORYSTORE_REDIS_INSTANCE_NAME_3, + output_config={"gcs_destination": {"uri": EXPORT_GCS_URL}}, + project_id=GCP_PROJECT_ID, + ) + # [START howto_operator_export_and_delete_instance] + + create_instance >> get_instance >> get_instance_result + create_instance >> update_instance + create_instance >> create_instance_result + create_instance >> export_instance + create_instance_2 >> import_instance + create_instance >> list_instances >> list_instances_result + list_instances >> delete_instance + export_instance >> update_instance + update_instance >> delete_instance + get_instance >> set_acl_permission >> export_instance + export_instance >> import_instance + export_instance >> delete_instance + failover_instance >> delete_instance_2 + import_instance >> failover_instance + + export_instance >> create_instance_and_import >> scale_instance >> export_and_delete_instance + +with models.DAG( + "gcp_cloud_memorystore_memcached", + schedule_interval=None, # Override to match your needs + start_date=dates.days_ago(1), + tags=["example"], +) as dag_memcache: + # [START howto_operator_create_instance_memcached] + create_memcached_instance = CloudMemorystoreMemcachedCreateInstanceOperator( + task_id="create-instance", + location="europe-north1", + instance_id=MEMORYSTORE_MEMCACHED_INSTANCE_NAME, + instance=MEMCACHED_INSTANCE, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_create_instance_memcached] + + # [START howto_operator_delete_instance_memcached] + delete_memcached_instance = CloudMemorystoreMemcachedDeleteInstanceOperator( + task_id="delete-instance", + location="europe-north1", + instance=MEMORYSTORE_MEMCACHED_INSTANCE_NAME, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_delete_instance_memcached] + + # [START howto_operator_get_instance_memcached] + get_memcached_instance = CloudMemorystoreMemcachedGetInstanceOperator( + task_id="get-instance", + location="europe-north1", + instance=MEMORYSTORE_MEMCACHED_INSTANCE_NAME, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_get_instance_memcached] + + # [START howto_operator_list_instances_memcached] + list_memcached_instances = CloudMemorystoreMemcachedListInstancesOperator( + task_id="list-instances", location="-", project_id=GCP_PROJECT_ID + ) + # [END howto_operator_list_instances_memcached] + + # # [START howto_operator_update_instance_memcached] + update_memcached_instance = CloudMemorystoreMemcachedUpdateInstanceOperator( + task_id="update-instance", + location="europe-north1", + instance_id=MEMORYSTORE_MEMCACHED_INSTANCE_NAME, + project_id=GCP_PROJECT_ID, + update_mask=cloud_memcache.field_mask.FieldMask(paths=["node_count"]), + instance={"node_count": 2}, + ) + # [END howto_operator_update_instance_memcached] + + # [START howto_operator_update_and_apply_parameters_memcached] + update_memcached_parameters = CloudMemorystoreMemcachedUpdateParametersOperator( + task_id="update-parameters", + location="europe-north1", + instance_id=MEMORYSTORE_MEMCACHED_INSTANCE_NAME, + project_id=GCP_PROJECT_ID, + update_mask={"paths": ["params"]}, + parameters={"params": {"protocol": "ascii", "hash_algorithm": "jenkins"}}, + ) + + apply_memcached_parameters = CloudMemorystoreMemcachedApplyParametersOperator( + task_id="apply-parameters", + location="europe-north1", + instance_id=MEMORYSTORE_MEMCACHED_INSTANCE_NAME, + project_id=GCP_PROJECT_ID, + node_ids=["node-a-1"], + apply_all=False, + ) + + # update_parameters >> apply_parameters + # [END howto_operator_update_and_apply_parameters_memcached] + + create_memcached_instance >> [list_memcached_instances, get_memcached_instance] + create_memcached_instance >> update_memcached_instance >> update_memcached_parameters + update_memcached_parameters >> apply_memcached_parameters >> delete_memcached_instance diff --git a/reference/providers/google/cloud/example_dags/example_cloud_sql.py b/reference/providers/google/cloud/example_dags/example_cloud_sql.py new file mode 100644 index 0000000..0108bd4 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_cloud_sql.py @@ -0,0 +1,375 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that creates, patches and deletes a Cloud SQL instance, and also +creates, patches and deletes a database inside the instance, in Google Cloud. + +This DAG relies on the following OS environment variables +https://airflow.apache.org/concepts.html#variables +* GCP_PROJECT_ID - Google Cloud project for the Cloud SQL instance. +* INSTANCE_NAME - Name of the Cloud SQL instance. +* DB_NAME - Name of the database inside a Cloud SQL instance. +""" + +import os +from urllib.parse import urlsplit + +from airflow import models +from airflow.providers.google.cloud.operators.cloud_sql import ( + CloudSQLCreateInstanceDatabaseOperator, + CloudSQLCreateInstanceOperator, + CloudSQLDeleteInstanceDatabaseOperator, + CloudSQLDeleteInstanceOperator, + CloudSQLExportInstanceOperator, + CloudSQLImportInstanceOperator, + CloudSQLInstancePatchOperator, + CloudSQLPatchInstanceDatabaseOperator, +) +from airflow.providers.google.cloud.operators.gcs import ( + GCSBucketCreateAclEntryOperator, + GCSObjectCreateAclEntryOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +INSTANCE_NAME = os.environ.get("GCSQL_MYSQL_INSTANCE_NAME", "test-mysql") +INSTANCE_NAME2 = os.environ.get("GCSQL_MYSQL_INSTANCE_NAME2", "test-mysql2") +DB_NAME = os.environ.get("GCSQL_MYSQL_DATABASE_NAME", "testdb") + +EXPORT_URI = os.environ.get("GCSQL_MYSQL_EXPORT_URI", "gs://bucketName/fileName") +IMPORT_URI = os.environ.get("GCSQL_MYSQL_IMPORT_URI", "gs://bucketName/fileName") + +# Bodies below represent Cloud SQL instance resources: +# https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/instances + +FAILOVER_REPLICA_NAME = INSTANCE_NAME + "-failover-replica" +READ_REPLICA_NAME = INSTANCE_NAME + "-read-replica" + +# [START howto_operator_cloudsql_create_body] +body = { + "name": INSTANCE_NAME, + "settings": { + "tier": "db-n1-standard-1", + "backupConfiguration": { + "binaryLogEnabled": True, + "enabled": True, + "startTime": "05:00", + }, + "activationPolicy": "ALWAYS", + "dataDiskSizeGb": 30, + "dataDiskType": "PD_SSD", + "databaseFlags": [], + "ipConfiguration": { + "ipv4Enabled": True, + "requireSsl": True, + }, + "locationPreference": {"zone": "europe-west4-a"}, + "maintenanceWindow": {"hour": 5, "day": 7, "updateTrack": "canary"}, + "pricingPlan": "PER_USE", + "replicationType": "ASYNCHRONOUS", + "storageAutoResize": True, + "storageAutoResizeLimit": 0, + "userLabels": {"my-key": "my-value"}, + }, + "failoverReplica": {"name": FAILOVER_REPLICA_NAME}, + "databaseVersion": "MYSQL_5_7", + "region": "europe-west4", +} +# [END howto_operator_cloudsql_create_body] + +body2 = { + "name": INSTANCE_NAME2, + "settings": { + "tier": "db-n1-standard-1", + }, + "databaseVersion": "MYSQL_5_7", + "region": "europe-west4", +} + +# [START howto_operator_cloudsql_create_replica] +read_replica_body = { + "name": READ_REPLICA_NAME, + "settings": { + "tier": "db-n1-standard-1", + }, + "databaseVersion": "MYSQL_5_7", + "region": "europe-west4", + "masterInstanceName": INSTANCE_NAME, +} +# [END howto_operator_cloudsql_create_replica] + + +# [START howto_operator_cloudsql_patch_body] +patch_body = { + "name": INSTANCE_NAME, + "settings": { + "dataDiskSizeGb": 35, + "maintenanceWindow": {"hour": 3, "day": 6, "updateTrack": "canary"}, + "userLabels": {"my-key-patch": "my-value-patch"}, + }, +} +# [END howto_operator_cloudsql_patch_body] +# [START howto_operator_cloudsql_export_body] +export_body = { + "exportContext": { + "fileType": "sql", + "uri": EXPORT_URI, + "sqlExportOptions": {"schemaOnly": False}, + } +} +# [END howto_operator_cloudsql_export_body] +# [START howto_operator_cloudsql_import_body] +import_body = {"importContext": {"fileType": "sql", "uri": IMPORT_URI}} +# [END howto_operator_cloudsql_import_body] +# [START howto_operator_cloudsql_db_create_body] +db_create_body = {"instance": INSTANCE_NAME, "name": DB_NAME, "project": GCP_PROJECT_ID} +# [END howto_operator_cloudsql_db_create_body] +# [START howto_operator_cloudsql_db_patch_body] +db_patch_body = {"charset": "utf16", "collation": "utf16_general_ci"} +# [END howto_operator_cloudsql_db_patch_body] + +with models.DAG( + "example_gcp_sql", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + # ############################################## # + # ### INSTANCES SET UP ######################### # + # ############################################## # + + # [START howto_operator_cloudsql_create] + sql_instance_create_task = CloudSQLCreateInstanceOperator( + project_id=GCP_PROJECT_ID, + body=body, + instance=INSTANCE_NAME, + task_id="sql_instance_create_task", + ) + # [END howto_operator_cloudsql_create] + + sql_instance_create_2_task = CloudSQLCreateInstanceOperator( + project_id=GCP_PROJECT_ID, + body=body2, + instance=INSTANCE_NAME2, + task_id="sql_instance_create_task2", + ) + # [END howto_operator_cloudsql_create] + + sql_instance_read_replica_create = CloudSQLCreateInstanceOperator( + project_id=GCP_PROJECT_ID, + body=read_replica_body, + instance=READ_REPLICA_NAME, + task_id="sql_instance_read_replica_create", + ) + + # ############################################## # + # ### MODIFYING INSTANCE AND ITS DATABASE ###### # + # ############################################## # + + # [START howto_operator_cloudsql_patch] + sql_instance_patch_task = CloudSQLInstancePatchOperator( + project_id=GCP_PROJECT_ID, + body=patch_body, + instance=INSTANCE_NAME, + task_id="sql_instance_patch_task", + ) + # [END howto_operator_cloudsql_patch] + + sql_instance_patch_task2 = CloudSQLInstancePatchOperator( + project_id=GCP_PROJECT_ID, + body=patch_body, + instance=INSTANCE_NAME, + task_id="sql_instance_patch_task2", + ) + + # [START howto_operator_cloudsql_db_create] + sql_db_create_task = CloudSQLCreateInstanceDatabaseOperator( + project_id=GCP_PROJECT_ID, + body=db_create_body, + instance=INSTANCE_NAME, + task_id="sql_db_create_task", + ) + sql_db_create_task2 = CloudSQLCreateInstanceDatabaseOperator( + body=db_create_body, instance=INSTANCE_NAME, task_id="sql_db_create_task2" + ) + # [END howto_operator_cloudsql_db_create] + + # [START howto_operator_cloudsql_db_patch] + sql_db_patch_task = CloudSQLPatchInstanceDatabaseOperator( + project_id=GCP_PROJECT_ID, + body=db_patch_body, + instance=INSTANCE_NAME, + database=DB_NAME, + task_id="sql_db_patch_task", + ) + sql_db_patch_task2 = CloudSQLPatchInstanceDatabaseOperator( + body=db_patch_body, + instance=INSTANCE_NAME, + database=DB_NAME, + task_id="sql_db_patch_task2", + ) + # [END howto_operator_cloudsql_db_patch] + + # ############################################## # + # ### EXPORTING SQL FROM INSTANCE 1 ############ # + # ############################################## # + export_url_split = urlsplit(EXPORT_URI) + + # For export to work we need to add the Cloud SQL instance's Service Account + # write access to the destination GCS bucket. + # [START howto_operator_cloudsql_export_gcs_permissions] + sql_gcp_add_bucket_permission_task = GCSBucketCreateAclEntryOperator( + entity="user-{{ task_instance.xcom_pull(" + "'sql_instance_create_task', key='service_account_email') " + "}}", + role="WRITER", + bucket=export_url_split[1], # netloc (bucket) + task_id="sql_gcp_add_bucket_permission_task", + ) + # [END howto_operator_cloudsql_export_gcs_permissions] + + # [START howto_operator_cloudsql_export] + sql_export_task = CloudSQLExportInstanceOperator( + project_id=GCP_PROJECT_ID, + body=export_body, + instance=INSTANCE_NAME, + task_id="sql_export_task", + ) + sql_export_task2 = CloudSQLExportInstanceOperator( + body=export_body, instance=INSTANCE_NAME, task_id="sql_export_task2" + ) + # [END howto_operator_cloudsql_export] + + # ############################################## # + # ### IMPORTING SQL TO INSTANCE 2 ############## # + # ############################################## # + import_url_split = urlsplit(IMPORT_URI) + + # For import to work we need to add the Cloud SQL instance's Service Account + # read access to the target GCS object. + # [START howto_operator_cloudsql_import_gcs_permissions] + sql_gcp_add_object_permission_task = GCSObjectCreateAclEntryOperator( + entity="user-{{ task_instance.xcom_pull(" + "'sql_instance_create_task2', key='service_account_email')" + " }}", + role="READER", + bucket=import_url_split[1], # netloc (bucket) + object_name=import_url_split[2][1:], # path (strip first '/') + task_id="sql_gcp_add_object_permission_task", + ) + + # For import to work we also need to add the Cloud SQL instance's Service Account + # write access to the whole bucket!. + sql_gcp_add_bucket_permission_2_task = GCSBucketCreateAclEntryOperator( + entity="user-{{ task_instance.xcom_pull(" + "'sql_instance_create_task2', key='service_account_email') " + "}}", + role="WRITER", + bucket=import_url_split[1], # netloc + task_id="sql_gcp_add_bucket_permission_2_task", + ) + # [END howto_operator_cloudsql_import_gcs_permissions] + + # [START howto_operator_cloudsql_import] + sql_import_task = CloudSQLImportInstanceOperator( + project_id=GCP_PROJECT_ID, + body=import_body, + instance=INSTANCE_NAME2, + task_id="sql_import_task", + ) + sql_import_task2 = CloudSQLImportInstanceOperator( + body=import_body, instance=INSTANCE_NAME2, task_id="sql_import_task2" + ) + # [END howto_operator_cloudsql_import] + + # ############################################## # + # ### DELETING A DATABASE FROM AN INSTANCE ##### # + # ############################################## # + + # [START howto_operator_cloudsql_db_delete] + sql_db_delete_task = CloudSQLDeleteInstanceDatabaseOperator( + project_id=GCP_PROJECT_ID, + instance=INSTANCE_NAME, + database=DB_NAME, + task_id="sql_db_delete_task", + ) + sql_db_delete_task2 = CloudSQLDeleteInstanceDatabaseOperator( + instance=INSTANCE_NAME, database=DB_NAME, task_id="sql_db_delete_task2" + ) + # [END howto_operator_cloudsql_db_delete] + + # ############################################## # + # ### INSTANCES TEAR DOWN ###################### # + # ############################################## # + + # [START howto_operator_cloudsql_replicas_delete] + sql_instance_failover_replica_delete_task = CloudSQLDeleteInstanceOperator( + project_id=GCP_PROJECT_ID, + instance=FAILOVER_REPLICA_NAME, + task_id="sql_instance_failover_replica_delete_task", + ) + + sql_instance_read_replica_delete_task = CloudSQLDeleteInstanceOperator( + project_id=GCP_PROJECT_ID, + instance=READ_REPLICA_NAME, + task_id="sql_instance_read_replica_delete_task", + ) + # [END howto_operator_cloudsql_replicas_delete] + + # [START howto_operator_cloudsql_delete] + sql_instance_delete_task = CloudSQLDeleteInstanceOperator( + project_id=GCP_PROJECT_ID, + instance=INSTANCE_NAME, + task_id="sql_instance_delete_task", + ) + sql_instance_delete_task2 = CloudSQLDeleteInstanceOperator( + instance=INSTANCE_NAME2, task_id="sql_instance_delete_task2" + ) + # [END howto_operator_cloudsql_delete] + + sql_instance_delete_2_task = CloudSQLDeleteInstanceOperator( + project_id=GCP_PROJECT_ID, + instance=INSTANCE_NAME2, + task_id="sql_instance_delete_2_task", + ) + + ( + sql_instance_create_task # noqa + >> sql_instance_create_2_task # noqa + >> sql_instance_read_replica_create # noqa + >> sql_instance_patch_task # noqa + >> sql_instance_patch_task2 # noqa + >> sql_db_create_task # noqa + >> sql_db_create_task2 # noqa + >> sql_db_patch_task # noqa + >> sql_db_patch_task2 # noqa + >> sql_gcp_add_bucket_permission_task # noqa + >> sql_export_task # noqa + >> sql_export_task2 # noqa + >> sql_gcp_add_object_permission_task # noqa + >> sql_gcp_add_bucket_permission_2_task # noqa + >> sql_import_task # noqa + >> sql_import_task2 # noqa + >> sql_db_delete_task # noqa + >> sql_db_delete_task2 # noqa + >> sql_instance_failover_replica_delete_task # noqa + >> sql_instance_read_replica_delete_task # noqa + >> sql_instance_delete_task # noqa + >> sql_instance_delete_2_task # noqa + ) diff --git a/reference/providers/google/cloud/example_dags/example_cloud_sql_query.py b/reference/providers/google/cloud/example_dags/example_cloud_sql_query.py new file mode 100644 index 0000000..664ac0c --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_cloud_sql_query.py @@ -0,0 +1,308 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that performs query in a Cloud SQL instance. + +This DAG relies on the following OS environment variables + +* GCP_PROJECT_ID - Google Cloud project for the Cloud SQL instance +* GCP_REGION - Google Cloud region where the database is created +* +* GCSQL_POSTGRES_INSTANCE_NAME - Name of the postgres Cloud SQL instance +* GCSQL_POSTGRES_USER - Name of the postgres database user +* GCSQL_POSTGRES_PASSWORD - Password of the postgres database user +* GCSQL_POSTGRES_PUBLIC_IP - Public IP of the Postgres database +* GCSQL_POSTGRES_PUBLIC_PORT - Port of the postgres database +* +* GCSQL_MYSQL_INSTANCE_NAME - Name of the postgres Cloud SQL instance +* GCSQL_MYSQL_USER - Name of the mysql database user +* GCSQL_MYSQL_PASSWORD - Password of the mysql database user +* GCSQL_MYSQL_PUBLIC_IP - Public IP of the mysql database +* GCSQL_MYSQL_PUBLIC_PORT - Port of the mysql database +""" +import os +import subprocess +from os.path import expanduser +from urllib.parse import quote_plus + +from airflow import models +from airflow.providers.google.cloud.operators.cloud_sql import ( + CloudSQLExecuteQueryOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCP_REGION = os.environ.get("GCP_REGION", "europe-west-1b") + +GCSQL_POSTGRES_INSTANCE_NAME_QUERY = os.environ.get( + "GCSQL_POSTGRES_INSTANCE_NAME_QUERY", "testpostgres" +) +GCSQL_POSTGRES_DATABASE_NAME = os.environ.get( + "GCSQL_POSTGRES_DATABASE_NAME", "postgresdb" +) +GCSQL_POSTGRES_USER = os.environ.get("GCSQL_POSTGRES_USER", "postgres_user") +GCSQL_POSTGRES_PASSWORD = os.environ.get("GCSQL_POSTGRES_PASSWORD", "password") +GCSQL_POSTGRES_PUBLIC_IP = os.environ.get("GCSQL_POSTGRES_PUBLIC_IP", "0.0.0.0") +GCSQL_POSTGRES_PUBLIC_PORT = os.environ.get("GCSQL_POSTGRES_PUBLIC_PORT", 5432) +GCSQL_POSTGRES_CLIENT_CERT_FILE = os.environ.get( + "GCSQL_POSTGRES_CLIENT_CERT_FILE", ".key/postgres-client-cert.pem" +) +GCSQL_POSTGRES_CLIENT_KEY_FILE = os.environ.get( + "GCSQL_POSTGRES_CLIENT_KEY_FILE", ".key/postgres-client-key.pem" +) +GCSQL_POSTGRES_SERVER_CA_FILE = os.environ.get( + "GCSQL_POSTGRES_SERVER_CA_FILE", ".key/postgres-server-ca.pem" +) + +GCSQL_MYSQL_INSTANCE_NAME_QUERY = os.environ.get( + "GCSQL_MYSQL_INSTANCE_NAME_QUERY", "testmysql" +) +GCSQL_MYSQL_DATABASE_NAME = os.environ.get("GCSQL_MYSQL_DATABASE_NAME", "mysqldb") +GCSQL_MYSQL_USER = os.environ.get("GCSQL_MYSQL_USER", "mysql_user") +GCSQL_MYSQL_PASSWORD = os.environ.get("GCSQL_MYSQL_PASSWORD", "password") +GCSQL_MYSQL_PUBLIC_IP = os.environ.get("GCSQL_MYSQL_PUBLIC_IP", "0.0.0.0") +GCSQL_MYSQL_PUBLIC_PORT = os.environ.get("GCSQL_MYSQL_PUBLIC_PORT", 3306) +GCSQL_MYSQL_CLIENT_CERT_FILE = os.environ.get( + "GCSQL_MYSQL_CLIENT_CERT_FILE", ".key/mysql-client-cert.pem" +) +GCSQL_MYSQL_CLIENT_KEY_FILE = os.environ.get( + "GCSQL_MYSQL_CLIENT_KEY_FILE", ".key/mysql-client-key.pem" +) +GCSQL_MYSQL_SERVER_CA_FILE = os.environ.get( + "GCSQL_MYSQL_SERVER_CA_FILE", ".key/mysql-server-ca.pem" +) + +SQL = [ + "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)", + "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)", # shows warnings logged + "INSERT INTO TABLE_TEST VALUES (0)", + "CREATE TABLE IF NOT EXISTS TABLE_TEST2 (I INTEGER)", + "DROP TABLE TABLE_TEST", + "DROP TABLE TABLE_TEST2", +] + + +# [START howto_operator_cloudsql_query_connections] + +HOME_DIR = expanduser("~") + + +def get_absolute_path(path): + """ + Returns absolute path. + """ + if path.startswith("/"): + return path + else: + return os.path.join(HOME_DIR, path) + + +postgres_kwargs = dict( + user=quote_plus(GCSQL_POSTGRES_USER), + password=quote_plus(GCSQL_POSTGRES_PASSWORD), + public_port=GCSQL_POSTGRES_PUBLIC_PORT, + public_ip=quote_plus(GCSQL_POSTGRES_PUBLIC_IP), + project_id=quote_plus(GCP_PROJECT_ID), + location=quote_plus(GCP_REGION), + instance=quote_plus(GCSQL_POSTGRES_INSTANCE_NAME_QUERY), + database=quote_plus(GCSQL_POSTGRES_DATABASE_NAME), + client_cert_file=quote_plus(get_absolute_path(GCSQL_POSTGRES_CLIENT_CERT_FILE)), + client_key_file=quote_plus(get_absolute_path(GCSQL_POSTGRES_CLIENT_KEY_FILE)), + server_ca_file=quote_plus(get_absolute_path(GCSQL_POSTGRES_SERVER_CA_FILE)), +) + +# The connections below are created using one of the standard approaches - via environment +# variables named AIRFLOW_CONN_* . The connections can also be created in the database +# of AIRFLOW (using command line or UI). + +# Postgres: connect via proxy over TCP +os.environ["AIRFLOW_CONN_PROXY_POSTGRES_TCP"] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=postgres&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=True&" + "sql_proxy_use_tcp=True".format(**postgres_kwargs) +) + +# Postgres: connect via proxy over UNIX socket (specific proxy version) +os.environ["AIRFLOW_CONN_PROXY_POSTGRES_SOCKET"] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=postgres&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=True&" + "sql_proxy_version=v1.13&" + "sql_proxy_use_tcp=False".format(**postgres_kwargs) +) + +# Postgres: connect directly via TCP (non-SSL) +os.environ["AIRFLOW_CONN_PUBLIC_POSTGRES_TCP"] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=postgres&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=False&" + "use_ssl=False".format(**postgres_kwargs) +) + +# Postgres: connect directly via TCP (SSL) +os.environ["AIRFLOW_CONN_PUBLIC_POSTGRES_TCP_SSL"] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=postgres&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=False&" + "use_ssl=True&" + "sslcert={client_cert_file}&" + "sslkey={client_key_file}&" + "sslrootcert={server_ca_file}".format(**postgres_kwargs) +) + +mysql_kwargs = dict( + user=quote_plus(GCSQL_MYSQL_USER), + password=quote_plus(GCSQL_MYSQL_PASSWORD), + public_port=GCSQL_MYSQL_PUBLIC_PORT, + public_ip=quote_plus(GCSQL_MYSQL_PUBLIC_IP), + project_id=quote_plus(GCP_PROJECT_ID), + location=quote_plus(GCP_REGION), + instance=quote_plus(GCSQL_MYSQL_INSTANCE_NAME_QUERY), + database=quote_plus(GCSQL_MYSQL_DATABASE_NAME), + client_cert_file=quote_plus(get_absolute_path(GCSQL_MYSQL_CLIENT_CERT_FILE)), + client_key_file=quote_plus(get_absolute_path(GCSQL_MYSQL_CLIENT_KEY_FILE)), + server_ca_file=quote_plus(get_absolute_path(GCSQL_MYSQL_SERVER_CA_FILE)), +) + +# MySQL: connect via proxy over TCP (specific proxy version) +os.environ["AIRFLOW_CONN_PROXY_MYSQL_TCP"] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=mysql&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=True&" + "sql_proxy_version=v1.13&" + "sql_proxy_use_tcp=True".format(**mysql_kwargs) +) + +# MySQL: connect via proxy over UNIX socket using pre-downloaded Cloud Sql Proxy binary +try: + sql_proxy_binary_path = ( + subprocess.check_output(["which", "cloud_sql_proxy"]).decode("utf-8").rstrip() + ) +except subprocess.CalledProcessError: + sql_proxy_binary_path = "/tmp/anyhow_download_cloud_sql_proxy" + +os.environ["AIRFLOW_CONN_PROXY_MYSQL_SOCKET"] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=mysql&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=True&" + "sql_proxy_binary_path={sql_proxy_binary_path}&" + "sql_proxy_use_tcp=False".format( + sql_proxy_binary_path=quote_plus(sql_proxy_binary_path), **mysql_kwargs + ) +) + +# MySQL: connect directly via TCP (non-SSL) +os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP"] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=mysql&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=False&" + "use_ssl=False".format(**mysql_kwargs) +) + +# MySQL: connect directly via TCP (SSL) and with fixed Cloud Sql Proxy binary path +os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL"] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=mysql&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=False&" + "use_ssl=True&" + "sslcert={client_cert_file}&" + "sslkey={client_key_file}&" + "sslrootcert={server_ca_file}".format(**mysql_kwargs) +) + +# Special case: MySQL: connect directly via TCP (SSL) and with fixed Cloud Sql +# Proxy binary path AND with missing project_id + +os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL_NO_PROJECT_ID"] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=mysql&" + "location={location}&" + "instance={instance}&" + "use_proxy=False&" + "use_ssl=True&" + "sslcert={client_cert_file}&" + "sslkey={client_key_file}&" + "sslrootcert={server_ca_file}".format(**mysql_kwargs) +) + + +# [END howto_operator_cloudsql_query_connections] + +# [START howto_operator_cloudsql_query_operators] + +connection_names = [ + "proxy_postgres_tcp", + "proxy_postgres_socket", + "public_postgres_tcp", + "public_postgres_tcp_ssl", + "proxy_mysql_tcp", + "proxy_mysql_socket", + "public_mysql_tcp", + "public_mysql_tcp_ssl", + "public_mysql_tcp_ssl_no_project_id", +] + +tasks = [] + + +with models.DAG( + dag_id="example_gcp_sql_query", + schedule_interval=None, + start_date=days_ago(1), + tags=["example"], +) as dag: + prev_task = None + + for connection_name in connection_names: + task = CloudSQLExecuteQueryOperator( + gcp_cloudsql_conn_id=connection_name, + task_id="example_gcp_sql_task_" + connection_name, + sql=SQL, + ) + tasks.append(task) + if prev_task: + prev_task >> task + prev_task = task + +# [END howto_operator_cloudsql_query_operators] diff --git a/reference/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_aws.py b/reference/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_aws.py new file mode 100644 index 0000000..1f6421c --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_aws.py @@ -0,0 +1,202 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that demonstrates interactions with Google Cloud Transfer. + + +This DAG relies on the following OS environment variables + +* GCP_PROJECT_ID - Google Cloud Project to use for the Google Cloud Transfer Service. +* GCP_DESCRIPTION - Description of transfer job +* GCP_TRANSFER_SOURCE_AWS_BUCKET - Amazon Web Services Storage bucket from which files are copied. + .. warning:: + You need to provide a large enough set of data so that operations do not execute too quickly. + Otherwise, DAG will fail. +* GCP_TRANSFER_SECOND_TARGET_BUCKET - Google Cloud Storage bucket to which files are copied +* WAIT_FOR_OPERATION_POKE_INTERVAL - interval of what to check the status of the operation + A smaller value than the default value accelerates the system test and ensures its correct execution with + smaller quantities of files in the source bucket + Look at documentation of :class:`~airflow.operators.sensors.BaseSensorOperator` for more information + +""" + +import os +from datetime import datetime, timedelta + +from airflow import models +from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( + ALREADY_EXISTING_IN_SINK, + AWS_S3_DATA_SOURCE, + BUCKET_NAME, + DESCRIPTION, + FILTER_JOB_NAMES, + FILTER_PROJECT_ID, + GCS_DATA_SINK, + JOB_NAME, + PROJECT_ID, + SCHEDULE, + SCHEDULE_END_DATE, + SCHEDULE_START_DATE, + START_TIME_OF_DAY, + STATUS, + TRANSFER_OPTIONS, + TRANSFER_SPEC, + GcpTransferJobsStatus, + GcpTransferOperationStatus, +) +from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import ( + CloudDataTransferServiceCancelOperationOperator, + CloudDataTransferServiceCreateJobOperator, + CloudDataTransferServiceDeleteJobOperator, + CloudDataTransferServiceGetOperationOperator, + CloudDataTransferServiceListOperationsOperator, + CloudDataTransferServicePauseOperationOperator, + CloudDataTransferServiceResumeOperationOperator, +) +from airflow.providers.google.cloud.sensors.cloud_storage_transfer_service import ( + CloudDataTransferServiceJobStatusSensor, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCP_DESCRIPTION = os.environ.get("GCP_DESCRIPTION", "description") +GCP_TRANSFER_TARGET_BUCKET = os.environ.get("GCP_TRANSFER_TARGET_BUCKET") +WAIT_FOR_OPERATION_POKE_INTERVAL = int( + os.environ.get("WAIT_FOR_OPERATION_POKE_INTERVAL", 5) +) + +GCP_TRANSFER_SOURCE_AWS_BUCKET = os.environ.get("GCP_TRANSFER_SOURCE_AWS_BUCKET") +GCP_TRANSFER_FIRST_TARGET_BUCKET = os.environ.get( + "GCP_TRANSFER_FIRST_TARGET_BUCKET", "gcp-transfer-first-target" +) + +GCP_TRANSFER_JOB_NAME = os.environ.get( + "GCP_TRANSFER_JOB_NAME", "transferJobs/sampleJob" +) + +# [START howto_operator_gcp_transfer_create_job_body_aws] +aws_to_gcs_transfer_body = { + DESCRIPTION: GCP_DESCRIPTION, + STATUS: GcpTransferJobsStatus.ENABLED, + PROJECT_ID: GCP_PROJECT_ID, + JOB_NAME: GCP_TRANSFER_JOB_NAME, + SCHEDULE: { + SCHEDULE_START_DATE: datetime(2015, 1, 1).date(), + SCHEDULE_END_DATE: datetime(2030, 1, 1).date(), + START_TIME_OF_DAY: (datetime.utcnow() + timedelta(minutes=2)).time(), + }, + TRANSFER_SPEC: { + AWS_S3_DATA_# {BUCKET_NAME: GCP_TRANSFER_SOURCE_AWS_BUCKET}, + GCS_DATA_SINK: {BUCKET_NAME: GCP_TRANSFER_FIRST_TARGET_BUCKET}, + TRANSFER_OPTIONS: {ALREADY_EXISTING_IN_SINK: True}, + }, +} +# [END howto_operator_gcp_transfer_create_job_body_aws] + + +# [START howto_operator_gcp_transfer_default_args] +default_args = {"owner": "airflow"} +# [END howto_operator_gcp_transfer_default_args] + +with models.DAG( + "example_gcp_transfer_aws", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + + # [START howto_operator_gcp_transfer_create_job] + create_transfer_job_from_aws = CloudDataTransferServiceCreateJobOperator( + task_id="create_transfer_job_from_aws", body=aws_to_gcs_transfer_body + ) + # [END howto_operator_gcp_transfer_create_job] + + wait_for_operation_to_start = CloudDataTransferServiceJobStatusSensor( + task_id="wait_for_operation_to_start", + job_name="{{task_instance.xcom_pull('create_transfer_job_from_aws')['name']}}", + project_id=GCP_PROJECT_ID, + expected_statuses={GcpTransferOperationStatus.IN_PROGRESS}, + poke_interval=WAIT_FOR_OPERATION_POKE_INTERVAL, + ) + + # [START howto_operator_gcp_transfer_pause_operation] + pause_operation = CloudDataTransferServicePauseOperationOperator( + task_id="pause_operation", + operation_name="{{task_instance.xcom_pull('wait_for_operation_to_start', " + "key='sensed_operations')[0]['name']}}", + ) + # [END howto_operator_gcp_transfer_pause_operation] + + # [START howto_operator_gcp_transfer_list_operations] + list_operations = CloudDataTransferServiceListOperationsOperator( + task_id="list_operations", + request_filter={ + FILTER_PROJECT_ID: GCP_PROJECT_ID, + FILTER_JOB_NAMES: [ + "{{task_instance.xcom_pull('create_transfer_job_from_aws')['name']}}" + ], + }, + ) + # [END howto_operator_gcp_transfer_list_operations] + + # [START howto_operator_gcp_transfer_get_operation] + get_operation = CloudDataTransferServiceGetOperationOperator( + task_id="get_operation", + operation_name="{{task_instance.xcom_pull('list_operations')[0]['name']}}", + ) + # [END howto_operator_gcp_transfer_get_operation] + + # [START howto_operator_gcp_transfer_resume_operation] + resume_operation = CloudDataTransferServiceResumeOperationOperator( + task_id="resume_operation", + operation_name="{{task_instance.xcom_pull('get_operation')['name']}}", + ) + # [END howto_operator_gcp_transfer_resume_operation] + + # [START howto_operator_gcp_transfer_wait_operation] + wait_for_operation_to_end = CloudDataTransferServiceJobStatusSensor( + task_id="wait_for_operation_to_end", + job_name="{{task_instance.xcom_pull('create_transfer_job_from_aws')['name']}}", + project_id=GCP_PROJECT_ID, + expected_statuses={GcpTransferOperationStatus.SUCCESS}, + poke_interval=WAIT_FOR_OPERATION_POKE_INTERVAL, + ) + # [END howto_operator_gcp_transfer_wait_operation] + + # [START howto_operator_gcp_transfer_cancel_operation] + cancel_operation = CloudDataTransferServiceCancelOperationOperator( + task_id="cancel_operation", + operation_name="{{task_instance.xcom_pull(" + "'wait_for_second_operation_to_start', key='sensed_operations')[0]['name']}}", + ) + # [END howto_operator_gcp_transfer_cancel_operation] + + # [START howto_operator_gcp_transfer_delete_job] + delete_transfer_from_aws_job = CloudDataTransferServiceDeleteJobOperator( + task_id="delete_transfer_from_aws_job", + job_name="{{task_instance.xcom_pull('create_transfer_job_from_aws')['name']}}", + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_gcp_transfer_delete_job] + + # fmt: off + create_transfer_job_from_aws >> wait_for_operation_to_start >> pause_operation + pause_operation >> list_operations >> get_operation >> resume_operation + resume_operation >> wait_for_operation_to_end >> cancel_operation >> delete_transfer_from_aws_job + # fmt: on diff --git a/reference/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_gcp.py b/reference/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_gcp.py new file mode 100644 index 0000000..77fe5cc --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_gcp.py @@ -0,0 +1,150 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that demonstrates interactions with Google Cloud Transfer. + + +This DAG relies on the following OS environment variables + +* GCP_PROJECT_ID - Google Cloud Project to use for the Google Cloud Transfer Service. +* GCP_TRANSFER_FIRST_TARGET_BUCKET - Google Cloud Storage bucket to which files are copied from AWS. + It is also a source bucket in next step +* GCP_TRANSFER_SECOND_TARGET_BUCKET - Google Cloud Storage bucket to which files are copied +""" + +import os +from datetime import datetime, timedelta + +from airflow import models +from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( + ALREADY_EXISTING_IN_SINK, + BUCKET_NAME, + DESCRIPTION, + FILTER_JOB_NAMES, + FILTER_PROJECT_ID, + GCS_DATA_SINK, + GCS_DATA_SOURCE, + PROJECT_ID, + SCHEDULE, + SCHEDULE_END_DATE, + SCHEDULE_START_DATE, + START_TIME_OF_DAY, + STATUS, + TRANSFER_JOB, + TRANSFER_JOB_FIELD_MASK, + TRANSFER_OPTIONS, + TRANSFER_SPEC, + GcpTransferJobsStatus, + GcpTransferOperationStatus, +) +from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import ( + CloudDataTransferServiceCreateJobOperator, + CloudDataTransferServiceDeleteJobOperator, + CloudDataTransferServiceGetOperationOperator, + CloudDataTransferServiceListOperationsOperator, + CloudDataTransferServiceUpdateJobOperator, +) +from airflow.providers.google.cloud.sensors.cloud_storage_transfer_service import ( + CloudDataTransferServiceJobStatusSensor, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCP_TRANSFER_FIRST_TARGET_BUCKET = os.environ.get( + "GCP_TRANSFER_FIRST_TARGET_BUCKET", "gcp-transfer-first-target" +) +GCP_TRANSFER_SECOND_TARGET_BUCKET = os.environ.get( + "GCP_TRANSFER_SECOND_TARGET_BUCKET", "gcp-transfer-second-target" +) + +# [START howto_operator_gcp_transfer_create_job_body_gcp] +gcs_to_gcs_transfer_body = { + DESCRIPTION: "description", + STATUS: GcpTransferJobsStatus.ENABLED, + PROJECT_ID: GCP_PROJECT_ID, + SCHEDULE: { + SCHEDULE_START_DATE: datetime(2015, 1, 1).date(), + SCHEDULE_END_DATE: datetime(2030, 1, 1).date(), + START_TIME_OF_DAY: (datetime.utcnow() + timedelta(seconds=120)).time(), + }, + TRANSFER_SPEC: { + GCS_DATA_# {BUCKET_NAME: GCP_TRANSFER_FIRST_TARGET_BUCKET}, + GCS_DATA_SINK: {BUCKET_NAME: GCP_TRANSFER_SECOND_TARGET_BUCKET}, + TRANSFER_OPTIONS: {ALREADY_EXISTING_IN_SINK: True}, + }, +} +# [END howto_operator_gcp_transfer_create_job_body_gcp] + +# [START howto_operator_gcp_transfer_update_job_body] +update_body = { + PROJECT_ID: GCP_PROJECT_ID, + TRANSFER_JOB: {DESCRIPTION: "description_updated"}, + TRANSFER_JOB_FIELD_MASK: "description", +} +# [END howto_operator_gcp_transfer_update_job_body] + +with models.DAG( + "example_gcp_transfer", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + + create_transfer = CloudDataTransferServiceCreateJobOperator( + task_id="create_transfer", body=gcs_to_gcs_transfer_body + ) + + # [START howto_operator_gcp_transfer_update_job] + update_transfer = CloudDataTransferServiceUpdateJobOperator( + task_id="update_transfer", + job_name="{{task_instance.xcom_pull('create_transfer')['name']}}", + body=update_body, + ) + # [END howto_operator_gcp_transfer_update_job] + + wait_for_transfer = CloudDataTransferServiceJobStatusSensor( + task_id="wait_for_transfer", + job_name="{{task_instance.xcom_pull('create_transfer')['name']}}", + project_id=GCP_PROJECT_ID, + expected_statuses={GcpTransferOperationStatus.SUCCESS}, + ) + + list_operations = CloudDataTransferServiceListOperationsOperator( + task_id="list_operations", + request_filter={ + FILTER_PROJECT_ID: GCP_PROJECT_ID, + FILTER_JOB_NAMES: [ + "{{task_instance.xcom_pull('create_transfer')['name']}}" + ], + }, + ) + + get_operation = CloudDataTransferServiceGetOperationOperator( + task_id="get_operation", + operation_name="{{task_instance.xcom_pull('list_operations')[0]['name']}}", + ) + + delete_transfer = CloudDataTransferServiceDeleteJobOperator( + task_id="delete_transfer_from_gcp_job", + job_name="{{task_instance.xcom_pull('create_transfer')['name']}}", + project_id=GCP_PROJECT_ID, + ) + + create_transfer >> wait_for_transfer >> update_transfer >> list_operations >> get_operation + get_operation >> delete_transfer diff --git a/reference/providers/google/cloud/example_dags/example_compute.py b/reference/providers/google/cloud/example_dags/example_compute.py new file mode 100644 index 0000000..1ca5cb7 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_compute.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that starts, stops and sets the machine type of a Google Compute +Engine instance. + +This DAG relies on the following OS environment variables + +* GCP_PROJECT_ID - Google Cloud project where the Compute Engine instance exists. +* GCE_ZONE - Google Cloud zone where the instance exists. +* GCE_INSTANCE - Name of the Compute Engine instance. +* GCE_SHORT_MACHINE_TYPE_NAME - Machine type resource name to set, e.g. 'n1-standard-1'. + See https://cloud.google.com/compute/docs/machine-types +""" + +import os + +from airflow import models +from airflow.providers.google.cloud.operators.compute import ( + ComputeEngineSetMachineTypeOperator, + ComputeEngineStartInstanceOperator, + ComputeEngineStopInstanceOperator, +) +from airflow.utils.dates import days_ago + +# [START howto_operator_gce_args_common] +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCE_ZONE = os.environ.get("GCE_ZONE", "europe-west1-b") +GCE_INSTANCE = os.environ.get("GCE_INSTANCE", "testinstance") +# [END howto_operator_gce_args_common] + + +GCE_SHORT_MACHINE_TYPE_NAME = os.environ.get( + "GCE_SHORT_MACHINE_TYPE_NAME", "n1-standard-1" +) + + +with models.DAG( + "example_gcp_compute", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + # [START howto_operator_gce_start] + gce_instance_start = ComputeEngineStartInstanceOperator( + project_id=GCP_PROJECT_ID, + zone=GCE_ZONE, + resource_id=GCE_INSTANCE, + task_id="gcp_compute_start_task", + ) + # [END howto_operator_gce_start] + # Duplicate start for idempotence testing + # [START howto_operator_gce_start_no_project_id] + gce_instance_start2 = ComputeEngineStartInstanceOperator( + zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id="gcp_compute_start_task2" + ) + # [END howto_operator_gce_start_no_project_id] + # [START howto_operator_gce_stop] + gce_instance_stop = ComputeEngineStopInstanceOperator( + project_id=GCP_PROJECT_ID, + zone=GCE_ZONE, + resource_id=GCE_INSTANCE, + task_id="gcp_compute_stop_task", + ) + # [END howto_operator_gce_stop] + # Duplicate stop for idempotence testing + # [START howto_operator_gce_stop_no_project_id] + gce_instance_stop2 = ComputeEngineStopInstanceOperator( + zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id="gcp_compute_stop_task2" + ) + # [END howto_operator_gce_stop_no_project_id] + # [START howto_operator_gce_set_machine_type] + gce_set_machine_type = ComputeEngineSetMachineTypeOperator( + project_id=GCP_PROJECT_ID, + zone=GCE_ZONE, + resource_id=GCE_INSTANCE, + body={ + "machineType": f"zones/{GCE_ZONE}/machineTypes/{GCE_SHORT_MACHINE_TYPE_NAME}" + }, + task_id="gcp_compute_set_machine_type", + ) + # [END howto_operator_gce_set_machine_type] + # Duplicate set machine type for idempotence testing + # [START howto_operator_gce_set_machine_type_no_project_id] + gce_set_machine_type2 = ComputeEngineSetMachineTypeOperator( + zone=GCE_ZONE, + resource_id=GCE_INSTANCE, + body={ + "machineType": f"zones/{GCE_ZONE}/machineTypes/{GCE_SHORT_MACHINE_TYPE_NAME}" + }, + task_id="gcp_compute_set_machine_type2", + ) + # [END howto_operator_gce_set_machine_type_no_project_id] + + gce_instance_start >> gce_instance_start2 >> gce_instance_stop >> gce_instance_stop2 + gce_instance_stop2 >> gce_set_machine_type >> gce_set_machine_type2 diff --git a/reference/providers/google/cloud/example_dags/example_compute_igm.py b/reference/providers/google/cloud/example_dags/example_compute_igm.py new file mode 100644 index 0000000..5b788d9 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_compute_igm.py @@ -0,0 +1,145 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that uses IGM-type compute operations: +* copy of Instance Template +* update template in Instance Group Manager + +This DAG relies on the following OS environment variables + +* GCP_PROJECT_ID - the Google Cloud project where the Compute Engine instance exists +* GCE_ZONE - the zone where the Compute Engine instance exists + +Variables for copy template operator: +* GCE_TEMPLATE_NAME - name of the template to copy +* GCE_NEW_TEMPLATE_NAME - name of the new template +* GCE_NEW_DESCRIPTION - description added to the template + +Variables for update template in Group Manager: + +* GCE_INSTANCE_GROUP_MANAGER_NAME - name of the Instance Group Manager +* SOURCE_TEMPLATE_URL - url of the template to replace in the Instance Group Manager +* DESTINATION_TEMPLATE_URL - url of the new template to set in the Instance Group Manager +""" + +import os + +from airflow import models +from airflow.providers.google.cloud.operators.compute import ( + ComputeEngineCopyInstanceTemplateOperator, + ComputeEngineInstanceGroupUpdateManagerTemplateOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCE_ZONE = os.environ.get("GCE_ZONE", "europe-west1-b") + +# [START howto_operator_compute_template_copy_args] +GCE_TEMPLATE_NAME = os.environ.get("GCE_TEMPLATE_NAME", "instance-template-test") +GCE_NEW_TEMPLATE_NAME = os.environ.get( + "GCE_NEW_TEMPLATE_NAME", "instance-template-test-new" +) +GCE_NEW_DESCRIPTION = os.environ.get("GCE_NEW_DESCRIPTION", "Test new description") +GCE_INSTANCE_TEMPLATE_BODY_UPDATE = { + "name": GCE_NEW_TEMPLATE_NAME, + "description": GCE_NEW_DESCRIPTION, + "properties": {"machineType": "n1-standard-2"}, +} +# [END howto_operator_compute_template_copy_args] + +# [START howto_operator_compute_igm_update_template_args] +GCE_INSTANCE_GROUP_MANAGER_NAME = os.environ.get( + "GCE_INSTANCE_GROUP_MANAGER_NAME", "instance-group-test" +) + +SOURCE_TEMPLATE_URL = os.environ.get( + "SOURCE_TEMPLATE_URL", + "https://www.googleapis.com/compute/beta/projects/" + + GCP_PROJECT_ID + + "/global/instanceTemplates/instance-template-test", +) + +DESTINATION_TEMPLATE_URL = os.environ.get( + "DESTINATION_TEMPLATE_URL", + "https://www.googleapis.com/compute/beta/projects/" + + GCP_PROJECT_ID + + "/global/instanceTemplates/" + + GCE_NEW_TEMPLATE_NAME, +) + +UPDATE_POLICY = { + "type": "OPPORTUNISTIC", + "minimalAction": "RESTART", + "maxSurge": {"fixed": 1}, + "minReadySec": 1800, +} + +# [END howto_operator_compute_igm_update_template_args] + + +with models.DAG( + "example_gcp_compute_igm", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + # [START howto_operator_gce_igm_copy_template] + gce_instance_template_copy = ComputeEngineCopyInstanceTemplateOperator( + project_id=GCP_PROJECT_ID, + resource_id=GCE_TEMPLATE_NAME, + body_patch=GCE_INSTANCE_TEMPLATE_BODY_UPDATE, + task_id="gcp_compute_igm_copy_template_task", + ) + # [END howto_operator_gce_igm_copy_template] + # Added to check for idempotence + # [START howto_operator_gce_igm_copy_template_no_project_id] + gce_instance_template_copy2 = ComputeEngineCopyInstanceTemplateOperator( + resource_id=GCE_TEMPLATE_NAME, + body_patch=GCE_INSTANCE_TEMPLATE_BODY_UPDATE, + task_id="gcp_compute_igm_copy_template_task_2", + ) + # [END howto_operator_gce_igm_copy_template_no_project_id] + # [START howto_operator_gce_igm_update_template] + gce_instance_group_manager_update_template = ( + ComputeEngineInstanceGroupUpdateManagerTemplateOperator( + project_id=GCP_PROJECT_ID, + resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, + zone=GCE_ZONE, + source_template=SOURCE_TEMPLATE_URL, + destination_template=DESTINATION_TEMPLATE_URL, + update_policy=UPDATE_POLICY, + task_id="gcp_compute_igm_group_manager_update_template", + ) + ) + # [END howto_operator_gce_igm_update_template] + # Added to check for idempotence (and without UPDATE_POLICY) + # [START howto_operator_gce_igm_update_template_no_project_id] + gce_instance_group_manager_update_template2 = ( + ComputeEngineInstanceGroupUpdateManagerTemplateOperator( + resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, + zone=GCE_ZONE, + source_template=SOURCE_TEMPLATE_URL, + destination_template=DESTINATION_TEMPLATE_URL, + task_id="gcp_compute_igm_group_manager_update_template_2", + ) + ) + # [END howto_operator_gce_igm_update_template_no_project_id] + + gce_instance_template_copy >> gce_instance_template_copy2 >> gce_instance_group_manager_update_template + gce_instance_group_manager_update_template >> gce_instance_group_manager_update_template2 diff --git a/reference/providers/google/cloud/example_dags/example_compute_ssh.py b/reference/providers/google/cloud/example_dags/example_compute_ssh.py new file mode 100644 index 0000000..5718bf2 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_compute_ssh.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.google.cloud.hooks.compute_ssh import ComputeEngineSSHHook +from airflow.providers.ssh.operators.ssh import SSHOperator +from airflow.utils import dates + +# [START howto_operator_gce_args_common] +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCE_ZONE = os.environ.get("GCE_ZONE", "europe-west2-a") +GCE_INSTANCE = os.environ.get("GCE_INSTANCE", "target-instance") +# [END howto_operator_gce_args_common] + +with models.DAG( + "example_compute_ssh", + default_args=dict(start_date=dates.days_ago(1)), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag: + # # [START howto_execute_command_on_remote1] + os_login_without_iap_tunnel = SSHOperator( + task_id="os_login_without_iap_tunnel", + ssh_hook=ComputeEngineSSHHook( + instance_name=GCE_INSTANCE, + zone=GCE_ZONE, + project_id=GCP_PROJECT_ID, + use_oslogin=True, + use_iap_tunnel=False, + ), + command="echo os_login_without_iap_tunnel", + ) + # # [END howto_execute_command_on_remote1] + + # # [START howto_execute_command_on_remote2] + metadata_without_iap_tunnel = SSHOperator( + task_id="metadata_without_iap_tunnel", + ssh_hook=ComputeEngineSSHHook( + instance_name=GCE_INSTANCE, + zone=GCE_ZONE, + use_oslogin=False, + use_iap_tunnel=False, + ), + command="echo metadata_without_iap_tunnel", + ) + # # [END howto_execute_command_on_remote2] + + os_login_with_iap_tunnel = SSHOperator( + task_id="os_login_with_iap_tunnel", + ssh_hook=ComputeEngineSSHHook( + instance_name=GCE_INSTANCE, + zone=GCE_ZONE, + use_oslogin=True, + use_iap_tunnel=True, + ), + command="echo os_login_with_iap_tunnel", + ) + + metadata_with_iap_tunnel = SSHOperator( + task_id="metadata_with_iap_tunnel", + ssh_hook=ComputeEngineSSHHook( + instance_name=GCE_INSTANCE, + zone=GCE_ZONE, + use_oslogin=False, + use_iap_tunnel=True, + ), + command="echo metadata_with_iap_tunnel", + ) + + os_login_with_iap_tunnel >> os_login_without_iap_tunnel + metadata_with_iap_tunnel >> metadata_without_iap_tunnel + + os_login_without_iap_tunnel >> metadata_with_iap_tunnel diff --git a/reference/providers/google/cloud/example_dags/example_datacatalog.py b/reference/providers/google/cloud/example_dags/example_datacatalog.py new file mode 100644 index 0000000..8b05c1e --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_datacatalog.py @@ -0,0 +1,473 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that interacts with Google Data Catalog service +""" +from airflow import models +from airflow.operators.bash_operator import BashOperator +from airflow.providers.google.cloud.operators.datacatalog import ( + CloudDataCatalogCreateEntryGroupOperator, + CloudDataCatalogCreateEntryOperator, + CloudDataCatalogCreateTagOperator, + CloudDataCatalogCreateTagTemplateFieldOperator, + CloudDataCatalogCreateTagTemplateOperator, + CloudDataCatalogDeleteEntryGroupOperator, + CloudDataCatalogDeleteEntryOperator, + CloudDataCatalogDeleteTagOperator, + CloudDataCatalogDeleteTagTemplateFieldOperator, + CloudDataCatalogDeleteTagTemplateOperator, + CloudDataCatalogGetEntryGroupOperator, + CloudDataCatalogGetEntryOperator, + CloudDataCatalogGetTagTemplateOperator, + CloudDataCatalogListTagsOperator, + CloudDataCatalogLookupEntryOperator, + CloudDataCatalogRenameTagTemplateFieldOperator, + CloudDataCatalogSearchCatalogOperator, + CloudDataCatalogUpdateEntryOperator, + CloudDataCatalogUpdateTagOperator, + CloudDataCatalogUpdateTagTemplateFieldOperator, + CloudDataCatalogUpdateTagTemplateOperator, +) +from airflow.utils.dates import days_ago +from airflow.utils.helpers import chain +from google.cloud.datacatalog_v1beta1 import FieldType, TagField, TagTemplateField + +PROJECT_ID = "polidea-airflow" +LOCATION = "us-central1" +ENTRY_GROUP_ID = "important_data_jan_2019" +ENTRY_ID = "python_files" +TEMPLATE_ID = "template_id" +FIELD_NAME_1 = "first" +FIELD_NAME_2 = "second" +FIELD_NAME_3 = "first-rename" + +with models.DAG( + "example_gcp_datacatalog", start_date=days_ago(1), schedule_interval=None +) as dag: + # Create + # [START howto_operator_gcp_datacatalog_create_entry_group] + create_entry_group = CloudDataCatalogCreateEntryGroupOperator( + task_id="create_entry_group", + location=LOCATION, + entry_group_id=ENTRY_GROUP_ID, + entry_group={"display_name": "analytics data - jan 2011"}, + ) + # [END howto_operator_gcp_datacatalog_create_entry_group] + + # [START howto_operator_gcp_datacatalog_create_entry_group_result] + create_entry_group_result = BashOperator( + task_id="create_entry_group_result", + bash_command="echo \"{{ task_instance.xcom_pull('create_entry_group', key='entry_group_id') }}\"", + ) + # [END howto_operator_gcp_datacatalog_create_entry_group_result] + + # [START howto_operator_gcp_datacatalog_create_entry_group_result2] + create_entry_group_result2 = BashOperator( + task_id="create_entry_group_result2", + bash_command="echo \"{{ task_instance.xcom_pull('create_entry_group') }}\"", + ) + # [END howto_operator_gcp_datacatalog_create_entry_group_result2] + + # [START howto_operator_gcp_datacatalog_create_entry_gcs] + create_entry_gcs = CloudDataCatalogCreateEntryOperator( + task_id="create_entry_gcs", + location=LOCATION, + entry_group=ENTRY_GROUP_ID, + entry_id=ENTRY_ID, + entry={ + "display_name": "Wizard", + "type_": "FILESET", + "gcs_fileset_spec": {"file_patterns": ["gs://test-datacatalog/**"]}, + }, + ) + # [END howto_operator_gcp_datacatalog_create_entry_gcs] + + # [START howto_operator_gcp_datacatalog_create_entry_gcs_result] + create_entry_gcs_result = BashOperator( + task_id="create_entry_gcs_result", + bash_command="echo \"{{ task_instance.xcom_pull('create_entry_gcs', key='entry_id') }}\"", + ) + # [END howto_operator_gcp_datacatalog_create_entry_gcs_result] + + # [START howto_operator_gcp_datacatalog_create_entry_gcs_result2] + create_entry_gcs_result2 = BashOperator( + task_id="create_entry_gcs_result2", + bash_command="echo \"{{ task_instance.xcom_pull('create_entry_gcs') }}\"", + ) + # [END howto_operator_gcp_datacatalog_create_entry_gcs_result2] + + # [START howto_operator_gcp_datacatalog_create_tag] + create_tag = CloudDataCatalogCreateTagOperator( + task_id="create_tag", + location=LOCATION, + entry_group=ENTRY_GROUP_ID, + entry=ENTRY_ID, + template_id=TEMPLATE_ID, + tag={"fields": {FIELD_NAME_1: TagField(string_value="example-value-string")}}, + ) + # [END howto_operator_gcp_datacatalog_create_tag] + + # [START howto_operator_gcp_datacatalog_create_tag_result] + create_tag_result = BashOperator( + task_id="create_tag_result", + bash_command="echo \"{{ task_instance.xcom_pull('create_tag', key='tag_id') }}\"", + ) + # [END howto_operator_gcp_datacatalog_create_tag_result] + + # [START howto_operator_gcp_datacatalog_create_tag_result2] + create_tag_result2 = BashOperator( + task_id="create_tag_result2", + bash_command="echo \"{{ task_instance.xcom_pull('create_tag') }}\"", + ) + # [END howto_operator_gcp_datacatalog_create_tag_result2] + + # [START howto_operator_gcp_datacatalog_create_tag_template] + create_tag_template = CloudDataCatalogCreateTagTemplateOperator( + task_id="create_tag_template", + location=LOCATION, + tag_template_id=TEMPLATE_ID, + tag_template={ + "display_name": "Awesome Tag Template", + "fields": { + FIELD_NAME_1: TagTemplateField( + display_name="first-field", type_=dict(primitive_type="STRING") + ) + }, + }, + ) + # [END howto_operator_gcp_datacatalog_create_tag_template] + + # [START howto_operator_gcp_datacatalog_create_tag_template_result] + create_tag_template_result = BashOperator( + task_id="create_tag_template_result", + bash_command="echo \"{{ task_instance.xcom_pull('create_tag_template', key='tag_template_id') }}\"", + ) + # [END howto_operator_gcp_datacatalog_create_tag_template_result] + + # [START howto_operator_gcp_datacatalog_create_tag_template_result2] + create_tag_template_result2 = BashOperator( + task_id="create_tag_template_result2", + bash_command="echo \"{{ task_instance.xcom_pull('create_tag_template') }}\"", + ) + # [END howto_operator_gcp_datacatalog_create_tag_template_result2] + + # [START howto_operator_gcp_datacatalog_create_tag_template_field] + create_tag_template_field = CloudDataCatalogCreateTagTemplateFieldOperator( + task_id="create_tag_template_field", + location=LOCATION, + tag_template=TEMPLATE_ID, + tag_template_field_id=FIELD_NAME_2, + tag_template_field=TagTemplateField( + display_name="second-field", type_=FieldType(primitive_type="STRING") + ), + ) + # [END howto_operator_gcp_datacatalog_create_tag_template_field] + + # [START howto_operator_gcp_datacatalog_create_tag_template_field_result] + create_tag_template_field_result = BashOperator( + task_id="create_tag_template_field_result", + bash_command=( + "echo \"{{ task_instance.xcom_pull('create_tag_template_field'," + + " key='tag_template_field_id') }}\"" + ), + ) + # [END howto_operator_gcp_datacatalog_create_tag_template_field_result] + + # [START howto_operator_gcp_datacatalog_create_tag_template_field_result2] + create_tag_template_field_result2 = BashOperator( + task_id="create_tag_template_field_result2", + bash_command="echo \"{{ task_instance.xcom_pull('create_tag_template_field') }}\"", + ) + # [END howto_operator_gcp_datacatalog_create_tag_template_field_result2] + + # Delete + # [START howto_operator_gcp_datacatalog_delete_entry] + delete_entry = CloudDataCatalogDeleteEntryOperator( + task_id="delete_entry", + location=LOCATION, + entry_group=ENTRY_GROUP_ID, + entry=ENTRY_ID, + ) + # [END howto_operator_gcp_datacatalog_delete_entry] + + # [START howto_operator_gcp_datacatalog_delete_entry_group] + delete_entry_group = CloudDataCatalogDeleteEntryGroupOperator( + task_id="delete_entry_group", location=LOCATION, entry_group=ENTRY_GROUP_ID + ) + # [END howto_operator_gcp_datacatalog_delete_entry_group] + + # [START howto_operator_gcp_datacatalog_delete_tag] + delete_tag = CloudDataCatalogDeleteTagOperator( + task_id="delete_tag", + location=LOCATION, + entry_group=ENTRY_GROUP_ID, + entry=ENTRY_ID, + tag="{{ task_instance.xcom_pull('create_tag', key='tag_id') }}", + ) + # [END howto_operator_gcp_datacatalog_delete_tag] + + # [START howto_operator_gcp_datacatalog_delete_tag_template_field] + delete_tag_template_field = CloudDataCatalogDeleteTagTemplateFieldOperator( + task_id="delete_tag_template_field", + location=LOCATION, + tag_template=TEMPLATE_ID, + field=FIELD_NAME_2, + force=True, + ) + # [END howto_operator_gcp_datacatalog_delete_tag_template_field] + + # [START howto_operator_gcp_datacatalog_delete_tag_template] + delete_tag_template = CloudDataCatalogDeleteTagTemplateOperator( + task_id="delete_tag_template", + location=LOCATION, + tag_template=TEMPLATE_ID, + force=True, + ) + # [END howto_operator_gcp_datacatalog_delete_tag_template] + + # Get + # [START howto_operator_gcp_datacatalog_get_entry_group] + get_entry_group = CloudDataCatalogGetEntryGroupOperator( + task_id="get_entry_group", + location=LOCATION, + entry_group=ENTRY_GROUP_ID, + read_mask={"paths": ["name", "display_name"]}, + ) + # [END howto_operator_gcp_datacatalog_get_entry_group] + + # [START howto_operator_gcp_datacatalog_get_entry_group_result] + get_entry_group_result = BashOperator( + task_id="get_entry_group_result", + bash_command="echo \"{{ task_instance.xcom_pull('get_entry_group') }}\"", + ) + # [END howto_operator_gcp_datacatalog_get_entry_group_result] + + # [START howto_operator_gcp_datacatalog_get_entry] + get_entry = CloudDataCatalogGetEntryOperator( + task_id="get_entry", + location=LOCATION, + entry_group=ENTRY_GROUP_ID, + entry=ENTRY_ID, + ) + # [END howto_operator_gcp_datacatalog_get_entry] + + # [START howto_operator_gcp_datacatalog_get_entry_result] + get_entry_result = BashOperator( + task_id="get_entry_result", + bash_command="echo \"{{ task_instance.xcom_pull('get_entry') }}\"", + ) + # [END howto_operator_gcp_datacatalog_get_entry_result] + + # [START howto_operator_gcp_datacatalog_get_tag_template] + get_tag_template = CloudDataCatalogGetTagTemplateOperator( + task_id="get_tag_template", location=LOCATION, tag_template=TEMPLATE_ID + ) + # [END howto_operator_gcp_datacatalog_get_tag_template] + + # [START howto_operator_gcp_datacatalog_get_tag_template_result] + get_tag_template_result = BashOperator( + task_id="get_tag_template_result", + bash_command="echo \"{{ task_instance.xcom_pull('get_tag_template') }}\"", + ) + # [END howto_operator_gcp_datacatalog_get_tag_template_result] + + # List + # [START howto_operator_gcp_datacatalog_list_tags] + list_tags = CloudDataCatalogListTagsOperator( + task_id="list_tags", + location=LOCATION, + entry_group=ENTRY_GROUP_ID, + entry=ENTRY_ID, + ) + # [END howto_operator_gcp_datacatalog_list_tags] + + # [START howto_operator_gcp_datacatalog_list_tags_result] + list_tags_result = BashOperator( + task_id="list_tags_result", + bash_command="echo \"{{ task_instance.xcom_pull('list_tags') }}\"", + ) + # [END howto_operator_gcp_datacatalog_list_tags_result] + + # Lookup + # [START howto_operator_gcp_datacatalog_lookup_entry_linked_resource] + current_entry_template = ( + "//datacatalog.googleapis.com/projects/{project_id}/locations/{location}/" + "entryGroups/{entry_group}/entries/{entry}" + ) + lookup_entry_linked_resource = CloudDataCatalogLookupEntryOperator( + task_id="lookup_entry", + linked_resource=current_entry_template.format( + project_id=PROJECT_ID, + location=LOCATION, + entry_group=ENTRY_GROUP_ID, + entry=ENTRY_ID, + ), + ) + # [END howto_operator_gcp_datacatalog_lookup_entry_linked_resource] + + # [START howto_operator_gcp_datacatalog_lookup_entry_result] + lookup_entry_result = BashOperator( + task_id="lookup_entry_result", + bash_command="echo \"{{ task_instance.xcom_pull('lookup_entry')['display_name'] }}\"", + ) + # [END howto_operator_gcp_datacatalog_lookup_entry_result] + + # Rename + # [START howto_operator_gcp_datacatalog_rename_tag_template_field] + rename_tag_template_field = CloudDataCatalogRenameTagTemplateFieldOperator( + task_id="rename_tag_template_field", + location=LOCATION, + tag_template=TEMPLATE_ID, + field=FIELD_NAME_1, + new_tag_template_field_id=FIELD_NAME_3, + ) + # [END howto_operator_gcp_datacatalog_rename_tag_template_field] + + # Search + # [START howto_operator_gcp_datacatalog_search_catalog] + search_catalog = CloudDataCatalogSearchCatalogOperator( + task_id="search_catalog", + scope={"include_project_ids": [PROJECT_ID]}, + query=f"projectid:{PROJECT_ID}", + ) + # [END howto_operator_gcp_datacatalog_search_catalog] + + # [START howto_operator_gcp_datacatalog_search_catalog_result] + search_catalog_result = BashOperator( + task_id="search_catalog_result", + bash_command="echo \"{{ task_instance.xcom_pull('search_catalog') }}\"", + ) + # [END howto_operator_gcp_datacatalog_search_catalog_result] + + # Update + # [START howto_operator_gcp_datacatalog_update_entry] + update_entry = CloudDataCatalogUpdateEntryOperator( + task_id="update_entry", + entry={"display_name": "New Wizard"}, + update_mask={"paths": ["display_name"]}, + location=LOCATION, + entry_group=ENTRY_GROUP_ID, + entry_id=ENTRY_ID, + ) + # [END howto_operator_gcp_datacatalog_update_entry] + + # [START howto_operator_gcp_datacatalog_update_tag] + update_tag = CloudDataCatalogUpdateTagOperator( + task_id="update_tag", + tag={"fields": {FIELD_NAME_1: TagField(string_value="new-value-string")}}, + update_mask={"paths": ["fields"]}, + location=LOCATION, + entry_group=ENTRY_GROUP_ID, + entry=ENTRY_ID, + tag_id="{{ task_instance.xcom_pull('create_tag', key='tag_id') }}", + ) + # [END howto_operator_gcp_datacatalog_update_tag] + + # [START howto_operator_gcp_datacatalog_update_tag_template] + update_tag_template = CloudDataCatalogUpdateTagTemplateOperator( + task_id="update_tag_template", + tag_template={"display_name": "Awesome Tag Template"}, + update_mask={"paths": ["display_name"]}, + location=LOCATION, + tag_template_id=TEMPLATE_ID, + ) + # [END howto_operator_gcp_datacatalog_update_tag_template] + + # [START howto_operator_gcp_datacatalog_update_tag_template_field] + update_tag_template_field = CloudDataCatalogUpdateTagTemplateFieldOperator( + task_id="update_tag_template_field", + tag_template_field={"display_name": "Updated template field"}, + update_mask={"paths": ["display_name"]}, + location=LOCATION, + tag_template=TEMPLATE_ID, + tag_template_field_id=FIELD_NAME_1, + ) + # [END howto_operator_gcp_datacatalog_update_tag_template_field] + + # Create + create_tasks = [ + create_entry_group, + create_entry_gcs, + create_tag_template, + create_tag_template_field, + create_tag, + ] + chain(*create_tasks) + + create_entry_group >> delete_entry_group + create_entry_group >> create_entry_group_result + create_entry_group >> create_entry_group_result2 + + create_entry_gcs >> delete_entry + create_entry_gcs >> create_entry_gcs_result + create_entry_gcs >> create_entry_gcs_result2 + + create_tag_template >> delete_tag_template_field + create_tag_template >> create_tag_template_result + create_tag_template >> create_tag_template_result2 + + create_tag_template_field >> delete_tag_template_field + create_tag_template_field >> create_tag_template_field_result + create_tag_template_field >> create_tag_template_field_result2 + + create_tag >> delete_tag + create_tag >> create_tag_result + create_tag >> create_tag_result2 + + # Delete + delete_tasks = [ + delete_tag, + delete_tag_template_field, + delete_tag_template, + delete_entry, + delete_entry_group, + ] + chain(*delete_tasks) + + # Get + create_tag_template >> get_tag_template >> delete_tag_template + get_tag_template >> get_tag_template_result + + create_entry_gcs >> get_entry >> delete_entry + get_entry >> get_entry_result + + create_entry_group >> get_entry_group >> delete_entry_group + get_entry_group >> get_entry_group_result + + # List + create_tag >> list_tags >> delete_tag + list_tags >> list_tags_result + + # Lookup + create_entry_gcs >> lookup_entry_linked_resource >> delete_entry + lookup_entry_linked_resource >> lookup_entry_result + + # Rename + update_tag >> rename_tag_template_field + create_tag_template_field >> rename_tag_template_field >> delete_tag_template_field + + # Search + chain(create_tasks, search_catalog, delete_tasks) + search_catalog >> search_catalog_result + + # Update + create_entry_gcs >> update_entry >> delete_entry + create_tag >> update_tag >> delete_tag + create_tag_template >> update_tag_template >> delete_tag_template + create_tag_template_field >> update_tag_template_field >> rename_tag_template_field diff --git a/reference/providers/google/cloud/example_dags/example_dataflow.py b/reference/providers/google/cloud/example_dags/example_dataflow.py new file mode 100644 index 0000000..7aae506 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_dataflow.py @@ -0,0 +1,272 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google Cloud Dataflow service +""" +import os +from typing import Callable, Dict, List +from urllib.parse import urlparse + +from airflow import models +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus +from airflow.providers.google.cloud.operators.dataflow import ( + CheckJobRunning, + DataflowCreateJavaJobOperator, + DataflowCreatePythonJobOperator, + DataflowTemplatedJobStartOperator, +) +from airflow.providers.google.cloud.sensors.dataflow import ( + DataflowJobAutoScalingEventsSensor, + DataflowJobMessagesSensor, + DataflowJobMetricsSensor, + DataflowJobStatusSensor, +) +from airflow.providers.google.cloud.transfers.gcs_to_local import ( + GCSToLocalFilesystemOperator, +) +from airflow.utils.dates import days_ago + +GCS_TMP = os.environ.get("GCP_DATAFLOW_GCS_TMP", "gs://test-dataflow-example/temp/") +GCS_STAGING = os.environ.get( + "GCP_DATAFLOW_GCS_STAGING", "gs://test-dataflow-example/staging/" +) +GCS_OUTPUT = os.environ.get( + "GCP_DATAFLOW_GCS_OUTPUT", "gs://test-dataflow-example/output" +) +GCS_JAR = os.environ.get( + "GCP_DATAFLOW_JAR", "gs://test-dataflow-example/word-count-beam-bundled-0.1.jar" +) +GCS_PYTHON = os.environ.get( + "GCP_DATAFLOW_PYTHON", "gs://test-dataflow-example/wordcount_debugging.py" +) + +GCS_JAR_PARTS = urlparse(GCS_JAR) +GCS_JAR_BUCKET_NAME = GCS_JAR_PARTS.netloc +GCS_JAR_OBJECT_NAME = GCS_JAR_PARTS.path[1:] + +default_args = { + "dataflow_default_options": { + "tempLocation": GCS_TMP, + "stagingLocation": GCS_STAGING, + } +} + +with models.DAG( + "example_gcp_dataflow_native_java", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag_native_java: + + # [START howto_operator_start_java_job_jar_on_gcs] + start_java_job = DataflowCreateJavaJobOperator( + task_id="start-java-job", + jar=GCS_JAR, + job_name="{{task.task_id}}", + options={ + "output": GCS_OUTPUT, + }, + poll_sleep=10, + job_class="org.apache.beam.examples.WordCount", + check_if_running=CheckJobRunning.IgnoreJob, + location="europe-west3", + ) + # [END howto_operator_start_java_job_jar_on_gcs] + + # [START howto_operator_start_java_job_local_jar] + jar_to_local = GCSToLocalFilesystemOperator( + task_id="jar-to-local", + bucket=GCS_JAR_BUCKET_NAME, + object_name=GCS_JAR_OBJECT_NAME, + filename="/tmp/dataflow-{{ ds_nodash }}.jar", + ) + + start_java_job_local = DataflowCreateJavaJobOperator( + task_id="start-java-job-local", + jar="/tmp/dataflow-{{ ds_nodash }}.jar", + job_name="{{task.task_id}}", + options={ + "output": GCS_OUTPUT, + }, + poll_sleep=10, + job_class="org.apache.beam.examples.WordCount", + check_if_running=CheckJobRunning.WaitForRun, + ) + jar_to_local >> start_java_job_local + # [END howto_operator_start_java_job_local_jar] + +with models.DAG( + "example_gcp_dataflow_native_python", + default_args=default_args, + start_date=days_ago(1), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag_native_python: + + # [START howto_operator_start_python_job] + start_python_job = DataflowCreatePythonJobOperator( + task_id="start-python-job", + py_file=GCS_PYTHON, + py_options=[], + job_name="{{task.task_id}}", + options={ + "output": GCS_OUTPUT, + }, + py_requirements=["apache-beam[gcp]==2.21.0"], + py_interpreter="python3", + py_system_site_packages=False, + location="europe-west3", + ) + # [END howto_operator_start_python_job] + + start_python_job_local = DataflowCreatePythonJobOperator( + task_id="start-python-job-local", + py_file="apache_beam.examples.wordcount", + py_options=["-m"], + job_name="{{task.task_id}}", + options={ + "output": GCS_OUTPUT, + }, + py_requirements=["apache-beam[gcp]==2.14.0"], + py_interpreter="python3", + py_system_site_packages=False, + ) + +with models.DAG( + "example_gcp_dataflow_native_python_async", + default_args=default_args, + start_date=days_ago(1), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag_native_python_async: + # [START howto_operator_start_python_job_async] + start_python_job_async = DataflowCreatePythonJobOperator( + task_id="start-python-job-async", + py_file=GCS_PYTHON, + py_options=[], + job_name="{{task.task_id}}", + options={ + "output": GCS_OUTPUT, + }, + py_requirements=["apache-beam[gcp]==2.25.0"], + py_interpreter="python3", + py_system_site_packages=False, + location="europe-west3", + wait_until_finished=False, + ) + # [END howto_operator_start_python_job_async] + + # [START howto_sensor_wait_for_job_status] + wait_for_python_job_async_done = DataflowJobStatusSensor( + task_id="wait-for-python-job-async-done", + job_id="{{task_instance.xcom_pull('start-python-job-async')['job_id']}}", + expected_statuses={DataflowJobStatus.JOB_STATE_DONE}, + location="europe-west3", + ) + # [END howto_sensor_wait_for_job_status] + + # [START howto_sensor_wait_for_job_metric] + def check_metric_scalar_gte(metric_name: str, value: int) -> Callable: + """Check is metric greater than equals to given value.""" + + def callback(metrics: List[Dict]) -> bool: + dag_native_python_async.log.info( + "Looking for '%s' >= %d", metric_name, value + ) + for metric in metrics: + context = metric.get("name", {}).get("context", {}) + original_name = context.get("original_name", "") + tentative = context.get("tentative", "") + if original_name == "Service-cpu_num_seconds" and not tentative: + return metric["scalar"] >= value + raise AirflowException(f"Metric '{metric_name}' not found in metrics") + + return callback + + wait_for_python_job_async_metric = DataflowJobMetricsSensor( + task_id="wait-for-python-job-async-metric", + job_id="{{task_instance.xcom_pull('start-python-job-async')['job_id']}}", + location="europe-west3", + callback=check_metric_scalar_gte( + metric_name="Service-cpu_num_seconds", value=100 + ), + ) + # [END howto_sensor_wait_for_job_metric] + + # [START howto_sensor_wait_for_job_message] + def check_message(messages: List[dict]) -> bool: + """Check message""" + for message in messages: + if "Adding workflow start and stop steps." in message.get( + "messageText", "" + ): + return True + return False + + wait_for_python_job_async_message = DataflowJobMessagesSensor( + task_id="wait-for-python-job-async-message", + job_id="{{task_instance.xcom_pull('start-python-job-async')['job_id']}}", + location="europe-west3", + callback=check_message, + ) + # [END howto_sensor_wait_for_job_message] + + # [START howto_sensor_wait_for_job_autoscaling_event] + def check_autoscaling_event(autoscaling_events: List[dict]) -> bool: + """Check autoscaling event""" + for autoscaling_event in autoscaling_events: + if "Worker pool started." in autoscaling_event.get("description", {}).get( + "messageText", "" + ): + return True + return False + + wait_for_python_job_async_autoscaling_event = DataflowJobAutoScalingEventsSensor( + task_id="wait-for-python-job-async-autoscaling-event", + job_id="{{task_instance.xcom_pull('start-python-job-async')['job_id']}}", + location="europe-west3", + callback=check_autoscaling_event, + ) + # [END howto_sensor_wait_for_job_autoscaling_event] + + start_python_job_async >> wait_for_python_job_async_done + start_python_job_async >> wait_for_python_job_async_metric + start_python_job_async >> wait_for_python_job_async_message + start_python_job_async >> wait_for_python_job_async_autoscaling_event + + +with models.DAG( + "example_gcp_dataflow_template", + default_args=default_args, + start_date=days_ago(1), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag_template: + # [START howto_operator_start_template_job] + start_template_job = DataflowTemplatedJobStartOperator( + task_id="start-template-job", + template="gs://dataflow-templates/latest/Word_Count", + parameters={ + "inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt", + "output": GCS_OUTPUT, + }, + location="europe-west3", + ) + # [END howto_operator_start_template_job] diff --git a/reference/providers/google/cloud/example_dags/example_dataflow_flex_template.py b/reference/providers/google/cloud/example_dags/example_dataflow_flex_template.py new file mode 100644 index 0000000..e7e78fa --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_dataflow_flex_template.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google Cloud Dataflow service +""" +import os + +from airflow import models +from airflow.providers.google.cloud.operators.dataflow import ( + DataflowStartFlexTemplateOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") + +DATAFLOW_FLEX_TEMPLATE_JOB_NAME = os.environ.get( + "GCP_DATAFLOW_FLEX_TEMPLATE_JOB_NAME", "dataflow-flex-template" +) + +# For simplicity we use the same topic name as the subscription name. +PUBSUB_FLEX_TEMPLATE_TOPIC = os.environ.get( + "GCP_DATAFLOW_PUBSUB_FLEX_TEMPLATE_TOPIC", "dataflow-flex-template" +) +PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION = PUBSUB_FLEX_TEMPLATE_TOPIC +GCS_FLEX_TEMPLATE_TEMPLATE_PATH = os.environ.get( + "GCP_DATAFLOW_GCS_FLEX_TEMPLATE_TEMPLATE_PATH", + "gs://test-airflow-dataflow-flex-template/samples/dataflow/templates/streaming-beam-sql.json", +) +BQ_FLEX_TEMPLATE_DATASET = os.environ.get( + "GCP_DATAFLOW_BQ_FLEX_TEMPLATE_DATASET", "airflow_dataflow_samples" +) +BQ_FLEX_TEMPLATE_LOCATION = os.environ.get( + "GCP_DATAFLOW_BQ_FLEX_TEMPLATE_LOCATION>", "us-west1" +) + +with models.DAG( + dag_id="example_gcp_dataflow_flex_template_java", + start_date=days_ago(1), + schedule_interval=None, # Override to match your needs +) as dag_flex_template: + # [START howto_operator_start_template_job] + start_flex_template = DataflowStartFlexTemplateOperator( + task_id="start_flex_template_streaming_beam_sql", + body={ + "launchParameter": { + "containerSpecGcsPath": GCS_FLEX_TEMPLATE_TEMPLATE_PATH, + "jobName": DATAFLOW_FLEX_TEMPLATE_JOB_NAME, + "parameters": { + "inputSubscription": PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION, + "outputTable": f"{GCP_PROJECT_ID}:{BQ_FLEX_TEMPLATE_DATASET}.streaming_beam_sql", + }, + } + }, + do_xcom_push=True, + location=BQ_FLEX_TEMPLATE_LOCATION, + ) + # [END howto_operator_start_template_job] diff --git a/reference/providers/google/cloud/example_dags/example_dataflow_sql.py b/reference/providers/google/cloud/example_dags/example_dataflow_sql.py new file mode 100644 index 0000000..f4f0cfc --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_dataflow_sql.py @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google Cloud Dataflow service +""" +import os + +from airflow import models +from airflow.providers.google.cloud.operators.dataflow import ( + DataflowStartSqlJobOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") + +BQ_SQL_DATASET = os.environ.get( + "GCP_DATAFLOW_BQ_SQL_DATASET", "airflow_dataflow_samples" +) +BQ_SQL_TABLE_INPUT = os.environ.get("GCP_DATAFLOW_BQ_SQL_TABLE_INPUT", "beam_input") +BQ_SQL_TABLE_OUTPUT = os.environ.get("GCP_DATAFLOW_BQ_SQL_TABLE_OUTPUT", "beam_output") +DATAFLOW_SQL_JOB_NAME = os.environ.get("GCP_DATAFLOW_SQL_JOB_NAME", "dataflow-sql") +DATAFLOW_SQL_LOCATION = os.environ.get("GCP_DATAFLOW_SQL_LOCATION", "us-west1") + + +with models.DAG( + dag_id="example_gcp_dataflow_sql", + start_date=days_ago(1), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag_sql: + # [START howto_operator_start_sql_job] + start_sql = DataflowStartSqlJobOperator( + task_id="start_sql_query", + job_name=DATAFLOW_SQL_JOB_NAME, + query=f""" + SELECT + sales_region as sales_region, + count(state_id) as count_state + FROM + bigquery.table.`{GCP_PROJECT_ID}`.`{BQ_SQL_DATASET}`.`{BQ_SQL_TABLE_INPUT}` + WHERE state_id >= @state_id_min + GROUP BY sales_region; + """, + options={ + "bigquery-project": GCP_PROJECT_ID, + "bigquery-dataset": BQ_SQL_DATASET, + "bigquery-table": BQ_SQL_TABLE_OUTPUT, + "bigquery-write-disposition": "write-truncate", + "parameter": "state_id_min:INT64:2", + }, + location=DATAFLOW_SQL_LOCATION, + do_xcom_push=True, + ) + # [END howto_operator_start_sql_job] diff --git a/reference/providers/google/cloud/example_dags/example_datafusion.py b/reference/providers/google/cloud/example_dags/example_datafusion.py new file mode 100644 index 0000000..93f238a --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_datafusion.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that shows how to use DataFusion. +""" +import os + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.datafusion import ( + CloudDataFusionCreateInstanceOperator, + CloudDataFusionCreatePipelineOperator, + CloudDataFusionDeleteInstanceOperator, + CloudDataFusionDeletePipelineOperator, + CloudDataFusionGetInstanceOperator, + CloudDataFusionListPipelinesOperator, + CloudDataFusionRestartInstanceOperator, + CloudDataFusionStartPipelineOperator, + CloudDataFusionStopPipelineOperator, + CloudDataFusionUpdateInstanceOperator, +) +from airflow.utils import dates +from airflow.utils.state import State + +# [START howto_data_fusion_env_variables] +LOCATION = "europe-north1" +INSTANCE_NAME = "airflow-test-instance" +INSTANCE = {"type": "BASIC", "displayName": INSTANCE_NAME} + +BUCKET_1 = os.environ.get("GCP_DATAFUSION_BUCKET_1", "test-datafusion-bucket-1") +BUCKET_2 = os.environ.get("GCP_DATAFUSION_BUCKET_2", "test-datafusion-bucket-2") + +BUCKET_1_URI = f"gs//{BUCKET_1}" +BUCKET_2_URI = f"gs//{BUCKET_2}" + +PIPELINE_NAME = os.environ.get("GCP_DATAFUSION_PIPELINE_NAME", "airflow_test") +PIPELINE = { + "name": "test-pipe", + "description": "Data Pipeline Application", + "artifact": {"name": "cdap-data-pipeline", "version": "6.1.2", "scope": "SYSTEM"}, + "config": { + "resources": {"memoryMB": 2048, "virtualCores": 1}, + "driverResources": {"memoryMB": 2048, "virtualCores": 1}, + "connections": [{"from": "GCS", "to": "GCS2"}], + "comments": [], + "postActions": [], + "properties": {}, + "processTimingEnabled": True, + "stageLoggingEnabled": False, + "stages": [ + { + "name": "GCS", + "plugin": { + "name": "GCSFile", + "type": "batchsource", + "label": "GCS", + "artifact": { + "name": "google-cloud", + "version": "0.14.2", + "scope": "SYSTEM", + }, + "properties": { + "project": "auto-detect", + "format": "text", + "skipHeader": "false", + "serviceFilePath": "auto-detect", + "filenameOnly": "false", + "recursive": "false", + "encrypted": "false", + "schema": '{"type":"record","name":"etlSchemaBody","fields":' + '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', + "path": BUCKET_1_URI, + "referenceName": "foo_bucket", + }, + }, + "outputSchema": [ + { + "name": "etlSchemaBody", + "schema": '{"type":"record","name":"etlSchemaBody","fields":' + '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', + } + ], + }, + { + "name": "GCS2", + "plugin": { + "name": "GCS", + "type": "batchsink", + "label": "GCS2", + "artifact": { + "name": "google-cloud", + "version": "0.14.2", + "scope": "SYSTEM", + }, + "properties": { + "project": "auto-detect", + "suffix": "yyyy-MM-dd-HH-mm", + "format": "json", + "serviceFilePath": "auto-detect", + "location": "us", + "schema": '{"type":"record","name":"etlSchemaBody","fields":' + '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', + "referenceName": "bar", + "path": BUCKET_2_URI, + }, + }, + "outputSchema": [ + { + "name": "etlSchemaBody", + "schema": '{"type":"record","name":"etlSchemaBody","fields":' + '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', + } + ], + "inputSchema": [ + { + "name": "GCS", + "schema": '{"type":"record","name":"etlSchemaBody","fields":' + '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', + } + ], + }, + ], + "schedule": "0 * * * *", + "engine": "spark", + "numOfRecordsPreview": 100, + "maxConcurrentRuns": 1, + }, +} +# [END howto_data_fusion_env_variables] + + +with models.DAG( + "example_data_fusion", + schedule_interval=None, # Override to match your needs + start_date=dates.days_ago(1), +) as dag: + # [START howto_cloud_data_fusion_create_instance_operator] + create_instance = CloudDataFusionCreateInstanceOperator( + location=LOCATION, + instance_name=INSTANCE_NAME, + instance=INSTANCE, + task_id="create_instance", + ) + # [END howto_cloud_data_fusion_create_instance_operator] + + # [START howto_cloud_data_fusion_get_instance_operator] + get_instance = CloudDataFusionGetInstanceOperator( + location=LOCATION, instance_name=INSTANCE_NAME, task_id="get_instance" + ) + # [END howto_cloud_data_fusion_get_instance_operator] + + # [START howto_cloud_data_fusion_restart_instance_operator] + restart_instance = CloudDataFusionRestartInstanceOperator( + location=LOCATION, instance_name=INSTANCE_NAME, task_id="restart_instance" + ) + # [END howto_cloud_data_fusion_restart_instance_operator] + + # [START howto_cloud_data_fusion_update_instance_operator] + update_instance = CloudDataFusionUpdateInstanceOperator( + location=LOCATION, + instance_name=INSTANCE_NAME, + instance=INSTANCE, + update_mask="instance.displayName", + task_id="update_instance", + ) + # [END howto_cloud_data_fusion_update_instance_operator] + + # [START howto_cloud_data_fusion_create_pipeline] + create_pipeline = CloudDataFusionCreatePipelineOperator( + location=LOCATION, + pipeline_name=PIPELINE_NAME, + pipeline=PIPELINE, + instance_name=INSTANCE_NAME, + task_id="create_pipeline", + ) + # [END howto_cloud_data_fusion_create_pipeline] + + # [START howto_cloud_data_fusion_list_pipelines] + list_pipelines = CloudDataFusionListPipelinesOperator( + location=LOCATION, instance_name=INSTANCE_NAME, task_id="list_pipelines" + ) + # [END howto_cloud_data_fusion_list_pipelines] + + # [START howto_cloud_data_fusion_start_pipeline] + start_pipeline = CloudDataFusionStartPipelineOperator( + location=LOCATION, + pipeline_name=PIPELINE_NAME, + instance_name=INSTANCE_NAME, + task_id="start_pipeline", + ) + # [END howto_cloud_data_fusion_start_pipeline] + + # [START howto_cloud_data_fusion_stop_pipeline] + stop_pipeline = CloudDataFusionStopPipelineOperator( + location=LOCATION, + pipeline_name=PIPELINE_NAME, + instance_name=INSTANCE_NAME, + task_id="stop_pipeline", + ) + # [END howto_cloud_data_fusion_stop_pipeline] + + # [START howto_cloud_data_fusion_delete_pipeline] + delete_pipeline = CloudDataFusionDeletePipelineOperator( + location=LOCATION, + pipeline_name=PIPELINE_NAME, + instance_name=INSTANCE_NAME, + task_id="delete_pipeline", + ) + # [END howto_cloud_data_fusion_delete_pipeline] + + # [START howto_cloud_data_fusion_delete_instance_operator] + delete_instance = CloudDataFusionDeleteInstanceOperator( + location=LOCATION, instance_name=INSTANCE_NAME, task_id="delete_instance" + ) + # [END howto_cloud_data_fusion_delete_instance_operator] + + # Add sleep before creating pipeline + sleep = BashOperator(task_id="sleep", bash_command="sleep 60") + + create_instance >> get_instance >> restart_instance >> update_instance >> sleep + sleep >> create_pipeline >> list_pipelines >> start_pipeline >> stop_pipeline >> delete_pipeline + delete_pipeline >> delete_instance + +if __name__ == "__main__": + dag.clear(dag_run_state=State.NONE) + dag.run() diff --git a/reference/providers/google/cloud/example_dags/example_dataprep.py b/reference/providers/google/cloud/example_dags/example_dataprep.py new file mode 100644 index 0000000..f81ad11 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_dataprep.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that shows how to use Google Dataprep. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.operators.dataprep import ( + DataprepGetJobGroupOperator, + DataprepGetJobsForJobGroupOperator, + DataprepRunJobGroupOperator, +) +from airflow.utils import dates + +DATAPREP_JOB_ID = int(os.environ.get("DATAPREP_JOB_ID", 12345677)) +DATAPREP_JOB_RECIPE_ID = int(os.environ.get("DATAPREP_JOB_RECIPE_ID", 12345677)) +DATAPREP_BUCKET = os.environ.get("DATAPREP_BUCKET", "gs://afl-sql/name@email.com") + +DATA = { + "wrangledDataset": {"id": DATAPREP_JOB_RECIPE_ID}, + "overrides": { + "execution": "dataflow", + "profiler": False, + "writesettings": [ + { + "path": DATAPREP_BUCKET, + "action": "create", + "format": "csv", + "compression": "none", + "header": False, + "asSingleFile": False, + } + ], + }, +} + + +with models.DAG( + "example_dataprep", + schedule_interval=None, + start_date=dates.days_ago(1), # Override to match your needs +) as dag: + # [START how_to_dataprep_run_job_group_operator] + run_job_group = DataprepRunJobGroupOperator( + task_id="run_job_group", body_request=DATA + ) + # [END how_to_dataprep_run_job_group_operator] + + # [START how_to_dataprep_get_jobs_for_job_group_operator] + get_jobs_for_job_group = DataprepGetJobsForJobGroupOperator( + task_id="get_jobs_for_job_group", job_id=DATAPREP_JOB_ID + ) + # [END how_to_dataprep_get_jobs_for_job_group_operator] + + # [START how_to_dataprep_get_job_group_operator] + get_job_group = DataprepGetJobGroupOperator( + task_id="get_job_group", + job_group_id=DATAPREP_JOB_ID, + embed="", + include_deleted=False, + ) + # [END how_to_dataprep_get_job_group_operator] + + run_job_group >> [get_jobs_for_job_group, get_job_group] diff --git a/reference/providers/google/cloud/example_dags/example_dataproc.py b/reference/providers/google/cloud/example_dags/example_dataproc.py new file mode 100644 index 0000000..e809675 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_dataproc.py @@ -0,0 +1,271 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that show how to use various Dataproc +operators to manage a cluster and submit jobs. +""" + +import os + +from airflow import models +from airflow.providers.google.cloud.operators.dataproc import ( + DataprocCreateClusterOperator, + DataprocCreateWorkflowTemplateOperator, + DataprocDeleteClusterOperator, + DataprocInstantiateWorkflowTemplateOperator, + DataprocSubmitJobOperator, + DataprocUpdateClusterOperator, +) +from airflow.providers.google.cloud.sensors.dataproc import DataprocJobSensor +from airflow.utils.dates import days_ago + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id") +CLUSTER_NAME = os.environ.get("GCP_DATAPROC_CLUSTER_NAME", "example-cluster") +REGION = os.environ.get("GCP_LOCATION", "europe-west1") +ZONE = os.environ.get("GCP_REGION", "europe-west1-b") +BUCKET = os.environ.get("GCP_DATAPROC_BUCKET", "dataproc-system-tests") +OUTPUT_FOLDER = "wordcount" +OUTPUT_PATH = f"gs://{BUCKET}/{OUTPUT_FOLDER}/" +PYSPARK_MAIN = os.environ.get("PYSPARK_MAIN", "hello_world.py") +PYSPARK_URI = f"gs://{BUCKET}/{PYSPARK_MAIN}" +SPARKR_MAIN = os.environ.get("SPARKR_MAIN", "hello_world.R") +SPARKR_URI = f"gs://{BUCKET}/{SPARKR_MAIN}" + +# Cluster definition +# [START how_to_cloud_dataproc_create_cluster] + +CLUSTER_CONFIG = { + "master_config": { + "num_instances": 1, + "machine_type_uri": "n1-standard-4", + "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 1024}, + }, + "worker_config": { + "num_instances": 2, + "machine_type_uri": "n1-standard-4", + "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 1024}, + }, +} + +# [END how_to_cloud_dataproc_create_cluster] + +# Update options +# [START how_to_cloud_dataproc_updatemask_cluster_operator] +CLUSTER_UPDATE = { + "config": { + "worker_config": {"num_instances": 3}, + "secondary_worker_config": {"num_instances": 3}, + } +} +UPDATE_MASK = { + "paths": [ + "config.worker_config.num_instances", + "config.secondary_worker_config.num_instances", + ] +} +# [END how_to_cloud_dataproc_updatemask_cluster_operator] + +TIMEOUT = {"seconds": 1 * 24 * 60 * 60} + +# Jobs definitions +# [START how_to_cloud_dataproc_pig_config] +PIG_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "pig_job": {"query_list": {"queries": ["define sin HiveUDF('sin');"]}}, +} +# [END how_to_cloud_dataproc_pig_config] + +# [START how_to_cloud_dataproc_sparksql_config] +SPARK_SQL_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "spark_sql_job": {"query_list": {"queries": ["SHOW DATABASES;"]}}, +} +# [END how_to_cloud_dataproc_sparksql_config] + +# [START how_to_cloud_dataproc_spark_config] +SPARK_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "spark_job": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, +} +# [END how_to_cloud_dataproc_spark_config] + +# [START how_to_cloud_dataproc_pyspark_config] +PYSPARK_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "pyspark_job": {"main_python_file_uri": PYSPARK_URI}, +} +# [END how_to_cloud_dataproc_pyspark_config] + +# [START how_to_cloud_dataproc_sparkr_config] +SPARKR_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "spark_r_job": {"main_r_file_uri": SPARKR_URI}, +} +# [END how_to_cloud_dataproc_sparkr_config] + +# [START how_to_cloud_dataproc_hive_config] +HIVE_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "hive_job": {"query_list": {"queries": ["SHOW DATABASES;"]}}, +} +# [END how_to_cloud_dataproc_hive_config] + +# [START how_to_cloud_dataproc_hadoop_config] +HADOOP_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "hadoop_job": { + "main_jar_file_uri": "file:///usr/lib/hadoop-mapreduce/hadoop-mapreduce-examples.jar", + "args": ["wordcount", "gs://pub/shakespeare/rose.txt", OUTPUT_PATH], + }, +} +# [END how_to_cloud_dataproc_hadoop_config] +WORKFLOW_NAME = "airflow-dataproc-test" +WORKFLOW_TEMPLATE = { + "id": WORKFLOW_NAME, + "placement": { + "managed_cluster": { + "cluster_name": CLUSTER_NAME, + "config": CLUSTER_CONFIG, + } + }, + "jobs": [{"step_id": "pig_job_1", "pig_job": PIG_JOB["pig_job"]}], +} + + +with models.DAG( + "example_gcp_dataproc", start_date=days_ago(1), schedule_interval=None +) as dag: + # [START how_to_cloud_dataproc_create_cluster_operator] + create_cluster = DataprocCreateClusterOperator( + task_id="create_cluster", + project_id=PROJECT_ID, + cluster_config=CLUSTER_CONFIG, + region=REGION, + cluster_name=CLUSTER_NAME, + ) + # [END how_to_cloud_dataproc_create_cluster_operator] + + # [START how_to_cloud_dataproc_update_cluster_operator] + scale_cluster = DataprocUpdateClusterOperator( + task_id="scale_cluster", + cluster_name=CLUSTER_NAME, + cluster=CLUSTER_UPDATE, + update_mask=UPDATE_MASK, + graceful_decommission_timeout=TIMEOUT, + project_id=PROJECT_ID, + location=REGION, + ) + # [END how_to_cloud_dataproc_update_cluster_operator] + + # [START how_to_cloud_dataproc_create_workflow_template] + create_workflow_template = DataprocCreateWorkflowTemplateOperator( + task_id="create_workflow_template", + template=WORKFLOW_TEMPLATE, + project_id=PROJECT_ID, + location=REGION, + ) + # [END how_to_cloud_dataproc_create_workflow_template] + + # [START how_to_cloud_dataproc_trigger_workflow_template] + trigger_workflow = DataprocInstantiateWorkflowTemplateOperator( + task_id="trigger_workflow", + region=REGION, + project_id=PROJECT_ID, + template_id=WORKFLOW_NAME, + ) + # [END how_to_cloud_dataproc_trigger_workflow_template] + + pig_task = DataprocSubmitJobOperator( + task_id="pig_task", job=PIG_JOB, location=REGION, project_id=PROJECT_ID + ) + spark_sql_task = DataprocSubmitJobOperator( + task_id="spark_sql_task", + job=SPARK_SQL_JOB, + location=REGION, + project_id=PROJECT_ID, + ) + + spark_task = DataprocSubmitJobOperator( + task_id="spark_task", job=SPARK_JOB, location=REGION, project_id=PROJECT_ID + ) + + # [START cloud_dataproc_async_submit_sensor] + spark_task_async = DataprocSubmitJobOperator( + task_id="spark_task_async", + job=SPARK_JOB, + location=REGION, + project_id=PROJECT_ID, + asynchronous=True, + ) + + spark_task_async_sensor = DataprocJobSensor( + task_id="spark_task_async_sensor_task", + location=REGION, + project_id=PROJECT_ID, + dataproc_job_id="{{task_instance.xcom_pull(task_ids='spark_task_async')}}", + poke_interval=10, + ) + # [END cloud_dataproc_async_submit_sensor] + + # [START how_to_cloud_dataproc_submit_job_to_cluster_operator] + pyspark_task = DataprocSubmitJobOperator( + task_id="pyspark_task", job=PYSPARK_JOB, location=REGION, project_id=PROJECT_ID + ) + # [END how_to_cloud_dataproc_submit_job_to_cluster_operator] + + sparkr_task = DataprocSubmitJobOperator( + task_id="sparkr_task", job=SPARKR_JOB, location=REGION, project_id=PROJECT_ID + ) + + hive_task = DataprocSubmitJobOperator( + task_id="hive_task", job=HIVE_JOB, location=REGION, project_id=PROJECT_ID + ) + + hadoop_task = DataprocSubmitJobOperator( + task_id="hadoop_task", job=HADOOP_JOB, location=REGION, project_id=PROJECT_ID + ) + + # [START how_to_cloud_dataproc_delete_cluster_operator] + delete_cluster = DataprocDeleteClusterOperator( + task_id="delete_cluster", + project_id=PROJECT_ID, + cluster_name=CLUSTER_NAME, + region=REGION, + ) + # [END how_to_cloud_dataproc_delete_cluster_operator] + + create_cluster >> scale_cluster + scale_cluster >> create_workflow_template >> trigger_workflow >> delete_cluster + scale_cluster >> hive_task >> delete_cluster + scale_cluster >> pig_task >> delete_cluster + scale_cluster >> spark_sql_task >> delete_cluster + scale_cluster >> spark_task >> delete_cluster + scale_cluster >> spark_task_async >> spark_task_async_sensor >> delete_cluster + scale_cluster >> pyspark_task >> delete_cluster + scale_cluster >> sparkr_task >> delete_cluster + scale_cluster >> hadoop_task >> delete_cluster diff --git a/reference/providers/google/cloud/example_dags/example_datastore.py b/reference/providers/google/cloud/example_dags/example_datastore.py new file mode 100644 index 0000000..507580f --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_datastore.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that shows how to use Datastore operators. + +This example requires that your project contains Datastore instance. +""" + +import os +from typing import Any, Dict + +from airflow import models +from airflow.providers.google.cloud.operators.datastore import ( + CloudDatastoreAllocateIdsOperator, + CloudDatastoreBeginTransactionOperator, + CloudDatastoreCommitOperator, + CloudDatastoreExportEntitiesOperator, + CloudDatastoreImportEntitiesOperator, + CloudDatastoreRollbackOperator, + CloudDatastoreRunQueryOperator, +) +from airflow.utils import dates + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +BUCKET = os.environ.get("GCP_DATASTORE_BUCKET", "datastore-system-test") + +with models.DAG( + "example_gcp_datastore", + schedule_interval=None, # Override to match your needs + start_date=dates.days_ago(1), + tags=["example"], +) as dag: + # [START how_to_export_task] + export_task = CloudDatastoreExportEntitiesOperator( + task_id="export_task", + bucket=BUCKET, + project_id=GCP_PROJECT_ID, + overwrite_existing=True, + ) + # [END how_to_export_task] + + # [START how_to_import_task] + import_task = CloudDatastoreImportEntitiesOperator( + task_id="import_task", + bucket="{{ task_instance.xcom_pull('export_task')['response']['outputUrl'].split('/')[2] }}", + file="{{ '/'.join(task_instance.xcom_pull('export_task')['response']['outputUrl'].split('/')[3:]) }}", + project_id=GCP_PROJECT_ID, + ) + # [END how_to_import_task] + + export_task >> import_task + +# [START how_to_keys_def] +KEYS = [ + { + "partitionId": {"projectId": GCP_PROJECT_ID, "namespaceId": ""}, + "path": {"kind": "airflow"}, + } +] +# [END how_to_keys_def] + +# [START how_to_transaction_def] +TRANSACTION_OPTIONS: Dict[str, Any] = {"readWrite": {}} +# [END how_to_transaction_def] + +# [START how_to_commit_def] +COMMIT_BODY = { + "mode": "TRANSACTIONAL", + "mutations": [ + { + "insert": { + "key": KEYS[0], + "properties": {"string": {"stringValue": "airflow is awesome!"}}, + } + } + ], + "transaction": "{{ task_instance.xcom_pull('begin_transaction_commit') }}", +} +# [END how_to_commit_def] + +# [START how_to_query_def] +QUERY = { + "partitionId": {"projectId": GCP_PROJECT_ID, "namespaceId": ""}, + "readOptions": { + "transaction": "{{ task_instance.xcom_pull('begin_transaction_query') }}" + }, + "query": {}, +} +# [END how_to_query_def] + +with models.DAG( + "example_gcp_datastore_operations", + start_date=dates.days_ago(1), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag2: + # [START how_to_allocate_ids] + allocate_ids = CloudDatastoreAllocateIdsOperator( + task_id="allocate_ids", partial_keys=KEYS, project_id=GCP_PROJECT_ID + ) + # [END how_to_allocate_ids] + + # [START how_to_begin_transaction] + begin_transaction_commit = CloudDatastoreBeginTransactionOperator( + task_id="begin_transaction_commit", + transaction_options=TRANSACTION_OPTIONS, + project_id=GCP_PROJECT_ID, + ) + # [END how_to_begin_transaction] + + # [START how_to_commit_task] + commit_task = CloudDatastoreCommitOperator( + task_id="commit_task", body=COMMIT_BODY, project_id=GCP_PROJECT_ID + ) + # [END how_to_commit_task] + + allocate_ids >> begin_transaction_commit >> commit_task + + begin_transaction_query = CloudDatastoreBeginTransactionOperator( + task_id="begin_transaction_query", + transaction_options=TRANSACTION_OPTIONS, + project_id=GCP_PROJECT_ID, + ) + + # [START how_to_run_query] + run_query = CloudDatastoreRunQueryOperator( + task_id="run_query", body=QUERY, project_id=GCP_PROJECT_ID + ) + # [END how_to_run_query] + + allocate_ids >> begin_transaction_query >> run_query + + begin_transaction_to_rollback = CloudDatastoreBeginTransactionOperator( + task_id="begin_transaction_to_rollback", + transaction_options=TRANSACTION_OPTIONS, + project_id=GCP_PROJECT_ID, + ) + + # [START how_to_rollback_transaction] + rollback_transaction = CloudDatastoreRollbackOperator( + task_id="rollback_transaction", + transaction="{{ task_instance.xcom_pull('begin_transaction_to_rollback') }}", + ) + begin_transaction_to_rollback >> rollback_transaction + # [END how_to_rollback_transaction] diff --git a/reference/providers/google/cloud/example_dags/example_dlp.py b/reference/providers/google/cloud/example_dags/example_dlp.py new file mode 100644 index 0000000..ac9951d --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_dlp.py @@ -0,0 +1,223 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that execute the following tasks using +Cloud DLP service in the Google Cloud: +1) Creating a content inspect template; +2) Using the created template to inspect content; +3) Deleting the template from Google Cloud . +""" + +import os + +from airflow import models +from airflow.providers.google.cloud.operators.dlp import ( + CloudDLPCreateInspectTemplateOperator, + CloudDLPCreateJobTriggerOperator, + CloudDLPCreateStoredInfoTypeOperator, + CloudDLPDeidentifyContentOperator, + CloudDLPDeleteInspectTemplateOperator, + CloudDLPDeleteJobTriggerOperator, + CloudDLPDeleteStoredInfoTypeOperator, + CloudDLPInspectContentOperator, + CloudDLPUpdateJobTriggerOperator, + CloudDLPUpdateStoredInfoTypeOperator, +) +from airflow.utils.dates import days_ago +from google.cloud.dlp_v2.types import ContentItem, InspectConfig, InspectTemplate + +GCP_PROJECT = os.environ.get("GCP_PROJECT_ID", "example-project") +TEMPLATE_ID = "dlp-inspect-838746" +ITEM = ContentItem( + table={ + "headers": [{"name": "column1"}], + "rows": [{"values": [{"string_value": "My phone number is (206) 555-0123"}]}], + } +) +INSPECT_CONFIG = InspectConfig( + info_types=[{"name": "PHONE_NUMBER"}, {"name": "US_TOLLFREE_PHONE_NUMBER"}] +) +INSPECT_TEMPLATE = InspectTemplate(inspect_config=INSPECT_CONFIG) +OUTPUT_BUCKET = os.environ.get("DLP_OUTPUT_BUCKET", "gs://test-dlp-airflow") +OUTPUT_FILENAME = "test.txt" + +OBJECT_GCS_URI = os.path.join(OUTPUT_BUCKET, "tmp") +OBJECT_GCS_OUTPUT_URI = os.path.join(OUTPUT_BUCKET, "tmp", OUTPUT_FILENAME) + +with models.DAG( + "example_gcp_dlp", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag1: + # [START howto_operator_dlp_create_inspect_template] + create_template = CloudDLPCreateInspectTemplateOperator( + project_id=GCP_PROJECT, + inspect_template=INSPECT_TEMPLATE, + template_id=TEMPLATE_ID, + task_id="create_template", + do_xcom_push=True, + ) + # [END howto_operator_dlp_create_inspect_template] + + # [START howto_operator_dlp_use_inspect_template] + inspect_content = CloudDLPInspectContentOperator( + task_id="inspect_content", + project_id=GCP_PROJECT, + item=ITEM, + inspect_template_name="{{ task_instance.xcom_pull('create_template', key='return_value')['name'] }}", + ) + # [END howto_operator_dlp_use_inspect_template] + + # [START howto_operator_dlp_delete_inspect_template] + delete_template = CloudDLPDeleteInspectTemplateOperator( + task_id="delete_template", + template_id=TEMPLATE_ID, + project_id=GCP_PROJECT, + ) + # [END howto_operator_dlp_delete_inspect_template] + + create_template >> inspect_content >> delete_template + +CUSTOM_INFO_TYPE_ID = "custom_info_type" +CUSTOM_INFO_TYPES = { + "large_custom_dictionary": { + "output_path": {"path": OBJECT_GCS_OUTPUT_URI}, + "cloud_storage_file_set": {"url": OBJECT_GCS_URI + "/"}, + } +} +UPDATE_CUSTOM_INFO_TYPE = { + "large_custom_dictionary": { + "output_path": {"path": OBJECT_GCS_OUTPUT_URI}, + "cloud_storage_file_set": {"url": OBJECT_GCS_URI + "/"}, + } +} + +with models.DAG( + "example_gcp_dlp_info_types", + schedule_interval=None, + start_date=days_ago(1), + tags=["example", "dlp", "info-types"], +) as dag2: + # [START howto_operator_dlp_create_info_type] + create_info_type = CloudDLPCreateStoredInfoTypeOperator( + project_id=GCP_PROJECT, + config=CUSTOM_INFO_TYPES, + stored_info_type_id=CUSTOM_INFO_TYPE_ID, + task_id="create_info_type", + ) + # [END howto_operator_dlp_create_info_type] + # [START howto_operator_dlp_update_info_type] + update_info_type = CloudDLPUpdateStoredInfoTypeOperator( + project_id=GCP_PROJECT, + stored_info_type_id=CUSTOM_INFO_TYPE_ID, + config=UPDATE_CUSTOM_INFO_TYPE, + task_id="update_info_type", + ) + # [END howto_operator_dlp_update_info_type] + # [START howto_operator_dlp_delete_info_type] + delete_info_type = CloudDLPDeleteStoredInfoTypeOperator( + project_id=GCP_PROJECT, + stored_info_type_id=CUSTOM_INFO_TYPE_ID, + task_id="delete_info_type", + ) + # [END howto_operator_dlp_delete_info_type] + create_info_type >> update_info_type >> delete_info_type + +JOB_TRIGGER = { + "inspect_job": { + "storage_config": { + "datastore_options": { + "partition_id": {"project_id": GCP_PROJECT}, + "kind": {"name": "test"}, + } + } + }, + "triggers": [ + {"schedule": {"recurrence_period_duration": {"seconds": 60 * 60 * 24}}} + ], + "status": "HEALTHY", +} + +TRIGGER_ID = "example_trigger" + +with models.DAG( + "example_gcp_dlp_job", + schedule_interval=None, + start_date=days_ago(1), + tags=["example", "dlp_job"], +) as dag3: # [START howto_operator_dlp_create_job_trigger] + create_trigger = CloudDLPCreateJobTriggerOperator( + project_id=GCP_PROJECT, + job_trigger=JOB_TRIGGER, + trigger_id=TRIGGER_ID, + task_id="create_trigger", + ) + # [END howto_operator_dlp_create_job_trigger] + + JOB_TRIGGER["triggers"] = [ + {"schedule": {"recurrence_period_duration": {"seconds": 2 * 60 * 60 * 24}}} + ] + + # [START howto_operator_dlp_update_job_trigger] + update_trigger = CloudDLPUpdateJobTriggerOperator( + project_id=GCP_PROJECT, + job_trigger_id=TRIGGER_ID, + job_trigger=JOB_TRIGGER, + task_id="update_info_type", + ) + # [END howto_operator_dlp_update_job_trigger] + # [START howto_operator_dlp_delete_job_trigger] + delete_trigger = CloudDLPDeleteJobTriggerOperator( + project_id=GCP_PROJECT, job_trigger_id=TRIGGER_ID, task_id="delete_info_type" + ) + # [END howto_operator_dlp_delete_job_trigger] + create_trigger >> update_trigger >> delete_trigger + +# [START dlp_deidentify_config_example] +DEIDENTIFY_CONFIG = { + "info_type_transformations": { + "transformations": [ + { + "primitive_transformation": { + "replace_config": { + "new_value": {"string_value": "[deidentified_number]"} + } + } + } + ] + } +} +# [END dlp_deidentify_config_example] + +with models.DAG( + "example_gcp_dlp_deidentify_content", + schedule_interval=None, + start_date=days_ago(1), + tags=["example", "dlp", "deidentify"], +) as dag4: + # [START _howto_operator_dlp_deidentify_content] + deidentify_content = CloudDLPDeidentifyContentOperator( + project_id=GCP_PROJECT, + item=ITEM, + deidentify_config=DEIDENTIFY_CONFIG, + inspect_config=INSPECT_CONFIG, + task_id="deidentify_content", + ) + # [END _howto_operator_dlp_deidentify_content] diff --git a/reference/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py b/reference/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py new file mode 100644 index 0000000..e656062 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py @@ -0,0 +1,133 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that shows how to use FacebookAdsReportToGcsOperator. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryExecuteQueryOperator, +) +from airflow.providers.google.cloud.operators.gcs import ( + GCSCreateBucketOperator, + GCSDeleteBucketOperator, +) +from airflow.providers.google.cloud.transfers.facebook_ads_to_gcs import ( + FacebookAdsReportToGcsOperator, +) +from airflow.providers.google.cloud.transfers.gcs_to_bigquery import ( + GCSToBigQueryOperator, +) +from airflow.utils.dates import days_ago +from facebook_business.adobjects.adsinsights import AdsInsights + +# [START howto_GCS_env_variables] +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "free-tier-1997") +GCS_BUCKET = os.environ.get("GCS_BUCKET", "airflow_bucket_fb") +GCS_OBJ_PATH = os.environ.get("GCS_OBJ_PATH", "Temp/this_is_my_report_csv.csv") +GCS_CONN_ID = os.environ.get("GCS_CONN_ID", "google_cloud_default") +DATASET_NAME = os.environ.get("DATASET_NAME", "airflow_test_dataset") +TABLE_NAME = os.environ.get("FB_TABLE_NAME", "airflow_test_datatable") +# [END howto_GCS_env_variables] + +# [START howto_FB_ADS_variables] +FIELDS = [ + AdsInsights.Field.campaign_name, + AdsInsights.Field.campaign_id, + AdsInsights.Field.ad_id, + AdsInsights.Field.clicks, + AdsInsights.Field.impressions, +] +PARAMS = {"level": "ad", "date_preset": "yesterday"} +# [END howto_FB_ADS_variables] + +with models.DAG( + "example_facebook_ads_to_gcs", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), +) as dag: + + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=GCS_BUCKET, + project_id=GCP_PROJECT_ID, + ) + + create_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create_dataset", + dataset_id=DATASET_NAME, + ) + + create_table = BigQueryCreateEmptyTableOperator( + task_id="create_table", + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + schema_fields=[ + {"name": "campaign_name", "type": "STRING", "mode": "NULLABLE"}, + {"name": "campaign_id", "type": "STRING", "mode": "NULLABLE"}, + {"name": "ad_id", "type": "STRING", "mode": "NULLABLE"}, + {"name": "clicks", "type": "STRING", "mode": "NULLABLE"}, + {"name": "impressions", "type": "STRING", "mode": "NULLABLE"}, + ], + ) + + # [START howto_operator_facebook_ads_to_gcs] + run_operator = FacebookAdsReportToGcsOperator( + task_id="run_fetch_data", + start_date=days_ago(2), + owner="airflow", + bucket_name=GCS_BUCKET, + params=PARAMS, + fields=FIELDS, + gcp_conn_id=GCS_CONN_ID, + object_name=GCS_OBJ_PATH, + ) + # [END howto_operator_facebook_ads_to_gcs] + + load_csv = GCSToBigQueryOperator( + task_id="gcs_to_bq_example", + bucket=GCS_BUCKET, + source_objects=[GCS_OBJ_PATH], + destination_project_dataset_table=f"{DATASET_NAME}.{TABLE_NAME}", + write_disposition="WRITE_TRUNCATE", + ) + + read_data_from_gcs_many_chunks = BigQueryExecuteQueryOperator( + task_id="read_data_from_gcs_many_chunks", + sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{TABLE_NAME}`", + use_legacy_sql=False, + ) + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", + bucket_name=GCS_BUCKET, + ) + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", + project_id=GCP_PROJECT_ID, + dataset_id=DATASET_NAME, + delete_contents=True, + ) + + create_bucket >> create_dataset >> create_table >> run_operator >> load_csv + load_csv >> read_data_from_gcs_many_chunks >> delete_bucket >> delete_dataset diff --git a/reference/providers/google/cloud/example_dags/example_functions.py b/reference/providers/google/cloud/example_dags/example_functions.py new file mode 100644 index 0000000..894f928 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_functions.py @@ -0,0 +1,138 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that displays interactions with Google Cloud Functions. +It creates a function and then deletes it. + +This DAG relies on the following OS environment variables +https://airflow.apache.org/concepts.html#variables + +* GCP_PROJECT_ID - Google Cloud Project to use for the Cloud Function. +* GCP_LOCATION - Google Cloud Functions region where the function should be + created. +* GCF_ENTRYPOINT - Name of the executable function in the source code. +* and one of the below: + + * GCF_SOURCE_ARCHIVE_URL - Path to the zipped source in Google Cloud Storage + + * GCF_SOURCE_UPLOAD_URL - Generated upload URL for the zipped source and GCF_ZIP_PATH - Local path to + the zipped source archive + + * GCF_SOURCE_REPOSITORY - The URL pointing to the hosted repository where the function + is defined in a supported Cloud Source Repository URL format + https://cloud.google.com/functions/docs/reference/rest/v1/projects.locations.functions#SourceRepository + +""" + +import os + +from airflow import models +from airflow.providers.google.cloud.operators.functions import ( + CloudFunctionDeleteFunctionOperator, + CloudFunctionDeployFunctionOperator, + CloudFunctionInvokeFunctionOperator, +) +from airflow.utils import dates + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCP_LOCATION = os.environ.get("GCP_LOCATION", "europe-west1") +GCF_SHORT_FUNCTION_NAME = os.environ.get("GCF_SHORT_FUNCTION_NAME", "hello").replace( + "-", "_" +) # make sure there are no dashes in function name (!) +FUNCTION_NAME = "projects/{}/locations/{}/functions/{}".format( + GCP_PROJECT_ID, GCP_LOCATION, GCF_SHORT_FUNCTION_NAME +) +GCF_SOURCE_ARCHIVE_URL = os.environ.get("GCF_SOURCE_ARCHIVE_URL", "") +GCF_SOURCE_UPLOAD_URL = os.environ.get("GCF_SOURCE_UPLOAD_URL", "") +GCF_SOURCE_REPOSITORY = os.environ.get( + "GCF_SOURCE_REPOSITORY", + "https://source.developers.google.com/" + "projects/{}/repos/hello-world/moveable-aliases/master".format(GCP_PROJECT_ID), +) +GCF_ZIP_PATH = os.environ.get("GCF_ZIP_PATH", "") +GCF_ENTRYPOINT = os.environ.get("GCF_ENTRYPOINT", "helloWorld") +GCF_RUNTIME = "nodejs6" +GCP_VALIDATE_BODY = os.environ.get("GCP_VALIDATE_BODY", "True") == "True" + +# [START howto_operator_gcf_deploy_body] +body = { + "name": FUNCTION_NAME, + "entryPoint": GCF_ENTRYPOINT, + "runtime": GCF_RUNTIME, + "httpsTrigger": {}, +} +# [END howto_operator_gcf_deploy_body] + +# [START howto_operator_gcf_default_args] +default_args = {"owner": "airflow"} +# [END howto_operator_gcf_default_args] + +# [START howto_operator_gcf_deploy_variants] +if GCF_SOURCE_ARCHIVE_URL: + body["sourceArchiveUrl"] = GCF_SOURCE_ARCHIVE_URL +elif GCF_SOURCE_REPOSITORY: + body["sourceRepository"] = {"url": GCF_SOURCE_REPOSITORY} +elif GCF_ZIP_PATH: + body["sourceUploadUrl"] = "" + default_args["zip_path"] = GCF_ZIP_PATH +elif GCF_SOURCE_UPLOAD_URL: + body["sourceUploadUrl"] = GCF_SOURCE_UPLOAD_URL +else: + raise Exception("Please provide one of the source_code parameters") +# [END howto_operator_gcf_deploy_variants] + + +with models.DAG( + "example_gcp_function", + schedule_interval=None, # Override to match your needs + start_date=dates.days_ago(1), + tags=["example"], +) as dag: + # [START howto_operator_gcf_deploy] + deploy_task = CloudFunctionDeployFunctionOperator( + task_id="gcf_deploy_task", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + body=body, + validate_body=GCP_VALIDATE_BODY, + ) + # [END howto_operator_gcf_deploy] + # [START howto_operator_gcf_deploy_no_project_id] + deploy2_task = CloudFunctionDeployFunctionOperator( + task_id="gcf_deploy2_task", + location=GCP_LOCATION, + body=body, + validate_body=GCP_VALIDATE_BODY, + ) + # [END howto_operator_gcf_deploy_no_project_id] + # [START howto_operator_gcf_invoke_function] + invoke_task = CloudFunctionInvokeFunctionOperator( + task_id="invoke_task", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + input_data={}, + function_id=GCF_SHORT_FUNCTION_NAME, + ) + # [END howto_operator_gcf_invoke_function] + # [START howto_operator_gcf_delete] + delete_task = CloudFunctionDeleteFunctionOperator( + task_id="gcf_delete_task", name=FUNCTION_NAME + ) + # [END howto_operator_gcf_delete] + deploy_task >> deploy2_task >> invoke_task >> delete_task diff --git a/reference/providers/google/cloud/example_dags/example_gcs.py b/reference/providers/google/cloud/example_dags/example_gcs.py new file mode 100644 index 0000000..06faa96 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_gcs.py @@ -0,0 +1,216 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Google Cloud Storage operators. +""" + +import os + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.gcs import ( + GCSBucketCreateAclEntryOperator, + GCSCreateBucketOperator, + GCSDeleteBucketOperator, + GCSDeleteObjectsOperator, + GCSFileTransformOperator, + GCSListObjectsOperator, + GCSObjectCreateAclEntryOperator, +) +from airflow.providers.google.cloud.sensors.gcs import ( + GCSObjectExistenceSensor, + GCSObjectsWithPrefixExistenceSensor, +) +from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator +from airflow.providers.google.cloud.transfers.gcs_to_local import ( + GCSToLocalFilesystemOperator, +) +from airflow.providers.google.cloud.transfers.local_to_gcs import ( + LocalFilesystemToGCSOperator, +) +from airflow.utils.dates import days_ago +from airflow.utils.state import State + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-id") +BUCKET_1 = os.environ.get("GCP_GCS_BUCKET_1", "test-gcs-example-bucket") +GCS_ACL_ENTITY = os.environ.get("GCS_ACL_ENTITY", "allUsers") +GCS_ACL_BUCKET_ROLE = "OWNER" +GCS_ACL_OBJECT_ROLE = "OWNER" + +BUCKET_2 = os.environ.get("GCP_GCS_BUCKET_2", "test-gcs-example-bucket-2") + +PATH_TO_TRANSFORM_SCRIPT = os.environ.get("GCP_GCS_PATH_TO_TRANSFORM_SCRIPT", "test.py") +PATH_TO_UPLOAD_FILE = os.environ.get( + "GCP_GCS_PATH_TO_UPLOAD_FILE", "test-gcs-example.txt" +) +PATH_TO_UPLOAD_FILE_PREFIX = os.environ.get( + "GCP_GCS_PATH_TO_UPLOAD_FILE_PREFIX", "test-gcs-" +) +PATH_TO_SAVED_FILE = os.environ.get( + "GCP_GCS_PATH_TO_SAVED_FILE", "test-gcs-example-download.txt" +) + +BUCKET_FILE_LOCATION = PATH_TO_UPLOAD_FILE.rpartition("/")[-1] + +with models.DAG( + "example_gcs", + start_date=days_ago(1), + schedule_interval=None, + tags=["example"], +) as dag: + create_bucket1 = GCSCreateBucketOperator( + task_id="create_bucket1", bucket_name=BUCKET_1, project_id=PROJECT_ID + ) + + create_bucket2 = GCSCreateBucketOperator( + task_id="create_bucket2", bucket_name=BUCKET_2, project_id=PROJECT_ID + ) + + list_buckets = GCSListObjectsOperator(task_id="list_buckets", bucket=BUCKET_1) + + list_buckets_result = BashOperator( + task_id="list_buckets_result", + bash_command="echo \"{{ task_instance.xcom_pull('list_buckets') }}\"", + ) + + upload_file = LocalFilesystemToGCSOperator( + task_id="upload_file", + src=PATH_TO_UPLOAD_FILE, + dst=BUCKET_FILE_LOCATION, + bucket=BUCKET_1, + ) + + transform_file = GCSFileTransformOperator( + task_id="transform_file", + source_bucket=BUCKET_1, + source_object=BUCKET_FILE_LOCATION, + transform_script=["python", PATH_TO_TRANSFORM_SCRIPT], + ) + # [START howto_operator_gcs_bucket_create_acl_entry_task] + gcs_bucket_create_acl_entry_task = GCSBucketCreateAclEntryOperator( + bucket=BUCKET_1, + entity=GCS_ACL_ENTITY, + role=GCS_ACL_BUCKET_ROLE, + task_id="gcs_bucket_create_acl_entry_task", + ) + # [END howto_operator_gcs_bucket_create_acl_entry_task] + + # [START howto_operator_gcs_object_create_acl_entry_task] + gcs_object_create_acl_entry_task = GCSObjectCreateAclEntryOperator( + bucket=BUCKET_1, + object_name=BUCKET_FILE_LOCATION, + entity=GCS_ACL_ENTITY, + role=GCS_ACL_OBJECT_ROLE, + task_id="gcs_object_create_acl_entry_task", + ) + # [END howto_operator_gcs_object_create_acl_entry_task] + + # [START howto_operator_gcs_download_file_task] + download_file = GCSToLocalFilesystemOperator( + task_id="download_file", + object_name=BUCKET_FILE_LOCATION, + bucket=BUCKET_1, + filename=PATH_TO_SAVED_FILE, + ) + # [END howto_operator_gcs_download_file_task] + + copy_file = GCSToGCSOperator( + task_id="copy_file", + source_bucket=BUCKET_1, + source_object=BUCKET_FILE_LOCATION, + destination_bucket=BUCKET_2, + destination_object=BUCKET_FILE_LOCATION, + ) + + delete_files = GCSDeleteObjectsOperator( + task_id="delete_files", bucket_name=BUCKET_1, objects=[BUCKET_FILE_LOCATION] + ) + + # [START howto_operator_gcs_delete_bucket] + delete_bucket_1 = GCSDeleteBucketOperator( + task_id="delete_bucket_1", bucket_name=BUCKET_1 + ) + delete_bucket_2 = GCSDeleteBucketOperator( + task_id="delete_bucket_2", bucket_name=BUCKET_2 + ) + # [END howto_operator_gcs_delete_bucket] + + [create_bucket1, create_bucket2] >> list_buckets >> list_buckets_result + [create_bucket1, create_bucket2] >> upload_file + upload_file >> [download_file, copy_file] + upload_file >> gcs_bucket_create_acl_entry_task >> gcs_object_create_acl_entry_task >> delete_files + + create_bucket1 >> delete_bucket_1 + create_bucket2 >> delete_bucket_2 + create_bucket2 >> copy_file + create_bucket1 >> copy_file + list_buckets >> delete_bucket_1 + upload_file >> delete_bucket_1 + create_bucket1 >> upload_file >> delete_bucket_1 + upload_file >> transform_file >> delete_bucket_1 + gcs_bucket_create_acl_entry_task >> delete_bucket_1 + gcs_object_create_acl_entry_task >> delete_bucket_1 + download_file >> delete_bucket_1 + copy_file >> delete_bucket_1 + copy_file >> delete_bucket_2 + delete_files >> delete_bucket_1 + +with models.DAG( + "example_gcs_sensors", + start_date=days_ago(1), + schedule_interval=None, + tags=["example"], +) as dag2: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", bucket_name=BUCKET_1, project_id=PROJECT_ID + ) + upload_file = LocalFilesystemToGCSOperator( + task_id="upload_file", + src=PATH_TO_UPLOAD_FILE, + dst=BUCKET_FILE_LOCATION, + bucket=BUCKET_1, + ) + # [START howto_sensor_object_exists_task] + gcs_object_exists = GCSObjectExistenceSensor( + bucket=BUCKET_1, + object=PATH_TO_UPLOAD_FILE, + mode="poke", + task_id="gcs_object_exists_task", + ) + # [END howto_sensor_object_exists_task] + # [START howto_sensor_object_with_prefix_exists_task] + gcs_object_with_prefix_exists = GCSObjectsWithPrefixExistenceSensor( + bucket=BUCKET_1, + prefix=PATH_TO_UPLOAD_FILE_PREFIX, + mode="poke", + task_id="gcs_object_with_prefix_exists_task", + ) + # [END howto_sensor_object_with_prefix_exists_task] + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", bucket_name=BUCKET_1 + ) + + create_bucket >> upload_file >> [ + gcs_object_exists, + gcs_object_with_prefix_exists, + ] >> delete_bucket + + +if __name__ == "__main__": + dag.clear(dag_run_state=State.NONE) + dag.run() diff --git a/reference/providers/google/cloud/example_dags/example_gcs_timespan_file_transform.py b/reference/providers/google/cloud/example_dags/example_gcs_timespan_file_transform.py new file mode 100644 index 0000000..5663395 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_gcs_timespan_file_transform.py @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Google Cloud Storage time-span file transform operator. +""" + +import os + +from airflow import models +from airflow.providers.google.cloud.operators.gcs import ( + GCSTimeSpanFileTransformOperator, +) +from airflow.utils.dates import days_ago +from airflow.utils.state import State + +SOURCE_BUCKET = os.environ.get("GCP_GCS_BUCKET_1", "test-gcs-example-bucket") +SOURCE_PREFIX = "gcs_timespan_file_transform_source" +SOURCE_GCP_CONN_ID = "google_cloud_default" +DESTINATION_BUCKET = SOURCE_BUCKET +DESTINATION_PREFIX = "gcs_timespan_file_transform_destination" +DESTINATION_GCP_CONN_ID = "google_cloud_default" + +PATH_TO_TRANSFORM_SCRIPT = os.environ.get( + "GCP_GCS_PATH_TO_TRANSFORM_SCRIPT", "test_gcs_timespan_transform_script.py" +) + + +with models.DAG( + "example_gcs_timespan_file_transform", + start_date=days_ago(1), + schedule_interval=None, + tags=["example"], +) as dag: + + # [START howto_operator_gcs_timespan_file_transform_operator_Task] + gcs_timespan_transform_files_task = GCSTimeSpanFileTransformOperator( + task_id="gcs_timespan_transform_files", + source_bucket=SOURCE_BUCKET, + source_prefix=SOURCE_PREFIX, + source_gcp_conn_id=SOURCE_GCP_CONN_ID, + destination_bucket=DESTINATION_BUCKET, + destination_prefix=DESTINATION_PREFIX, + destination_gcp_conn_id=DESTINATION_GCP_CONN_ID, + transform_script=["python", PATH_TO_TRANSFORM_SCRIPT], + ) + # [END howto_operator_gcs_timespan_file_transform_operator_Task] + + +if __name__ == "__main__": + dag.clear(dag_run_state=State.NONE) + dag.run() diff --git a/reference/providers/google/cloud/example_dags/example_gcs_to_bigquery.py b/reference/providers/google/cloud/example_dags/example_gcs_to_bigquery.py new file mode 100644 index 0000000..6a567fa --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_gcs_to_bigquery.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example DAG using GCSToBigQueryOperator. +""" + +import os + +from airflow import models +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryDeleteDatasetOperator, +) +from airflow.providers.google.cloud.transfers.gcs_to_bigquery import ( + GCSToBigQueryOperator, +) +from airflow.utils.dates import days_ago + +DATASET_NAME = os.environ.get("GCP_DATASET_NAME", "airflow_test") +TABLE_NAME = os.environ.get("GCP_TABLE_NAME", "gcs_to_bq_table") + +dag = models.DAG( + dag_id="example_gcs_to_bigquery_operator", + start_date=days_ago(2), + schedule_interval=None, + tags=["example"], +) + +create_test_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create_airflow_test_dataset", dataset_id=DATASET_NAME, dag=dag +) + +# [START howto_operator_gcs_to_bigquery] +load_csv = GCSToBigQueryOperator( + task_id="gcs_to_bigquery_example", + bucket="cloud-samples-data", + source_objects=["bigquery/us-states/us-states.csv"], + destination_project_dataset_table=f"{DATASET_NAME}.{TABLE_NAME}", + schema_fields=[ + {"name": "name", "type": "STRING", "mode": "NULLABLE"}, + {"name": "post_abbr", "type": "STRING", "mode": "NULLABLE"}, + ], + write_disposition="WRITE_TRUNCATE", + dag=dag, +) +# [END howto_operator_gcs_to_bigquery] + +delete_test_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_airflow_test_dataset", + dataset_id=DATASET_NAME, + delete_contents=True, + dag=dag, +) + +create_test_dataset >> load_csv >> delete_test_dataset diff --git a/reference/providers/google/cloud/example_dags/example_gcs_to_gcs.py b/reference/providers/google/cloud/example_dags/example_gcs_to_gcs.py new file mode 100644 index 0000000..e94bbed --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_gcs_to_gcs.py @@ -0,0 +1,157 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Google Cloud Storage to Google Cloud Storage transfer operators. +""" + +import os + +from airflow import models +from airflow.providers.google.cloud.operators.gcs import GCSSynchronizeBucketsOperator +from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator +from airflow.utils.dates import days_ago + +BUCKET_1_SRC = os.environ.get("GCP_GCS_BUCKET_1_SRC", "test-gcs-sync-1-src") +BUCKET_1_DST = os.environ.get("GCP_GCS_BUCKET_1_DST", "test-gcs-sync-1-dst") + +BUCKET_2_SRC = os.environ.get("GCP_GCS_BUCKET_2_SRC", "test-gcs-sync-2-src") +BUCKET_2_DST = os.environ.get("GCP_GCS_BUCKET_2_DST", "test-gcs-sync-2-dst") + +BUCKET_3_SRC = os.environ.get("GCP_GCS_BUCKET_3_SRC", "test-gcs-sync-3-src") +BUCKET_3_DST = os.environ.get("GCP_GCS_BUCKET_3_DST", "test-gcs-sync-3-dst") + +OBJECT_1 = os.environ.get("GCP_GCS_OBJECT_1", "test-gcs-to-gcs-1") +OBJECT_2 = os.environ.get("GCP_GCS_OBJECT_2", "test-gcs-to-gcs-2") + +with models.DAG( + "example_gcs_to_gcs", + start_date=days_ago(1), + schedule_interval=None, + tags=["example"], +) as dag: + # [START howto_synch_bucket] + sync_bucket = GCSSynchronizeBucketsOperator( + task_id="sync_bucket", + source_bucket=BUCKET_1_SRC, + destination_bucket=BUCKET_1_DST, + ) + # [END howto_synch_bucket] + + # [START howto_synch_full_bucket] + sync_full_bucket = GCSSynchronizeBucketsOperator( + task_id="sync_full_bucket", + source_bucket=BUCKET_1_SRC, + destination_bucket=BUCKET_1_DST, + delete_extra_files=True, + allow_overwrite=True, + ) + # [END howto_synch_full_bucket] + + # [START howto_synch_to_subdir] + sync_to_subdirectory = GCSSynchronizeBucketsOperator( + task_id="sync_to_subdirectory", + source_bucket=BUCKET_1_SRC, + destination_bucket=BUCKET_1_DST, + destination_object="subdir/", + ) + # [END howto_synch_to_subdir] + + # [START howto_sync_from_subdir] + sync_from_subdirectory = GCSSynchronizeBucketsOperator( + task_id="sync_from_subdirectory", + source_bucket=BUCKET_1_SRC, + source_object="subdir/", + destination_bucket=BUCKET_1_DST, + ) + # [END howto_sync_from_subdir] + + # [START howto_operator_gcs_to_gcs_single_file] + copy_single_file = GCSToGCSOperator( + task_id="copy_single_gcs_file", + source_bucket=BUCKET_1_SRC, + source_object=OBJECT_1, + destination_bucket=BUCKET_1_DST, # If not supplied the source_bucket value will be used + destination_object="backup_" + + OBJECT_1, # If not supplied the source_object value will be used + ) + # [END howto_operator_gcs_to_gcs_single_file] + + # [START howto_operator_gcs_to_gcs_wildcard] + copy_files_with_wildcard = GCSToGCSOperator( + task_id="copy_files_with_wildcard", + source_bucket=BUCKET_1_SRC, + source_object="data/*.txt", + destination_bucket=BUCKET_1_DST, + destination_object="backup/", + ) + # [END howto_operator_gcs_to_gcs_wildcard] + + # [START howto_operator_gcs_to_gcs_without_wildcard] + copy_files_without_wildcard = GCSToGCSOperator( + task_id="copy_files_without_wildcard", + source_bucket=BUCKET_1_SRC, + source_object="subdir/", + destination_bucket=BUCKET_1_DST, + destination_object="backup/", + ) + # [END howto_operator_gcs_to_gcs_without_wildcard] + + # [START howto_operator_gcs_to_gcs_delimiter] + copy_files_with_delimiter = GCSToGCSOperator( + task_id="copy_files_with_delimiter", + source_bucket=BUCKET_1_SRC, + source_object="data/", + destination_bucket=BUCKET_1_DST, + destination_object="backup/", + delimiter=".txt", + ) + # [END howto_operator_gcs_to_gcs_delimiter] + + # [START howto_operator_gcs_to_gcs_list] + copy_files_with_list = GCSToGCSOperator( + task_id="copy_files_with_list", + source_bucket=BUCKET_1_SRC, + source_objects=[ + OBJECT_1, + OBJECT_2, + ], # Instead of files each element could be a wildcard expression + destination_bucket=BUCKET_1_DST, + destination_object="backup/", + ) + # [END howto_operator_gcs_to_gcs_list] + + # [START howto_operator_gcs_to_gcs_single_file_move] + move_single_file = GCSToGCSOperator( + task_id="move_single_file", + source_bucket=BUCKET_1_SRC, + source_object=OBJECT_1, + destination_bucket=BUCKET_1_DST, + destination_object="backup_" + OBJECT_1, + move_object=True, + ) + # [END howto_operator_gcs_to_gcs_single_file_move] + + # [START howto_operator_gcs_to_gcs_list_move] + move_files_with_list = GCSToGCSOperator( + task_id="move_files_with_list", + source_bucket=BUCKET_1_SRC, + source_objects=[OBJECT_1, OBJECT_2], + destination_bucket=BUCKET_1_DST, + destination_object="backup/", + ) + # [END howto_operator_gcs_to_gcs_list_move] diff --git a/reference/providers/google/cloud/example_dags/example_gcs_to_local.py b/reference/providers/google/cloud/example_dags/example_gcs_to_local.py new file mode 100644 index 0000000..cc558e0 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_gcs_to_local.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.google.cloud.transfers.gcs_to_local import ( + GCSToLocalFilesystemOperator, +) +from airflow.utils.dates import days_ago + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-id") +BUCKET = os.environ.get("GCP_GCS_BUCKET", "test-gcs-example-bucket") + +PATH_TO_REMOTE_FILE = os.environ.get( + "GCP_GCS_PATH_TO_UPLOAD_FILE", "test-gcs-example-remote.txt" +) +PATH_TO_LOCAL_FILE = os.environ.get( + "GCP_GCS_PATH_TO_SAVED_FILE", "test-gcs-example-local.txt" +) + +with models.DAG( + "example_gcs_to_local", + start_date=days_ago(1), + schedule_interval=None, + tags=["example"], +) as dag: + # [START howto_operator_gcs_download_file_task] + download_file = GCSToLocalFilesystemOperator( + task_id="download_file", + object_name=PATH_TO_REMOTE_FILE, + bucket=BUCKET, + filename=PATH_TO_LOCAL_FILE, + ) + # [END howto_operator_gcs_download_file_task] diff --git a/reference/providers/google/cloud/example_dags/example_gcs_to_sftp.py b/reference/providers/google/cloud/example_dags/example_gcs_to_sftp.py new file mode 100644 index 0000000..e353214 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_gcs_to_sftp.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Google Cloud Storage to SFTP transfer operators. +""" + +import os + +from airflow import models +from airflow.providers.google.cloud.transfers.gcs_to_sftp import GCSToSFTPOperator +from airflow.providers.sftp.sensors.sftp import SFTPSensor +from airflow.utils.dates import days_ago + +SFTP_CONN_ID = "ssh_default" +BUCKET_SRC = os.environ.get("GCP_GCS_BUCKET_1_SRC", "test-gcs-sftp") +OBJECT_SRC_1 = "parent-1.bin" +OBJECT_SRC_2 = "dir-1/parent-2.bin" +OBJECT_SRC_3 = "dir-2/*" +DESTINATION_PATH_1 = "/tmp/single-file/" +DESTINATION_PATH_2 = "/tmp/dest-dir-1/" +DESTINATION_PATH_3 = "/tmp/dest-dir-2/" + + +with models.DAG( + "example_gcs_to_sftp", + start_date=days_ago(1), + schedule_interval=None, + tags=["example"], +) as dag: + # [START howto_operator_gcs_to_sftp_copy_single_file] + copy_file_from_gcs_to_sftp = GCSToSFTPOperator( + task_id="file-copy-gsc-to-sftp", + sftp_conn_id=SFTP_CONN_ID, + source_bucket=BUCKET_SRC, + source_object=OBJECT_SRC_1, + destination_path=DESTINATION_PATH_1, + ) + # [END howto_operator_gcs_to_sftp_copy_single_file] + + check_copy_file_from_gcs_to_sftp = SFTPSensor( + task_id="check-file-copy-gsc-to-sftp", + sftp_conn_id=SFTP_CONN_ID, + timeout=60, + path=os.path.join(DESTINATION_PATH_1, OBJECT_SRC_1), + ) + + # [START howto_operator_gcs_to_sftp_move_single_file_destination] + move_file_from_gcs_to_sftp = GCSToSFTPOperator( + task_id="file-move-gsc-to-sftp", + sftp_conn_id=SFTP_CONN_ID, + source_bucket=BUCKET_SRC, + source_object=OBJECT_SRC_2, + destination_path=DESTINATION_PATH_1, + move_object=True, + ) + # [END howto_operator_gcs_to_sftp_move_single_file_destination] + + check_move_file_from_gcs_to_sftp = SFTPSensor( + task_id="check-file-move-gsc-to-sftp", + sftp_conn_id=SFTP_CONN_ID, + timeout=60, + path=os.path.join(DESTINATION_PATH_1, OBJECT_SRC_2), + ) + + # [START howto_operator_gcs_to_sftp_copy_directory] + copy_dir_from_gcs_to_sftp = GCSToSFTPOperator( + task_id="dir-copy-gsc-to-sftp", + sftp_conn_id=SFTP_CONN_ID, + source_bucket=BUCKET_SRC, + source_object=OBJECT_SRC_3, + destination_path=DESTINATION_PATH_2, + ) + # [END howto_operator_gcs_to_sftp_copy_directory] + + check_copy_dir_from_gcs_to_sftp = SFTPSensor( + task_id="check-dir-copy-gsc-to-sftp", + sftp_conn_id=SFTP_CONN_ID, + timeout=60, + path=os.path.join(DESTINATION_PATH_2, "dir-2", OBJECT_SRC_1), + ) + + # [START howto_operator_gcs_to_sftp_move_specific_files] + move_dir_from_gcs_to_sftp = GCSToSFTPOperator( + task_id="dir-move-gsc-to-sftp", + sftp_conn_id=SFTP_CONN_ID, + source_bucket=BUCKET_SRC, + source_object=OBJECT_SRC_3, + destination_path=DESTINATION_PATH_3, + keep_directory_structure=False, + ) + # [END howto_operator_gcs_to_sftp_move_specific_files] + + check_move_dir_from_gcs_to_sftp = SFTPSensor( + task_id="check-dir-move-gsc-to-sftp", + sftp_conn_id=SFTP_CONN_ID, + timeout=60, + path=os.path.join(DESTINATION_PATH_3, OBJECT_SRC_1), + ) + + move_file_from_gcs_to_sftp >> check_move_file_from_gcs_to_sftp + copy_dir_from_gcs_to_sftp >> check_copy_file_from_gcs_to_sftp + + copy_dir_from_gcs_to_sftp >> move_dir_from_gcs_to_sftp + copy_dir_from_gcs_to_sftp >> check_copy_dir_from_gcs_to_sftp + move_dir_from_gcs_to_sftp >> check_move_dir_from_gcs_to_sftp diff --git a/reference/providers/google/cloud/example_dags/example_gdrive_to_gcs.py b/reference/providers/google/cloud/example_dags/example_gdrive_to_gcs.py new file mode 100644 index 0000000..f970318 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_gdrive_to_gcs.py @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.google.cloud.transfers.gdrive_to_gcs import ( + GoogleDriveToGCSOperator, +) +from airflow.providers.google.suite.sensors.drive import GoogleDriveFileExistenceSensor +from airflow.utils.dates import days_ago + +BUCKET = os.environ.get("GCP_GCS_BUCKET", "test28397yeo") +OBJECT = os.environ.get("GCP_GCS_OBJECT", "abc123xyz") +FOLDER_ID = os.environ.get("FILE_ID", "1234567890qwerty") +FILE_NAME = os.environ.get("FILE_NAME", "file.pdf") + +with models.DAG( + "example_gdrive_to_gcs_with_gdrive_sensor", + start_date=days_ago(1), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag: + # [START detect_file] + detect_file = GoogleDriveFileExistenceSensor( + task_id="detect_file", folder_id=FOLDER_ID, file_name=FILE_NAME + ) + # [END detect_file] + # [START upload_gdrive_to_gcs] + upload_gdrive_to_gcs = GoogleDriveToGCSOperator( + task_id="upload_gdrive_object_to_gcs", + folder_id=FOLDER_ID, + file_name=FILE_NAME, + destination_bucket=BUCKET, + destination_object=OBJECT, + ) + # [END upload_gdrive_to_gcs] + detect_file >> upload_gdrive_to_gcs diff --git a/reference/providers/google/cloud/example_dags/example_gdrive_to_local.py b/reference/providers/google/cloud/example_dags/example_gdrive_to_local.py new file mode 100644 index 0000000..f27e53d --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_gdrive_to_local.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.google.cloud.transfers.gdrive_to_local import ( + GoogleDriveToLocalOperator, +) +from airflow.providers.google.suite.sensors.drive import GoogleDriveFileExistenceSensor +from airflow.utils.dates import days_ago + +FOLDER_ID = os.environ.get("FILE_ID", "1234567890qwerty") +FILE_NAME = os.environ.get("FILE_NAME", "file.pdf") +OUTPUT_FILE = os.environ.get("OUTPUT_FILE", "out_file.pdf") + +with models.DAG( + "example_gdrive_to_local_with_gdrive_sensor", + start_date=days_ago(1), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag: + # [START detect_file] + detect_file = GoogleDriveFileExistenceSensor( + task_id="detect_file", folder_id=FOLDER_ID, file_name=FILE_NAME + ) + # [END detect_file] + # [START download_from_gdrive_to_local] + download_from_gdrive_to_local = GoogleDriveToLocalOperator( + task_id="download_from_gdrive_to_local", + folder_id=FOLDER_ID, + file_name=FILE_NAME, + output_file=OUTPUT_FILE, + ) + # [END download_from_gdrive_to_local] + detect_file >> download_from_gdrive_to_local diff --git a/reference/providers/google/cloud/example_dags/example_kubernetes_engine.py b/reference/providers/google/cloud/example_dags/example_kubernetes_engine.py new file mode 100644 index 0000000..4c873fb --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_kubernetes_engine.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Google Kubernetes Engine. +""" + +import os + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.kubernetes_engine import ( + GKECreateClusterOperator, + GKEDeleteClusterOperator, + GKEStartPodOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCP_LOCATION = os.environ.get("GCP_GKE_LOCATION", "europe-north1-a") +CLUSTER_NAME = os.environ.get("GCP_GKE_CLUSTER_NAME", "cluster-name") + +# [START howto_operator_gcp_gke_create_cluster_definition] +CLUSTER = {"name": CLUSTER_NAME, "initial_node_count": 1} +# [END howto_operator_gcp_gke_create_cluster_definition] + +with models.DAG( + "example_gcp_gke", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + # [START howto_operator_gke_create_cluster] + create_cluster = GKECreateClusterOperator( + task_id="create_cluster", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + body=CLUSTER, + ) + # [END howto_operator_gke_create_cluster] + + pod_task = GKEStartPodOperator( + task_id="pod_task", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + namespace="default", + image="perl", + name="test-pod", + ) + + # [START howto_operator_gke_start_pod_xcom] + pod_task_xcom = GKEStartPodOperator( + task_id="pod_task_xcom", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + do_xcom_push=True, + namespace="default", + image="alpine", + cmds=[ + "sh", + "-c", + "mkdir -p /airflow/xcom/;echo '[1,2,3,4]' > /airflow/xcom/return.json", + ], + name="test-pod-xcom", + ) + # [END howto_operator_gke_start_pod_xcom] + + # [START howto_operator_gke_xcom_result] + pod_task_xcom_result = BashOperator( + bash_command="echo \"{{ task_instance.xcom_pull('pod_task_xcom')[0] }}\"", + task_id="pod_task_xcom_result", + ) + # [END howto_operator_gke_xcom_result] + + # [START howto_operator_gke_delete_cluster] + delete_cluster = GKEDeleteClusterOperator( + task_id="delete_cluster", + name=CLUSTER_NAME, + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + ) + # [END howto_operator_gke_delete_cluster] + + create_cluster >> pod_task >> delete_cluster + create_cluster >> pod_task_xcom >> delete_cluster + pod_task_xcom >> pod_task_xcom_result diff --git a/reference/providers/google/cloud/example_dags/example_life_sciences.py b/reference/providers/google/cloud/example_dags/example_life_sciences.py new file mode 100644 index 0000000..3f76d75 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_life_sciences.py @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.google.cloud.operators.life_sciences import ( + LifeSciencesRunPipelineOperator, +) +from airflow.utils import dates + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project-id") +BUCKET = os.environ.get("GCP_GCS_LIFE_SCIENCES_BUCKET", "example-life-sciences-bucket") +FILENAME = os.environ.get("GCP_GCS_LIFE_SCIENCES_FILENAME", "input.in") +LOCATION = os.environ.get("GCP_LIFE_SCIENCES_LOCATION", "us-central1") + + +# [START howto_configure_simple_action_pipeline] +SIMPLE_ACTION_PIPELINE = { + "pipeline": { + "actions": [ + {"imageUri": "bash", "commands": ["-c", "echo Hello, world"]}, + ], + "resources": { + "regions": [f"{LOCATION}"], + "virtualMachine": { + "machineType": "n1-standard-1", + }, + }, + }, +} +# [END howto_configure_simple_action_pipeline] + +# [START howto_configure_multiple_action_pipeline] +MULTI_ACTION_PIPELINE = { + "pipeline": { + "actions": [ + { + "imageUri": "google/cloud-sdk", + "commands": ["gsutil", "cp", f"gs://{BUCKET}/{FILENAME}", "/tmp"], + }, + {"imageUri": "bash", "commands": ["-c", "echo Hello, world"]}, + { + "imageUri": "google/cloud-sdk", + "commands": [ + "gsutil", + "cp", + f"gs://{BUCKET}/{FILENAME}", + f"gs://{BUCKET}/output.in", + ], + }, + ], + "resources": { + "regions": [f"{LOCATION}"], + "virtualMachine": { + "machineType": "n1-standard-1", + }, + }, + } +} +# [END howto_configure_multiple_action_pipeline] + +with models.DAG( + "example_gcp_life_sciences", + default_args=dict(start_date=dates.days_ago(1)), + schedule_interval=None, + tags=["example"], +) as dag: + + # [START howto_run_pipeline] + simple_life_science_action_pipeline = LifeSciencesRunPipelineOperator( + task_id="simple-action-pipeline", + body=SIMPLE_ACTION_PIPELINE, + project_id=PROJECT_ID, + location=LOCATION, + ) + # [END howto_run_pipeline] + + multiple_life_science_action_pipeline = LifeSciencesRunPipelineOperator( + task_id="multi-action-pipeline", + body=MULTI_ACTION_PIPELINE, + project_id=PROJECT_ID, + location=LOCATION, + ) + + simple_life_science_action_pipeline >> multiple_life_science_action_pipeline diff --git a/reference/providers/google/cloud/example_dags/example_local_to_gcs.py b/reference/providers/google/cloud/example_dags/example_local_to_gcs.py new file mode 100644 index 0000000..93b953b --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_local_to_gcs.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.google.cloud.transfers.local_to_gcs import ( + LocalFilesystemToGCSOperator, +) +from airflow.utils import dates + +# [START howto_gcs_environment_variables] +BUCKET_NAME = os.environ.get("GCP_GCS_BUCKET", "example-bucket-name") +PATH_TO_UPLOAD_FILE = os.environ.get("GCP_GCS_PATH_TO_UPLOAD_FILE", "example-text.txt") +DESTINATION_FILE_LOCATION = os.environ.get( + "GCP_GCS_DESTINATION_FILE_LOCATION", "example-text.txt" +) +# [END howto_gcs_environment_variables] + +with models.DAG( + "example_local_to_gcs", + default_args=dict(start_date=dates.days_ago(1)), + schedule_interval=None, + tags=["example"], +) as dag: + # [START howto_operator_local_filesystem_to_gcs] + upload_file = LocalFilesystemToGCSOperator( + task_id="upload_file", + src=PATH_TO_UPLOAD_FILE, + dst=DESTINATION_FILE_LOCATION, + bucket=BUCKET_NAME, + ) + # [END howto_operator_local_filesystem_to_gcs] diff --git a/reference/providers/google/cloud/example_dags/example_mlengine.py b/reference/providers/google/cloud/example_dags/example_mlengine.py new file mode 100644 index 0000000..a86ee31 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_mlengine.py @@ -0,0 +1,277 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google ML Engine service. +""" +import os +from typing import Dict + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.mlengine import ( + MLEngineCreateModelOperator, + MLEngineCreateVersionOperator, + MLEngineDeleteModelOperator, + MLEngineDeleteVersionOperator, + MLEngineGetModelOperator, + MLEngineListVersionsOperator, + MLEngineSetDefaultVersionOperator, + MLEngineStartBatchPredictionJobOperator, + MLEngineStartTrainingJobOperator, +) +from airflow.providers.google.cloud.utils import mlengine_operator_utils +from airflow.utils.dates import days_ago + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") + +MODEL_NAME = os.environ.get("GCP_MLENGINE_MODEL_NAME", "model_name") + +SAVED_MODEL_PATH = os.environ.get( + "GCP_MLENGINE_SAVED_MODEL_PATH", "gs://test-airflow-mlengine/saved-model/" +) +JOB_DIR = os.environ.get( + "GCP_MLENGINE_JOB_DIR", "gs://test-airflow-mlengine/keras-job-dir" +) +PREDICTION_INPUT = os.environ.get( + "GCP_MLENGINE_PREDICTION_INPUT", "gs://test-airflow-mlengine/prediction_input.json" +) +PREDICTION_OUTPUT = os.environ.get( + "GCP_MLENGINE_PREDICTION_OUTPUT", "gs://test-airflow-mlengine/prediction_output" +) +TRAINER_URI = os.environ.get( + "GCP_MLENGINE_TRAINER_URI", "gs://test-airflow-mlengine/trainer.tar.gz" +) +TRAINER_PY_MODULE = os.environ.get( + "GCP_MLENGINE_TRAINER_TRAINER_PY_MODULE", "trainer.task" +) + +SUMMARY_TMP = os.environ.get( + "GCP_MLENGINE_DATAFLOW_TMP", "gs://test-airflow-mlengine/tmp/" +) +SUMMARY_STAGING = os.environ.get( + "GCP_MLENGINE_DATAFLOW_STAGING", "gs://test-airflow-mlengine/staging/" +) + +default_args = {"params": {"model_name": MODEL_NAME}} + +with models.DAG( + "example_gcp_mlengine", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + # [START howto_operator_gcp_mlengine_training] + training = MLEngineStartTrainingJobOperator( + task_id="training", + project_id=PROJECT_ID, + region="us-central1", + job_id="training-job-{{ ts_nodash }}-{{ params.model_name }}", + runtime_version="1.15", + python_version="3.7", + job_dir=JOB_DIR, + package_uris=[TRAINER_URI], + training_python_module=TRAINER_PY_MODULE, + training_args=[], + labels={"job_type": "training"}, + ) + # [END howto_operator_gcp_mlengine_training] + + # [START howto_operator_gcp_mlengine_create_model] + create_model = MLEngineCreateModelOperator( + task_id="create-model", + project_id=PROJECT_ID, + model={ + "name": MODEL_NAME, + }, + ) + # [END howto_operator_gcp_mlengine_create_model] + + # [START howto_operator_gcp_mlengine_get_model] + get_model = MLEngineGetModelOperator( + task_id="get-model", + project_id=PROJECT_ID, + model_name=MODEL_NAME, + ) + # [END howto_operator_gcp_mlengine_get_model] + + # [START howto_operator_gcp_mlengine_print_model] + get_model_result = BashOperator( + bash_command="echo \"{{ task_instance.xcom_pull('get-model') }}\"", + task_id="get-model-result", + ) + # [END howto_operator_gcp_mlengine_print_model] + + # [START howto_operator_gcp_mlengine_create_version1] + create_version = MLEngineCreateVersionOperator( + task_id="create-version", + project_id=PROJECT_ID, + model_name=MODEL_NAME, + version={ + "name": "v1", + "description": "First-version", + "deployment_uri": f"{JOB_DIR}/keras_export/", + "runtime_version": "1.15", + "machineType": "mls1-c1-m2", + "framework": "TENSORFLOW", + "pythonVersion": "3.7", + }, + ) + # [END howto_operator_gcp_mlengine_create_version1] + + # [START howto_operator_gcp_mlengine_create_version2] + create_version_2 = MLEngineCreateVersionOperator( + task_id="create-version-2", + project_id=PROJECT_ID, + model_name=MODEL_NAME, + version={ + "name": "v2", + "description": "Second version", + "deployment_uri": SAVED_MODEL_PATH, + "runtime_version": "1.15", + "machineType": "mls1-c1-m2", + "framework": "TENSORFLOW", + "pythonVersion": "3.7", + }, + ) + # [END howto_operator_gcp_mlengine_create_version2] + + # [START howto_operator_gcp_mlengine_default_version] + set_defaults_version = MLEngineSetDefaultVersionOperator( + task_id="set-default-version", + project_id=PROJECT_ID, + model_name=MODEL_NAME, + version_name="v2", + ) + # [END howto_operator_gcp_mlengine_default_version] + + # [START howto_operator_gcp_mlengine_list_versions] + list_version = MLEngineListVersionsOperator( + task_id="list-version", + project_id=PROJECT_ID, + model_name=MODEL_NAME, + ) + # [END howto_operator_gcp_mlengine_list_versions] + + # [START howto_operator_gcp_mlengine_print_versions] + list_version_result = BashOperator( + bash_command="echo \"{{ task_instance.xcom_pull('list-version') }}\"", + task_id="list-version-result", + ) + # [END howto_operator_gcp_mlengine_print_versions] + + # [START howto_operator_gcp_mlengine_get_prediction] + prediction = MLEngineStartBatchPredictionJobOperator( + task_id="prediction", + project_id=PROJECT_ID, + job_id="prediction-{{ ts_nodash }}-{{ params.model_name }}", + region="us-central1", + model_name=MODEL_NAME, + data_format="TEXT", + input_paths=[PREDICTION_INPUT], + output_path=PREDICTION_OUTPUT, + labels={"job_type": "prediction"}, + ) + # [END howto_operator_gcp_mlengine_get_prediction] + + # [START howto_operator_gcp_mlengine_delete_version] + delete_version = MLEngineDeleteVersionOperator( + task_id="delete-version", + project_id=PROJECT_ID, + model_name=MODEL_NAME, + version_name="v1", + ) + # [END howto_operator_gcp_mlengine_delete_version] + + # [START howto_operator_gcp_mlengine_delete_model] + delete_model = MLEngineDeleteModelOperator( + task_id="delete-model", + project_id=PROJECT_ID, + model_name=MODEL_NAME, + delete_contents=True, + ) + # [END howto_operator_gcp_mlengine_delete_model] + + training >> create_version + training >> create_version_2 + create_model >> get_model >> [get_model_result, delete_model] + create_model >> create_version >> create_version_2 >> set_defaults_version >> list_version + create_version >> prediction + create_version_2 >> prediction + prediction >> delete_version + list_version >> list_version_result + list_version >> delete_version + delete_version >> delete_model + + # [START howto_operator_gcp_mlengine_get_metric] + def get_metric_fn_and_keys(): + """ + Gets metric function and keys used to generate summary + """ + + def normalize_value(inst: Dict): + val = float(inst["dense_4"][0]) + return tuple([val]) # returns a tuple. + + return normalize_value, ["val"] # key order must match. + + # [END howto_operator_gcp_mlengine_get_metric] + + # [START howto_operator_gcp_mlengine_validate_error] + def validate_err_and_count(summary: Dict) -> Dict: + """ + Validate summary result + """ + if summary["val"] > 1: + raise ValueError(f"Too high val>1; summary={summary}") + if summary["val"] < 0: + raise ValueError(f"Too low val<0; summary={summary}") + if summary["count"] != 20: + raise ValueError(f"Invalid value val != 20; summary={summary}") + return summary + + # [END howto_operator_gcp_mlengine_validate_error] + + # [START howto_operator_gcp_mlengine_evaluate] + ( + evaluate_prediction, + evaluate_summary, + evaluate_validation, + ) = mlengine_operator_utils.create_evaluate_ops( + task_prefix="evaluate-ops", + data_format="TEXT", + input_paths=[PREDICTION_INPUT], + prediction_path=PREDICTION_OUTPUT, + metric_fn_and_keys=get_metric_fn_and_keys(), + validate_fn=validate_err_and_count, + batch_prediction_job_id="evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}", + project_id=PROJECT_ID, + region="us-central1", + dataflow_options={ + "project": PROJECT_ID, + "tempLocation": SUMMARY_TMP, + "stagingLocation": SUMMARY_STAGING, + }, + model_name=MODEL_NAME, + version_name="v1", + py_interpreter="python3", + ) + # [END howto_operator_gcp_mlengine_evaluate] + + create_model >> create_version >> evaluate_prediction + evaluate_validation >> delete_version diff --git a/reference/providers/google/cloud/example_dags/example_mysql_to_gcs.py b/reference/providers/google/cloud/example_dags/example_mysql_to_gcs.py new file mode 100644 index 0000000..e765c6d --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_mysql_to_gcs.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.google.cloud.transfers.mysql_to_gcs import MySQLToGCSOperator +from airflow.utils import dates + +GCS_BUCKET = os.environ.get("GCP_GCS_BUCKET", "example-airflow-mysql-gcs") +FILENAME = "test_file" + +SQL_QUERY = "SELECT * from test_table" + +with models.DAG( + "example_mysql_to_gcs", + default_args=dict(start_date=dates.days_ago(1)), + schedule_interval=None, + tags=["example"], +) as dag: + # [START howto_operator_mysql_to_gcs] + upload = MySQLToGCSOperator( + task_id="mysql_to_gcs", + sql=SQL_QUERY, + bucket=GCS_BUCKET, + filename=FILENAME, + export_format="csv", + ) + # [END howto_operator_mysql_to_gcs] diff --git a/reference/providers/google/cloud/example_dags/example_natural_language.py b/reference/providers/google/cloud/example_dags/example_natural_language.py new file mode 100644 index 0000000..3765bb4 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_natural_language.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google Cloud Natural Language service +""" + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.natural_language import ( + CloudNaturalLanguageAnalyzeEntitiesOperator, + CloudNaturalLanguageAnalyzeEntitySentimentOperator, + CloudNaturalLanguageAnalyzeSentimentOperator, + CloudNaturalLanguageClassifyTextOperator, +) +from airflow.utils.dates import days_ago +from google.cloud.language_v1.proto.language_service_pb2 import Document + +# [START howto_operator_gcp_natural_language_document_text] +TEXT = """Airflow is a platform to programmatically author, schedule and monitor workflows. + +Use Airflow to author workflows as Directed Acyclic Graphs (DAGs) of tasks. The Airflow scheduler executes + your tasks on an array of workers while following the specified dependencies. Rich command line utilities + make performing complex surgeries on DAGs a snap. The rich user interface makes it easy to visualize + pipelines running in production, monitor progress, and troubleshoot issues when needed. +""" +document = Document(content=TEXT, type="PLAIN_TEXT") +# [END howto_operator_gcp_natural_language_document_text] + +# [START howto_operator_gcp_natural_language_document_gcs] +GCS_CONTENT_URI = "gs://my-text-bucket/sentiment-me.txt" +document_gcs = Document(gcs_content_uri=GCS_CONTENT_URI, type="PLAIN_TEXT") +# [END howto_operator_gcp_natural_language_document_gcs] + + +with models.DAG( + "example_gcp_natural_language", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), +) as dag: + + # [START howto_operator_gcp_natural_language_analyze_entities] + analyze_entities = CloudNaturalLanguageAnalyzeEntitiesOperator( + document=document, task_id="analyze_entities" + ) + # [END howto_operator_gcp_natural_language_analyze_entities] + + # [START howto_operator_gcp_natural_language_analyze_entities_result] + analyze_entities_result = BashOperator( + bash_command="echo \"{{ task_instance.xcom_pull('analyze_entities') }}\"", + task_id="analyze_entities_result", + ) + # [END howto_operator_gcp_natural_language_analyze_entities_result] + + # [START howto_operator_gcp_natural_language_analyze_entity_sentiment] + analyze_entity_sentiment = CloudNaturalLanguageAnalyzeEntitySentimentOperator( + document=document, task_id="analyze_entity_sentiment" + ) + # [END howto_operator_gcp_natural_language_analyze_entity_sentiment] + + # [START howto_operator_gcp_natural_language_analyze_entity_sentiment_result] + analyze_entity_sentiment_result = BashOperator( + bash_command="echo \"{{ task_instance.xcom_pull('analyze_entity_sentiment') }}\"", + task_id="analyze_entity_sentiment_result", + ) + # [END howto_operator_gcp_natural_language_analyze_entity_sentiment_result] + + # [START howto_operator_gcp_natural_language_analyze_sentiment] + analyze_sentiment = CloudNaturalLanguageAnalyzeSentimentOperator( + document=document, task_id="analyze_sentiment" + ) + # [END howto_operator_gcp_natural_language_analyze_sentiment] + + # [START howto_operator_gcp_natural_language_analyze_sentiment_result] + analyze_sentiment_result = BashOperator( + bash_command="echo \"{{ task_instance.xcom_pull('analyze_sentiment') }}\"", + task_id="analyze_sentiment_result", + ) + # [END howto_operator_gcp_natural_language_analyze_sentiment_result] + + # [START howto_operator_gcp_natural_language_analyze_classify_text] + analyze_classify_text = CloudNaturalLanguageClassifyTextOperator( + document=document, task_id="analyze_classify_text" + ) + # [END howto_operator_gcp_natural_language_analyze_classify_text] + + # [START howto_operator_gcp_natural_language_analyze_classify_text_result] + analyze_classify_text_result = BashOperator( + bash_command="echo \"{{ task_instance.xcom_pull('analyze_classify_text') }}\"", + task_id="analyze_classify_text_result", + ) + # [END howto_operator_gcp_natural_language_analyze_classify_text_result] + + analyze_entities >> analyze_entities_result + analyze_entity_sentiment >> analyze_entity_sentiment_result + analyze_sentiment >> analyze_sentiment_result + analyze_classify_text >> analyze_classify_text_result diff --git a/reference/providers/google/cloud/example_dags/example_oracle_to_gcs.py b/reference/providers/google/cloud/example_dags/example_oracle_to_gcs.py new file mode 100644 index 0000000..e37573e --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_oracle_to_gcs.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.google.cloud.transfers.oracle_to_gcs import OracleToGCSOperator +from airflow.utils import dates + +GCS_BUCKET = os.environ.get("GCP_GCS_BUCKET", "example-airflow-oracle-gcs") +FILENAME = "test_file" + +SQL_QUERY = "SELECT * from test_table" + +with models.DAG( + "example_oracle_to_gcs", + default_args=dict(start_date=dates.days_ago(1)), + schedule_interval=None, + tags=["example"], +) as dag: + # [START howto_operator_oracle_to_gcs] + upload = OracleToGCSOperator( + task_id="oracle_to_gcs", + sql=SQL_QUERY, + bucket=GCS_BUCKET, + filename=FILENAME, + export_format="csv", + ) + # [END howto_operator_oracle_to_gcs] diff --git a/reference/providers/google/cloud/example_dags/example_postgres_to_gcs.py b/reference/providers/google/cloud/example_dags/example_postgres_to_gcs.py new file mode 100644 index 0000000..8be40ce --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_postgres_to_gcs.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example DAG using PostgresToGoogleCloudStorageOperator. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.transfers.postgres_to_gcs import ( + PostgresToGCSOperator, +) +from airflow.utils.dates import days_ago + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCS_BUCKET = os.environ.get("GCP_GCS_BUCKET_NAME", "postgres_to_gcs_example") +FILENAME = "test_file" +SQL_QUERY = "select * from test_table;" + +with models.DAG( + dag_id="example_postgres_to_gcs", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + upload_data = PostgresToGCSOperator( + task_id="get_data", + sql=SQL_QUERY, + bucket=GCS_BUCKET, + filename=FILENAME, + gzip=False, + ) + + upload_data_server_side_cursor = PostgresToGCSOperator( + task_id="get_data_with_server_side_cursor", + sql=SQL_QUERY, + bucket=GCS_BUCKET, + filename=FILENAME, + gzip=False, + use_server_side_cursor=True, + ) diff --git a/reference/providers/google/cloud/example_dags/example_presto_to_gcs.py b/reference/providers/google/cloud/example_dags/example_presto_to_gcs.py new file mode 100644 index 0000000..2064fda --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_presto_to_gcs.py @@ -0,0 +1,156 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example DAG using PrestoToGCSOperator. +""" +import os +import re + +from airflow import models +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateExternalTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryExecuteQueryOperator, +) +from airflow.providers.google.cloud.transfers.presto_to_gcs import PrestoToGCSOperator +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCS_BUCKET = os.environ.get( + "GCP_PRESTO_TO_GCS_BUCKET_NAME", "test-presto-to-gcs-bucket" +) +DATASET_NAME = os.environ.get( + "GCP_PRESTO_TO_GCS_DATASET_NAME", "test_presto_to_gcs_dataset" +) + +SOURCE_MULTIPLE_TYPES = "memory.default.test_multiple_types" +SOURCE_CUSTOMER_TABLE = "tpch.sf1.customer" + + +def safe_name(s: str) -> str: + """ + Remove invalid characters for filename + """ + return re.sub("[^0-9a-zA-Z_]+", "_", s) + + +with models.DAG( + dag_id="example_presto_to_gcs", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + + create_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create-dataset", dataset_id=DATASET_NAME + ) + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True + ) + + # [START howto_operator_presto_to_gcs_basic] + presto_to_gcs_basic = PrestoToGCSOperator( + task_id="presto_to_gcs_basic", + sql=f"select * from {SOURCE_MULTIPLE_TYPES}", + bucket=GCS_BUCKET, + filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}.{{}}.json", + ) + # [END howto_operator_presto_to_gcs_basic] + + # [START howto_operator_presto_to_gcs_multiple_types] + presto_to_gcs_multiple_types = PrestoToGCSOperator( + task_id="presto_to_gcs_multiple_types", + sql=f"select * from {SOURCE_MULTIPLE_TYPES}", + bucket=GCS_BUCKET, + filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}.{{}}.json", + schema_filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json", + gzip=False, + ) + # [END howto_operator_presto_to_gcs_multiple_types] + + # [START howto_operator_create_external_table_multiple_types] + create_external_table_multiple_types = BigQueryCreateExternalTableOperator( + task_id="create_external_table_multiple_types", + bucket=GCS_BUCKET, + source_objects=[f"{safe_name(SOURCE_MULTIPLE_TYPES)}.*.json"], + source_format="NEWLINE_DELIMITED_JSON", + destination_project_dataset_table=f"{DATASET_NAME}.{safe_name(SOURCE_MULTIPLE_TYPES)}", + schema_object=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json", + ) + # [END howto_operator_create_external_table_multiple_types] + + read_data_from_gcs_multiple_types = BigQueryExecuteQueryOperator( + task_id="read_data_from_gcs_multiple_types", + sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{safe_name(SOURCE_MULTIPLE_TYPES)}`", + use_legacy_sql=False, + ) + + # [START howto_operator_presto_to_gcs_many_chunks] + presto_to_gcs_many_chunks = PrestoToGCSOperator( + task_id="presto_to_gcs_many_chunks", + sql=f"select * from {SOURCE_CUSTOMER_TABLE}", + bucket=GCS_BUCKET, + filename=f"{safe_name(SOURCE_CUSTOMER_TABLE)}.{{}}.json", + schema_filename=f"{safe_name(SOURCE_CUSTOMER_TABLE)}-schema.json", + approx_max_file_size_bytes=10_000_000, + gzip=False, + ) + # [END howto_operator_presto_to_gcs_many_chunks] + + create_external_table_many_chunks = BigQueryCreateExternalTableOperator( + task_id="create_external_table_many_chunks", + bucket=GCS_BUCKET, + source_objects=[f"{safe_name(SOURCE_CUSTOMER_TABLE)}.*.json"], + source_format="NEWLINE_DELIMITED_JSON", + destination_project_dataset_table=f"{DATASET_NAME}.{safe_name(SOURCE_CUSTOMER_TABLE)}", + schema_object=f"{safe_name(SOURCE_CUSTOMER_TABLE)}-schema.json", + ) + + # [START howto_operator_read_data_from_gcs_many_chunks] + read_data_from_gcs_many_chunks = BigQueryExecuteQueryOperator( + task_id="read_data_from_gcs_many_chunks", + sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{safe_name(SOURCE_CUSTOMER_TABLE)}`", + use_legacy_sql=False, + ) + # [END howto_operator_read_data_from_gcs_many_chunks] + + # [START howto_operator_presto_to_gcs_csv] + presto_to_gcs_csv = PrestoToGCSOperator( + task_id="presto_to_gcs_csv", + sql=f"select * from {SOURCE_MULTIPLE_TYPES}", + bucket=GCS_BUCKET, + filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}.{{}}.csv", + schema_filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json", + export_format="csv", + ) + # [END howto_operator_presto_to_gcs_csv] + + create_dataset >> presto_to_gcs_basic + create_dataset >> presto_to_gcs_multiple_types + create_dataset >> presto_to_gcs_many_chunks + create_dataset >> presto_to_gcs_csv + + presto_to_gcs_multiple_types >> create_external_table_multiple_types >> read_data_from_gcs_multiple_types + presto_to_gcs_many_chunks >> create_external_table_many_chunks >> read_data_from_gcs_many_chunks + + presto_to_gcs_basic >> delete_dataset + presto_to_gcs_csv >> delete_dataset + read_data_from_gcs_multiple_types >> delete_dataset + read_data_from_gcs_many_chunks >> delete_dataset diff --git a/reference/providers/google/cloud/example_dags/example_pubsub.py b/reference/providers/google/cloud/example_dags/example_pubsub.py new file mode 100644 index 0000000..027ce72 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_pubsub.py @@ -0,0 +1,188 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that uses Google PubSub services. +""" +import os + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.pubsub import ( + PubSubCreateSubscriptionOperator, + PubSubCreateTopicOperator, + PubSubDeleteSubscriptionOperator, + PubSubDeleteTopicOperator, + PubSubPublishMessageOperator, + PubSubPullOperator, +) +from airflow.providers.google.cloud.sensors.pubsub import PubSubPullSensor +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") +TOPIC_FOR_SENSOR_DAG = os.environ.get( + "GCP_PUBSUB_SENSOR_TOPIC", "PubSubSensorTestTopic" +) +TOPIC_FOR_OPERATOR_DAG = os.environ.get( + "GCP_PUBSUB_OPERATOR_TOPIC", "PubSubOperatorTestTopic" +) +MESSAGE = { + "data": b"Tool", + "attributes": {"name": "wrench", "mass": "1.3kg", "count": "3"}, +} + +# [START howto_operator_gcp_pubsub_pull_messages_result_cmd] +echo_cmd = """ +{% for m in task_instance.xcom_pull('pull_messages') %} + echo "AckID: {{ m.get('ackId') }}, Base64-Encoded: {{ m.get('message') }}" +{% endfor %} +""" +# [END howto_operator_gcp_pubsub_pull_messages_result_cmd] + +with models.DAG( + "example_gcp_pubsub_sensor", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), +) as example_sensor_dag: + # [START howto_operator_gcp_pubsub_create_topic] + create_topic = PubSubCreateTopicOperator( + task_id="create_topic", + topic=TOPIC_FOR_SENSOR_DAG, + project_id=GCP_PROJECT_ID, + fail_if_exists=False, + ) + # [END howto_operator_gcp_pubsub_create_topic] + + # [START howto_operator_gcp_pubsub_create_subscription] + subscribe_task = PubSubCreateSubscriptionOperator( + task_id="subscribe_task", project_id=GCP_PROJECT_ID, topic=TOPIC_FOR_SENSOR_DAG + ) + # [END howto_operator_gcp_pubsub_create_subscription] + + # [START howto_operator_gcp_pubsub_pull_message_with_sensor] + subscription = "{{ task_instance.xcom_pull('subscribe_task') }}" + + pull_messages = PubSubPullSensor( + task_id="pull_messages", + ack_messages=True, + project_id=GCP_PROJECT_ID, + subscription=subscription, + ) + # [END howto_operator_gcp_pubsub_pull_message_with_sensor] + + # [START howto_operator_gcp_pubsub_pull_messages_result] + pull_messages_result = BashOperator( + task_id="pull_messages_result", bash_command=echo_cmd + ) + # [END howto_operator_gcp_pubsub_pull_messages_result] + + # [START howto_operator_gcp_pubsub_publish] + publish_task = PubSubPublishMessageOperator( + task_id="publish_task", + project_id=GCP_PROJECT_ID, + topic=TOPIC_FOR_SENSOR_DAG, + messages=[MESSAGE] * 10, + ) + # [END howto_operator_gcp_pubsub_publish] + + # [START howto_operator_gcp_pubsub_unsubscribe] + unsubscribe_task = PubSubDeleteSubscriptionOperator( + task_id="unsubscribe_task", + project_id=GCP_PROJECT_ID, + subscription="{{ task_instance.xcom_pull('subscribe_task') }}", + ) + # [END howto_operator_gcp_pubsub_unsubscribe] + + # [START howto_operator_gcp_pubsub_delete_topic] + delete_topic = PubSubDeleteTopicOperator( + task_id="delete_topic", topic=TOPIC_FOR_SENSOR_DAG, project_id=GCP_PROJECT_ID + ) + # [END howto_operator_gcp_pubsub_delete_topic] + + create_topic >> subscribe_task >> [publish_task, pull_messages] + pull_messages >> pull_messages_result >> unsubscribe_task >> delete_topic + + +with models.DAG( + "example_gcp_pubsub_operator", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), +) as example_operator_dag: + # [START howto_operator_gcp_pubsub_create_topic] + create_topic = PubSubCreateTopicOperator( + task_id="create_topic", topic=TOPIC_FOR_OPERATOR_DAG, project_id=GCP_PROJECT_ID + ) + # [END howto_operator_gcp_pubsub_create_topic] + + # [START howto_operator_gcp_pubsub_create_subscription] + subscribe_task = PubSubCreateSubscriptionOperator( + task_id="subscribe_task", + project_id=GCP_PROJECT_ID, + topic=TOPIC_FOR_OPERATOR_DAG, + ) + # [END howto_operator_gcp_pubsub_create_subscription] + + # [START howto_operator_gcp_pubsub_pull_message_with_operator] + subscription = "{{ task_instance.xcom_pull('subscribe_task') }}" + + pull_messages_operator = PubSubPullOperator( + task_id="pull_messages", + ack_messages=True, + project_id=GCP_PROJECT_ID, + subscription=subscription, + ) + # [END howto_operator_gcp_pubsub_pull_message_with_operator] + + # [START howto_operator_gcp_pubsub_pull_messages_result] + pull_messages_result = BashOperator( + task_id="pull_messages_result", bash_command=echo_cmd + ) + # [END howto_operator_gcp_pubsub_pull_messages_result] + + # [START howto_operator_gcp_pubsub_publish] + publish_task = PubSubPublishMessageOperator( + task_id="publish_task", + project_id=GCP_PROJECT_ID, + topic=TOPIC_FOR_OPERATOR_DAG, + messages=[MESSAGE, MESSAGE, MESSAGE], + ) + # [END howto_operator_gcp_pubsub_publish] + + # [START howto_operator_gcp_pubsub_unsubscribe] + unsubscribe_task = PubSubDeleteSubscriptionOperator( + task_id="unsubscribe_task", + project_id=GCP_PROJECT_ID, + subscription="{{ task_instance.xcom_pull('subscribe_task') }}", + ) + # [END howto_operator_gcp_pubsub_unsubscribe] + + # [START howto_operator_gcp_pubsub_delete_topic] + delete_topic = PubSubDeleteTopicOperator( + task_id="delete_topic", topic=TOPIC_FOR_OPERATOR_DAG, project_id=GCP_PROJECT_ID + ) + # [END howto_operator_gcp_pubsub_delete_topic] + + ( + create_topic + >> subscribe_task + >> publish_task + >> pull_messages_operator + >> pull_messages_result + >> unsubscribe_task + >> delete_topic + ) diff --git a/reference/providers/google/cloud/example_dags/example_s3_to_gcs.py b/reference/providers/google/cloud/example_dags/example_s3_to_gcs.py new file mode 100644 index 0000000..9a12029 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_s3_to_gcs.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.operators.python import PythonOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.operators.s3_bucket import ( + S3CreateBucketOperator, + S3DeleteBucketOperator, +) +from airflow.providers.google.cloud.operators.gcs import ( + GCSCreateBucketOperator, + GCSDeleteBucketOperator, +) +from airflow.providers.google.cloud.transfers.s3_to_gcs import S3ToGCSOperator +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "gcp-project-id") +S3BUCKET_NAME = os.environ.get("S3BUCKET_NAME", "example-s3bucket-name") +GCS_BUCKET = os.environ.get("GCP_GCS_BUCKET", "example-gcsbucket-name") +UPLOAD_FILE = "/tmp/example-file.txt" +PREFIX = "TESTS" + + +def upload_file(): + """A callable to upload file to AWS bucket""" + s3_hook = S3Hook() + s3_hook.load_file(filename=UPLOAD_FILE, key=PREFIX, bucket_name=S3BUCKET_NAME) + + +with models.DAG( + "example_s3_to_gcs", + schedule_interval=None, + start_date=days_ago(2), + tags=["example"], +) as dag: + create_s3_bucket = S3CreateBucketOperator( + task_id="create_s3_bucket", bucket_name=S3BUCKET_NAME, region_name="us-east-1" + ) + + upload_to_s3 = PythonOperator( + task_id="upload_file_to_s3", python_callable=upload_file + ) + + create_gcs_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=GCS_BUCKET, + project_id=GCP_PROJECT_ID, + ) + # [START howto_transfer_s3togcs_operator] + transfer_to_gcs = S3ToGCSOperator( + task_id="s3_to_gcs_task", + bucket=S3BUCKET_NAME, + prefix=PREFIX, + dest_gcs="gs://" + GCS_BUCKET, + ) + # [END howto_transfer_s3togcs_operator] + + delete_s3_bucket = S3DeleteBucketOperator( + task_id="delete_s3_bucket", bucket_name=S3BUCKET_NAME, force_delete=True + ) + + delete_gcs_bucket = GCSDeleteBucketOperator( + task_id="delete_gcs_bucket", bucket_name=GCS_BUCKET + ) + + ( + create_s3_bucket + >> upload_to_s3 + >> create_gcs_bucket + >> transfer_to_gcs + >> delete_s3_bucket + >> delete_gcs_bucket + ) diff --git a/reference/providers/google/cloud/example_dags/example_salesforce_to_gcs.py b/reference/providers/google/cloud/example_dags/example_salesforce_to_gcs.py new file mode 100644 index 0000000..b0d27d3 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_salesforce_to_gcs.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that shows how to use SalesforceToGcsOperator. +""" +import os + +from airflow import models +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryExecuteQueryOperator, +) +from airflow.providers.google.cloud.operators.gcs import ( + GCSCreateBucketOperator, + GCSDeleteBucketOperator, +) +from airflow.providers.google.cloud.transfers.gcs_to_bigquery import ( + GCSToBigQueryOperator, +) +from airflow.providers.google.cloud.transfers.salesforce_to_gcs import ( + SalesforceToGcsOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCS_BUCKET = os.environ.get("GCS_BUCKET", "airflow-salesforce-bucket") +DATASET_NAME = os.environ.get("SALESFORCE_DATASET_NAME", "salesforce_test_dataset") +TABLE_NAME = os.environ.get("SALESFORCE_TABLE_NAME", "salesforce_test_datatable") +GCS_OBJ_PATH = os.environ.get("GCS_OBJ_PATH", "results.csv") +QUERY = "SELECT Id, Name, Company, Phone, Email, CreatedDate, LastModifiedDate, IsDeleted FROM Lead" +GCS_CONN_ID = os.environ.get("GCS_CONN_ID", "google_cloud_default") +SALESFORCE_CONN_ID = os.environ.get("SALESFORCE_CONN_ID", "salesforce_default") + + +with models.DAG( + "example_salesforce_to_gcs", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), +) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=GCS_BUCKET, + project_id=GCP_PROJECT_ID, + gcp_conn_id=GCS_CONN_ID, + ) + + # [START howto_operator_salesforce_to_gcs] + gcs_upload_task = SalesforceToGcsOperator( + query=QUERY, + include_deleted=True, + bucket_name=GCS_BUCKET, + object_name=GCS_OBJ_PATH, + salesforce_conn_id=SALESFORCE_CONN_ID, + export_format="csv", + coerce_to_timestamp=False, + record_time_added=False, + gcp_conn_id=GCS_CONN_ID, + task_id="upload_to_gcs", + dag=dag, + ) + # [END howto_operator_salesforce_to_gcs] + + create_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create_dataset", + dataset_id=DATASET_NAME, + project_id=GCP_PROJECT_ID, + gcp_conn_id=GCS_CONN_ID, + ) + + create_table = BigQueryCreateEmptyTableOperator( + task_id="create_table", + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + schema_fields=[ + {"name": "id", "type": "STRING", "mode": "NULLABLE"}, + {"name": "name", "type": "STRING", "mode": "NULLABLE"}, + {"name": "company", "type": "STRING", "mode": "NULLABLE"}, + {"name": "phone", "type": "STRING", "mode": "NULLABLE"}, + {"name": "email", "type": "STRING", "mode": "NULLABLE"}, + {"name": "createddate", "type": "STRING", "mode": "NULLABLE"}, + {"name": "lastmodifieddate", "type": "STRING", "mode": "NULLABLE"}, + {"name": "isdeleted", "type": "BOOL", "mode": "NULLABLE"}, + ], + ) + + load_csv = GCSToBigQueryOperator( + task_id="gcs_to_bq", + bucket=GCS_BUCKET, + source_objects=[GCS_OBJ_PATH], + destination_project_dataset_table=f"{DATASET_NAME}.{TABLE_NAME}", + write_disposition="WRITE_TRUNCATE", + ) + + read_data_from_gcs = BigQueryExecuteQueryOperator( + task_id="read_data_from_gcs", + sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{TABLE_NAME}`", + use_legacy_sql=False, + ) + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", + bucket_name=GCS_BUCKET, + ) + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", + project_id=GCP_PROJECT_ID, + dataset_id=DATASET_NAME, + delete_contents=True, + ) + + create_bucket >> gcs_upload_task >> load_csv + create_dataset >> create_table >> load_csv + load_csv >> read_data_from_gcs + read_data_from_gcs >> delete_bucket + read_data_from_gcs >> delete_dataset diff --git a/reference/providers/google/cloud/example_dags/example_sftp_to_gcs.py b/reference/providers/google/cloud/example_dags/example_sftp_to_gcs.py new file mode 100644 index 0000000..218fa31 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_sftp_to_gcs.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Google Cloud Storage to SFTP transfer operators. +""" + +import os + +from airflow import models +from airflow.providers.google.cloud.transfers.sftp_to_gcs import SFTPToGCSOperator +from airflow.utils.dates import days_ago + +BUCKET_SRC = os.environ.get("GCP_GCS_BUCKET_1_SRC", "test-sftp-gcs") + +TMP_PATH = "/tmp" +DIR = "tests_sftp_hook_dir" +SUBDIR = "subdir" + +OBJECT_SRC_1 = "parent-1.bin" +OBJECT_SRC_2 = "parent-2.bin" +OBJECT_SRC_3 = "parent-3.txt" + + +with models.DAG( + "example_sftp_to_gcs", start_date=days_ago(1), schedule_interval=None +) as dag: + # [START howto_operator_sftp_to_gcs_copy_single_file] + copy_file_from_sftp_to_gcs = SFTPToGCSOperator( + task_id="file-copy-sftp-to-gcs", + source_path=os.path.join(TMP_PATH, DIR, OBJECT_SRC_1), + destination_bucket=BUCKET_SRC, + ) + # [END howto_operator_sftp_to_gcs_copy_single_file] + + # [START howto_operator_sftp_to_gcs_move_single_file_destination] + move_file_from_sftp_to_gcs_destination = SFTPToGCSOperator( + task_id="file-move-sftp-to-gcs-destination", + source_path=os.path.join(TMP_PATH, DIR, OBJECT_SRC_2), + destination_bucket=BUCKET_SRC, + destination_path="destination_dir/destination_filename.bin", + move_object=True, + ) + # [END howto_operator_sftp_to_gcs_move_single_file_destination] + + # [START howto_operator_sftp_to_gcs_copy_directory] + copy_directory_from_sftp_to_gcs = SFTPToGCSOperator( + task_id="dir-copy-sftp-to-gcs", + source_path=os.path.join(TMP_PATH, DIR, SUBDIR, "*"), + destination_bucket=BUCKET_SRC, + ) + # [END howto_operator_sftp_to_gcs_copy_directory] + + # [START howto_operator_sftp_to_gcs_move_specific_files] + move_specific_files_from_gcs_to_sftp = SFTPToGCSOperator( + task_id="dir-move-specific-files-sftp-to-gcs", + source_path=os.path.join(TMP_PATH, DIR, SUBDIR, "*.bin"), + destination_bucket=BUCKET_SRC, + destination_path="specific_files/", + move_object=True, + ) + # [END howto_operator_sftp_to_gcs_move_specific_files] diff --git a/reference/providers/google/cloud/example_dags/example_sheets_to_gcs.py b/reference/providers/google/cloud/example_dags/example_sheets_to_gcs.py new file mode 100644 index 0000000..c5984fb --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_sheets_to_gcs.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.google.cloud.transfers.sheets_to_gcs import ( + GoogleSheetsToGCSOperator, +) +from airflow.utils.dates import days_ago + +BUCKET = os.environ.get("GCP_GCS_BUCKET", "test28397yeo") +SPREADSHEET_ID = os.environ.get("SPREADSHEET_ID", "1234567890qwerty") + +with models.DAG( + "example_sheets_to_gcs", + start_date=days_ago(1), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag: + # [START upload_sheet_to_gcs] + upload_sheet_to_gcs = GoogleSheetsToGCSOperator( + task_id="upload_sheet_to_gcs", + destination_bucket=BUCKET, + spreadsheet_id=SPREADSHEET_ID, + ) + # [END upload_sheet_to_gcs] diff --git a/reference/providers/google/cloud/example_dags/example_spanner.py b/reference/providers/google/cloud/example_dags/example_spanner.py new file mode 100644 index 0000000..8ddcf47 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_spanner.py @@ -0,0 +1,196 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that creates, updates, queries and deletes a Cloud Spanner instance. + +This DAG relies on the following environment variables +* GCP_PROJECT_ID - Google Cloud project for the Cloud Spanner instance. +* GCP_SPANNER_INSTANCE_ID - Cloud Spanner instance ID. +* GCP_SPANNER_DATABASE_ID - Cloud Spanner database ID. +* GCP_SPANNER_CONFIG_NAME - The name of the instance's configuration. Values are of the + form ``projects//instanceConfigs/``. See also: + https://cloud.google.com/spanner/docs/reference/rest/v1/projects.instanceConfigs#InstanceConfig + https://cloud.google.com/spanner/docs/reference/rest/v1/projects.instanceConfigs/list#google.spanner.admin.instance.v1.InstanceAdmin.ListInstanceConfigs +* GCP_SPANNER_NODE_COUNT - Number of nodes allocated to the instance. +* GCP_SPANNER_DISPLAY_NAME - The descriptive name for this instance as it appears in UIs. + Must be unique per project and between 4 and 30 characters in length. +""" + +import os + +from airflow import models +from airflow.providers.google.cloud.operators.spanner import ( + SpannerDeleteDatabaseInstanceOperator, + SpannerDeleteInstanceOperator, + SpannerDeployDatabaseInstanceOperator, + SpannerDeployInstanceOperator, + SpannerQueryDatabaseInstanceOperator, + SpannerUpdateDatabaseInstanceOperator, +) +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCP_SPANNER_INSTANCE_ID = os.environ.get("GCP_SPANNER_INSTANCE_ID", "testinstance") +GCP_SPANNER_DATABASE_ID = os.environ.get("GCP_SPANNER_DATABASE_ID", "testdatabase") +GCP_SPANNER_CONFIG_NAME = os.environ.get( + "GCP_SPANNER_CONFIG_NAME", + f"projects/{GCP_PROJECT_ID}/instanceConfigs/regional-europe-west3", +) +GCP_SPANNER_NODE_COUNT = os.environ.get("GCP_SPANNER_NODE_COUNT", "1") +GCP_SPANNER_DISPLAY_NAME = os.environ.get("GCP_SPANNER_DISPLAY_NAME", "Test Instance") +# OPERATION_ID should be unique per operation +OPERATION_ID = "unique_operation_id" + +with models.DAG( + "example_gcp_spanner", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + # Create + # [START howto_operator_spanner_deploy] + spanner_instance_create_task = SpannerDeployInstanceOperator( + project_id=GCP_PROJECT_ID, + instance_id=GCP_SPANNER_INSTANCE_ID, + configuration_name=GCP_SPANNER_CONFIG_NAME, + node_count=int(GCP_SPANNER_NODE_COUNT), + display_name=GCP_SPANNER_DISPLAY_NAME, + task_id="spanner_instance_create_task", + ) + spanner_instance_update_task = SpannerDeployInstanceOperator( + instance_id=GCP_SPANNER_INSTANCE_ID, + configuration_name=GCP_SPANNER_CONFIG_NAME, + node_count=int(GCP_SPANNER_NODE_COUNT) + 1, + display_name=GCP_SPANNER_DISPLAY_NAME + "_updated", + task_id="spanner_instance_update_task", + ) + # [END howto_operator_spanner_deploy] + + # [START howto_operator_spanner_database_deploy] + spanner_database_deploy_task = SpannerDeployDatabaseInstanceOperator( + project_id=GCP_PROJECT_ID, + instance_id=GCP_SPANNER_INSTANCE_ID, + database_id=GCP_SPANNER_DATABASE_ID, + ddl_statements=[ + "CREATE TABLE my_table1 (id INT64, name STRING(MAX)) PRIMARY KEY (id)", + "CREATE TABLE my_table2 (id INT64, name STRING(MAX)) PRIMARY KEY (id)", + ], + task_id="spanner_database_deploy_task", + ) + spanner_database_deploy_task2 = SpannerDeployDatabaseInstanceOperator( + instance_id=GCP_SPANNER_INSTANCE_ID, + database_id=GCP_SPANNER_DATABASE_ID, + ddl_statements=[ + "CREATE TABLE my_table1 (id INT64, name STRING(MAX)) PRIMARY KEY (id)", + "CREATE TABLE my_table2 (id INT64, name STRING(MAX)) PRIMARY KEY (id)", + ], + task_id="spanner_database_deploy_task2", + ) + # [END howto_operator_spanner_database_deploy] + + # [START howto_operator_spanner_database_update] + spanner_database_update_task = SpannerUpdateDatabaseInstanceOperator( + project_id=GCP_PROJECT_ID, + instance_id=GCP_SPANNER_INSTANCE_ID, + database_id=GCP_SPANNER_DATABASE_ID, + ddl_statements=[ + "CREATE TABLE my_table3 (id INT64, name STRING(MAX)) PRIMARY KEY (id)", + ], + task_id="spanner_database_update_task", + ) + # [END howto_operator_spanner_database_update] + + # [START howto_operator_spanner_database_update_idempotent] + spanner_database_update_idempotent1_task = SpannerUpdateDatabaseInstanceOperator( + project_id=GCP_PROJECT_ID, + instance_id=GCP_SPANNER_INSTANCE_ID, + database_id=GCP_SPANNER_DATABASE_ID, + operation_id=OPERATION_ID, + ddl_statements=[ + "CREATE TABLE my_table_unique (id INT64, name STRING(MAX)) PRIMARY KEY (id)", + ], + task_id="spanner_database_update_idempotent1_task", + ) + spanner_database_update_idempotent2_task = SpannerUpdateDatabaseInstanceOperator( + instance_id=GCP_SPANNER_INSTANCE_ID, + database_id=GCP_SPANNER_DATABASE_ID, + operation_id=OPERATION_ID, + ddl_statements=[ + "CREATE TABLE my_table_unique (id INT64, name STRING(MAX)) PRIMARY KEY (id)", + ], + task_id="spanner_database_update_idempotent2_task", + ) + # [END howto_operator_spanner_database_update_idempotent] + + # [START howto_operator_spanner_query] + spanner_instance_query_task = SpannerQueryDatabaseInstanceOperator( + project_id=GCP_PROJECT_ID, + instance_id=GCP_SPANNER_INSTANCE_ID, + database_id=GCP_SPANNER_DATABASE_ID, + query=["DELETE FROM my_table2 WHERE true"], + task_id="spanner_instance_query_task", + ) + spanner_instance_query_task2 = SpannerQueryDatabaseInstanceOperator( + instance_id=GCP_SPANNER_INSTANCE_ID, + database_id=GCP_SPANNER_DATABASE_ID, + query=["DELETE FROM my_table2 WHERE true"], + task_id="spanner_instance_query_task2", + ) + # [END howto_operator_spanner_query] + + # [START howto_operator_spanner_database_delete] + spanner_database_delete_task = SpannerDeleteDatabaseInstanceOperator( + project_id=GCP_PROJECT_ID, + instance_id=GCP_SPANNER_INSTANCE_ID, + database_id=GCP_SPANNER_DATABASE_ID, + task_id="spanner_database_delete_task", + ) + spanner_database_delete_task2 = SpannerDeleteDatabaseInstanceOperator( + instance_id=GCP_SPANNER_INSTANCE_ID, + database_id=GCP_SPANNER_DATABASE_ID, + task_id="spanner_database_delete_task2", + ) + # [END howto_operator_spanner_database_delete] + + # [START howto_operator_spanner_delete] + spanner_instance_delete_task = SpannerDeleteInstanceOperator( + project_id=GCP_PROJECT_ID, + instance_id=GCP_SPANNER_INSTANCE_ID, + task_id="spanner_instance_delete_task", + ) + spanner_instance_delete_task2 = SpannerDeleteInstanceOperator( + instance_id=GCP_SPANNER_INSTANCE_ID, task_id="spanner_instance_delete_task2" + ) + # [END howto_operator_spanner_delete] + + ( + spanner_instance_create_task + >> spanner_instance_update_task + >> spanner_database_deploy_task + >> spanner_database_deploy_task2 + >> spanner_database_update_task + >> spanner_database_update_idempotent1_task + >> spanner_database_update_idempotent2_task + >> spanner_instance_query_task + >> spanner_instance_query_task2 + >> spanner_database_delete_task + >> spanner_database_delete_task2 + >> spanner_instance_delete_task + >> spanner_instance_delete_task2 + ) diff --git a/reference/providers/google/cloud/example_dags/example_spanner.sql b/reference/providers/google/cloud/example_dags/example_spanner.sql new file mode 100644 index 0000000..54854a1 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_spanner.sql @@ -0,0 +1,23 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + + +INSERT my_table2 (id, name) VALUES (7, 'Seven'); +INSERT my_table2 (id, name) + VALUES (8, 'Eight'); diff --git a/reference/providers/google/cloud/example_dags/example_speech_to_text.py b/reference/providers/google/cloud/example_dags/example_speech_to_text.py new file mode 100644 index 0000000..ef5cb54 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_speech_to_text.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.google.cloud.operators.speech_to_text import ( + CloudSpeechToTextRecognizeSpeechOperator, +) +from airflow.providers.google.cloud.operators.text_to_speech import ( + CloudTextToSpeechSynthesizeOperator, +) +from airflow.utils import dates + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +BUCKET_NAME = os.environ.get( + "GCP_SPEECH_TO_TEXT_TEST_BUCKET", "gcp-speech-to-text-test-bucket" +) + +# [START howto_operator_speech_to_text_gcp_filename] +FILENAME = "gcp-speech-test-file" +# [END howto_operator_speech_to_text_gcp_filename] + +# [START howto_operator_text_to_speech_api_arguments] +INPUT = {"text": "Sample text for demo purposes"} +VOICE = {"language_code": "en-US", "ssml_gender": "FEMALE"} +AUDIO_CONFIG = {"audio_encoding": "LINEAR16"} +# [END howto_operator_text_to_speech_api_arguments] + +# [START howto_operator_speech_to_text_api_arguments] +CONFIG = {"encoding": "LINEAR16", "language_code": "en_US"} +AUDIO = {"uri": f"gs://{BUCKET_NAME}/{FILENAME}"} +# [END howto_operator_speech_to_text_api_arguments] + +with models.DAG( + "example_gcp_speech_to_text", + start_date=dates.days_ago(1), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag: + text_to_speech_synthesize_task = CloudTextToSpeechSynthesizeOperator( + project_id=GCP_PROJECT_ID, + input_data=INPUT, + voice=VOICE, + audio_config=AUDIO_CONFIG, + target_bucket_name=BUCKET_NAME, + target_filename=FILENAME, + task_id="text_to_speech_synthesize_task", + ) + # [START howto_operator_speech_to_text_recognize] + speech_to_text_recognize_task2 = CloudSpeechToTextRecognizeSpeechOperator( + config=CONFIG, audio=AUDIO, task_id="speech_to_text_recognize_task" + ) + # [END howto_operator_speech_to_text_recognize] + + text_to_speech_synthesize_task >> speech_to_text_recognize_task2 diff --git a/reference/providers/google/cloud/example_dags/example_stackdriver.py b/reference/providers/google/cloud/example_dags/example_stackdriver.py new file mode 100644 index 0000000..7242b86 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_stackdriver.py @@ -0,0 +1,213 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google Cloud Stackdriver service. +""" + +import json +import os + +from airflow import models +from airflow.providers.google.cloud.operators.stackdriver import ( + StackdriverDeleteAlertOperator, + StackdriverDeleteNotificationChannelOperator, + StackdriverDisableAlertPoliciesOperator, + StackdriverDisableNotificationChannelsOperator, + StackdriverEnableAlertPoliciesOperator, + StackdriverEnableNotificationChannelsOperator, + StackdriverListAlertPoliciesOperator, + StackdriverListNotificationChannelsOperator, + StackdriverUpsertAlertOperator, + StackdriverUpsertNotificationChannelOperator, +) +from airflow.utils.dates import days_ago + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") + +TEST_ALERT_POLICY_1 = { + "combiner": "OR", + "enabled": True, + "display_name": "test alert 1", + "conditions": [ + { + "condition_threshold": { + "filter": ( + 'metric.label.state="blocked" AND ' + 'metric.type="agent.googleapis.com/processes/count_by_state" ' + 'AND resource.type="gce_instance"' + ), + "comparison": "COMPARISON_GT", + "threshold_value": 100, + "duration": {"seconds": 900}, + "trigger": {"percent": 0}, + "aggregations": [ + { + "alignment_period": {"seconds": 60}, + "per_series_aligner": "ALIGN_MEAN", + "cross_series_reducer": "REDUCE_MEAN", + "group_by_fields": [ + "project", + "resource.label.instance_id", + "resource.label.zone", + ], + } + ], + }, + "display_name": "test_alert_policy_1", + } + ], +} + +TEST_ALERT_POLICY_2 = { + "combiner": "OR", + "enabled": False, + "display_name": "test alert 2", + "conditions": [ + { + "condition_threshold": { + "filter": ( + 'metric.label.state="blocked" AND ' + 'metric.type="agent.googleapis.com/processes/count_by_state" AND ' + 'resource.type="gce_instance"' + ), + "comparison": "COMPARISON_GT", + "threshold_value": 100, + "duration": {"seconds": 900}, + "trigger": {"percent": 0}, + "aggregations": [ + { + "alignment_period": {"seconds": 60}, + "per_series_aligner": "ALIGN_MEAN", + "cross_series_reducer": "REDUCE_MEAN", + "group_by_fields": [ + "project", + "resource.label.instance_id", + "resource.label.zone", + ], + } + ], + }, + "display_name": "test_alert_policy_2", + } + ], +} + +TEST_NOTIFICATION_CHANNEL_1 = { + "display_name": "channel1", + "enabled": True, + "labels": {"auth_token": "top-secret", "channel_name": "#channel"}, + "type_": "slack", +} + +TEST_NOTIFICATION_CHANNEL_2 = { + "display_name": "channel2", + "enabled": False, + "labels": {"auth_token": "top-secret", "channel_name": "#channel"}, + "type_": "slack", +} + +with models.DAG( + "example_stackdriver", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + # [START howto_operator_gcp_stackdriver_upsert_notification_channel] + create_notification_channel = StackdriverUpsertNotificationChannelOperator( + task_id="create-notification-channel", + channels=json.dumps( + {"channels": [TEST_NOTIFICATION_CHANNEL_1, TEST_NOTIFICATION_CHANNEL_2]} + ), + ) + # [END howto_operator_gcp_stackdriver_upsert_notification_channel] + + # [START howto_operator_gcp_stackdriver_enable_notification_channel] + enable_notification_channel = StackdriverEnableNotificationChannelsOperator( + task_id="enable-notification-channel", filter_='type="slack"' + ) + # [END howto_operator_gcp_stackdriver_enable_notification_channel] + + # [START howto_operator_gcp_stackdriver_disable_notification_channel] + disable_notification_channel = StackdriverDisableNotificationChannelsOperator( + task_id="disable-notification-channel", filter_='displayName="channel1"' + ) + # [END howto_operator_gcp_stackdriver_disable_notification_channel] + + # [START howto_operator_gcp_stackdriver_list_notification_channel] + list_notification_channel = StackdriverListNotificationChannelsOperator( + task_id="list-notification-channel", filter_='type="slack"' + ) + # [END howto_operator_gcp_stackdriver_list_notification_channel] + + # [START howto_operator_gcp_stackdriver_upsert_alert_policy] + create_alert_policy = StackdriverUpsertAlertOperator( + task_id="create-alert-policies", + alerts=json.dumps({"policies": [TEST_ALERT_POLICY_1, TEST_ALERT_POLICY_2]}), + ) + # [END howto_operator_gcp_stackdriver_upsert_alert_policy] + + # [START howto_operator_gcp_stackdriver_enable_alert_policy] + enable_alert_policy = StackdriverEnableAlertPoliciesOperator( + task_id="enable-alert-policies", + filter_='(displayName="test alert 1" OR displayName="test alert 2")', + ) + # [END howto_operator_gcp_stackdriver_enable_alert_policy] + + # [START howto_operator_gcp_stackdriver_disable_alert_policy] + disable_alert_policy = StackdriverDisableAlertPoliciesOperator( + task_id="disable-alert-policies", + filter_='displayName="test alert 1"', + ) + # [END howto_operator_gcp_stackdriver_disable_alert_policy] + + # [START howto_operator_gcp_stackdriver_list_alert_policy] + list_alert_policies = StackdriverListAlertPoliciesOperator( + task_id="list-alert-policies", + ) + # [END howto_operator_gcp_stackdriver_list_alert_policy] + + # [START howto_operator_gcp_stackdriver_delete_notification_channel] + delete_notification_channel = StackdriverDeleteNotificationChannelOperator( + task_id="delete-notification-channel", + name="{{ task_instance.xcom_pull('list-notification-channel')[0]['name'] }}", + ) + # [END howto_operator_gcp_stackdriver_delete_notification_channel] + + delete_notification_channel_2 = StackdriverDeleteNotificationChannelOperator( + task_id="delete-notification-channel-2", + name="{{ task_instance.xcom_pull('list-notification-channel')[1]['name'] }}", + ) + + # [START howto_operator_gcp_stackdriver_delete_alert_policy] + delete_alert_policy = StackdriverDeleteAlertOperator( + task_id="delete-alert-policy", + name="{{ task_instance.xcom_pull('list-alert-policies')[0]['name'] }}", + ) + # [END howto_operator_gcp_stackdriver_delete_alert_policy] + + delete_alert_policy_2 = StackdriverDeleteAlertOperator( + task_id="delete-alert-policy-2", + name="{{ task_instance.xcom_pull('list-alert-policies')[1]['name'] }}", + ) + + create_notification_channel >> enable_notification_channel >> disable_notification_channel + disable_notification_channel >> list_notification_channel >> create_alert_policy + create_alert_policy >> enable_alert_policy >> disable_alert_policy >> list_alert_policies + list_alert_policies >> delete_notification_channel >> delete_notification_channel_2 + delete_notification_channel_2 >> delete_alert_policy >> delete_alert_policy_2 diff --git a/reference/providers/google/cloud/example_dags/example_tasks.py b/reference/providers/google/cloud/example_dags/example_tasks.py new file mode 100644 index 0000000..fe69ee2 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_tasks.py @@ -0,0 +1,187 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that creates, gets, lists, updates, purges, pauses, resumes +and deletes Queues and creates, gets, lists, runs and deletes Tasks in the Google +Cloud Tasks service in the Google Cloud. +""" +import os +from datetime import datetime, timedelta + +from airflow import models +from airflow.models.baseoperator import chain +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.tasks import ( + CloudTasksQueueCreateOperator, + CloudTasksQueueDeleteOperator, + CloudTasksQueueGetOperator, + CloudTasksQueuePauseOperator, + CloudTasksQueuePurgeOperator, + CloudTasksQueueResumeOperator, + CloudTasksQueuesListOperator, + CloudTasksQueueUpdateOperator, + CloudTasksTaskCreateOperator, + CloudTasksTaskDeleteOperator, + CloudTasksTaskGetOperator, + CloudTasksTaskRunOperator, + CloudTasksTasksListOperator, +) +from airflow.utils.dates import days_ago +from google.api_core.retry import Retry +from google.cloud.tasks_v2.types import Queue +from google.protobuf import timestamp_pb2 + +timestamp = timestamp_pb2.Timestamp() +timestamp.FromDatetime( + datetime.now() + timedelta(hours=12) +) # pylint: disable=no-member + +LOCATION = "europe-west1" +QUEUE_ID = os.environ.get("GCP_TASKS_QUEUE_ID", "cloud-tasks-queue") +TASK_NAME = "task-to-run" + + +TASK = { + "app_engine_http_request": { # Specify the type of request. + "http_method": "POST", + "relative_uri": "/example_task_handler", + "body": b"Hello", + }, + "schedule_time": timestamp, +} + +with models.DAG( + "example_gcp_tasks", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + + # Queue operations + create_queue = CloudTasksQueueCreateOperator( + location=LOCATION, + task_queue=Queue(stackdriver_logging_config=dict(sampling_ratio=0.5)), + queue_name=QUEUE_ID, + retry=Retry(maximum=10.0), + timeout=5, + task_id="create_queue", + ) + + delete_queue = CloudTasksQueueDeleteOperator( + location=LOCATION, + queue_name=QUEUE_ID, + task_id="delete_queue", + ) + + resume_queue = CloudTasksQueueResumeOperator( + location=LOCATION, + queue_name=QUEUE_ID, + task_id="resume_queue", + ) + + pause_queue = CloudTasksQueuePauseOperator( + location=LOCATION, + queue_name=QUEUE_ID, + task_id="pause_queue", + ) + + purge_queue = CloudTasksQueuePurgeOperator( + location=LOCATION, + queue_name=QUEUE_ID, + task_id="purge_queue", + ) + + get_queue = CloudTasksQueueGetOperator( + location=LOCATION, + queue_name=QUEUE_ID, + task_id="get_queue", + ) + + get_queue_result = BashOperator( + task_id="get_queue_result", + bash_command="echo \"{{ task_instance.xcom_pull('get_queue') }}\"", + ) + get_queue >> get_queue_result + + update_queue = CloudTasksQueueUpdateOperator( + task_queue=Queue(stackdriver_logging_config=dict(sampling_ratio=1)), + location=LOCATION, + queue_name=QUEUE_ID, + update_mask={"paths": ["stackdriver_logging_config.sampling_ratio"]}, + task_id="update_queue", + ) + + list_queue = CloudTasksQueuesListOperator(location=LOCATION, task_id="list_queue") + + chain( + create_queue, + update_queue, + pause_queue, + resume_queue, + purge_queue, + get_queue, + list_queue, + delete_queue, + ) + + # Tasks operations + create_task = CloudTasksTaskCreateOperator( + location=LOCATION, + queue_name=QUEUE_ID, + task=TASK, + task_name=TASK_NAME, + retry=Retry(maximum=10.0), + timeout=5, + task_id="create_task_to_run", + ) + + tasks_get = CloudTasksTaskGetOperator( + location=LOCATION, + queue_name=QUEUE_ID, + task_name=TASK_NAME, + task_id="tasks_get", + ) + + run_task = CloudTasksTaskRunOperator( + location=LOCATION, + queue_name=QUEUE_ID, + task_name=TASK_NAME, + task_id="run_task", + ) + + list_tasks = CloudTasksTasksListOperator( + location=LOCATION, queue_name=QUEUE_ID, task_id="list_tasks" + ) + + delete_task = CloudTasksTaskDeleteOperator( + location=LOCATION, + queue_name=QUEUE_ID, + task_name=TASK_NAME, + task_id="delete_task", + ) + + chain( + purge_queue, + create_task, + tasks_get, + list_tasks, + run_task, + delete_task, + delete_queue, + ) diff --git a/reference/providers/google/cloud/example_dags/example_text_to_speech.py b/reference/providers/google/cloud/example_dags/example_text_to_speech.py new file mode 100644 index 0000000..d322194 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_text_to_speech.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.google.cloud.operators.text_to_speech import ( + CloudTextToSpeechSynthesizeOperator, +) +from airflow.utils import dates + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +BUCKET_NAME = os.environ.get( + "GCP_TEXT_TO_SPEECH_BUCKET", "gcp-text-to-speech-test-bucket" +) + +# [START howto_operator_text_to_speech_gcp_filename] +FILENAME = "gcp-speech-test-file" +# [END howto_operator_text_to_speech_gcp_filename] + +# [START howto_operator_text_to_speech_api_arguments] +INPUT = {"text": "Sample text for demo purposes"} +VOICE = {"language_code": "en-US", "ssml_gender": "FEMALE"} +AUDIO_CONFIG = {"audio_encoding": "LINEAR16"} +# [END howto_operator_text_to_speech_api_arguments] + +with models.DAG( + "example_gcp_text_to_speech", + start_date=dates.days_ago(1), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag: + + # [START howto_operator_text_to_speech_synthesize] + text_to_speech_synthesize_task = CloudTextToSpeechSynthesizeOperator( + project_id=GCP_PROJECT_ID, + input_data=INPUT, + voice=VOICE, + audio_config=AUDIO_CONFIG, + target_bucket_name=BUCKET_NAME, + target_filename=FILENAME, + task_id="text_to_speech_synthesize_task", + ) + text_to_speech_synthesize_task2 = CloudTextToSpeechSynthesizeOperator( + input_data=INPUT, + voice=VOICE, + audio_config=AUDIO_CONFIG, + target_bucket_name=BUCKET_NAME, + target_filename=FILENAME, + task_id="text_to_speech_synthesize_task2", + ) + # [END howto_operator_text_to_speech_synthesize] + + text_to_speech_synthesize_task >> text_to_speech_synthesize_task2 diff --git a/reference/providers/google/cloud/example_dags/example_translate.py b/reference/providers/google/cloud/example_dags/example_translate.py new file mode 100644 index 0000000..d269d76 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_translate.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that translates text in Google Cloud Translate +service in the Google Cloud. + +""" + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.translate import ( + CloudTranslateTextOperator, +) +from airflow.utils.dates import days_ago + +with models.DAG( + "example_gcp_translate", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + # [START howto_operator_translate_text] + product_set_create = CloudTranslateTextOperator( + task_id="translate", + values=["zażółć gęślą jaźń"], + target_language="en", + format_="text", + source_language=None, + model="base", + ) + # [END howto_operator_translate_text] + # [START howto_operator_translate_access] + translation_access = BashOperator( + task_id="access", + bash_command="echo '{{ task_instance.xcom_pull(\"translate\")[0] }}'", + ) + product_set_create >> translation_access + # [END howto_operator_translate_access] diff --git a/reference/providers/google/cloud/example_dags/example_translate_speech.py b/reference/providers/google/cloud/example_dags/example_translate_speech.py new file mode 100644 index 0000000..1d56475 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_translate_speech.py @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.google.cloud.operators.text_to_speech import ( + CloudTextToSpeechSynthesizeOperator, +) +from airflow.providers.google.cloud.operators.translate_speech import ( + CloudTranslateSpeechOperator, +) +from airflow.utils import dates + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +BUCKET_NAME = os.environ.get( + "GCP_TRANSLATE_SPEECH_TEST_BUCKET", "gcp-translate-speech-test-bucket" +) + +# [START howto_operator_translate_speech_gcp_filename] +FILENAME = "gcp-speech-test-file" +# [END howto_operator_translate_speech_gcp_filename] + +# [START howto_operator_text_to_speech_api_arguments] +INPUT = {"text": "Sample text for demo purposes"} +VOICE = {"language_code": "en-US", "ssml_gender": "FEMALE"} +AUDIO_CONFIG = {"audio_encoding": "LINEAR16"} +# [END howto_operator_text_to_speech_api_arguments] + +# [START howto_operator_translate_speech_arguments] +CONFIG = {"encoding": "LINEAR16", "language_code": "en_US"} +AUDIO = {"uri": f"gs://{BUCKET_NAME}/{FILENAME}"} +TARGET_LANGUAGE = "pl" +FORMAT = "text" +MODEL = "base" +SOURCE_LANGUAGE = None # type: None +# [END howto_operator_translate_speech_arguments] + + +with models.DAG( + "example_gcp_translate_speech", + schedule_interval=None, # Override to match your needs + start_date=dates.days_ago(1), + tags=["example"], +) as dag: + text_to_speech_synthesize_task = CloudTextToSpeechSynthesizeOperator( + project_id=GCP_PROJECT_ID, + input_data=INPUT, + voice=VOICE, + audio_config=AUDIO_CONFIG, + target_bucket_name=BUCKET_NAME, + target_filename=FILENAME, + task_id="text_to_speech_synthesize_task", + ) + # [START howto_operator_translate_speech] + translate_speech_task = CloudTranslateSpeechOperator( + project_id=GCP_PROJECT_ID, + audio=AUDIO, + config=CONFIG, + target_language=TARGET_LANGUAGE, + format_=FORMAT, + source_language=SOURCE_LANGUAGE, + model=MODEL, + task_id="translate_speech_task", + ) + translate_speech_task2 = CloudTranslateSpeechOperator( + audio=AUDIO, + config=CONFIG, + target_language=TARGET_LANGUAGE, + format_=FORMAT, + source_language=SOURCE_LANGUAGE, + model=MODEL, + task_id="translate_speech_task2", + ) + # [END howto_operator_translate_speech] + text_to_speech_synthesize_task >> translate_speech_task >> translate_speech_task2 diff --git a/reference/providers/google/cloud/example_dags/example_video_intelligence.py b/reference/providers/google/cloud/example_dags/example_video_intelligence.py new file mode 100644 index 0000000..6b95056 --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_video_intelligence.py @@ -0,0 +1,118 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that demonstrates operators for the Google Cloud Video Intelligence service in the Google +Cloud Platform. + +This DAG relies on the following OS environment variables: + +* GCP_BUCKET_NAME - Google Cloud Storage bucket where the file exists. +""" +import os + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.video_intelligence import ( + CloudVideoIntelligenceDetectVideoExplicitContentOperator, + CloudVideoIntelligenceDetectVideoLabelsOperator, + CloudVideoIntelligenceDetectVideoShotsOperator, +) +from airflow.utils.dates import days_ago +from google.api_core.retry import Retry + +# [START howto_operator_video_intelligence_os_args] +GCP_BUCKET_NAME = os.environ.get( + "GCP_VIDEO_INTELLIGENCE_BUCKET_NAME", "test-bucket-name" +) +# [END howto_operator_video_intelligence_os_args] + + +# [START howto_operator_video_intelligence_other_args] +INPUT_URI = f"gs://{GCP_BUCKET_NAME}/video.mp4" +# [END howto_operator_video_intelligence_other_args] + + +with models.DAG( + "example_gcp_video_intelligence", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + + # [START howto_operator_video_intelligence_detect_labels] + detect_video_label = CloudVideoIntelligenceDetectVideoLabelsOperator( + input_uri=INPUT_URI, + output_uri=None, + video_context=None, + timeout=5, + task_id="detect_video_label", + ) + # [END howto_operator_video_intelligence_detect_labels] + + # [START howto_operator_video_intelligence_detect_labels_result] + detect_video_label_result = BashOperator( + bash_command="echo {{ task_instance.xcom_pull('detect_video_label')" + "['annotationResults'][0]['shotLabelAnnotations'][0]['entity']}}", + task_id="detect_video_label_result", + ) + # [END howto_operator_video_intelligence_detect_labels_result] + + # [START howto_operator_video_intelligence_detect_explicit_content] + detect_video_explicit_content = ( + CloudVideoIntelligenceDetectVideoExplicitContentOperator( + input_uri=INPUT_URI, + output_uri=None, + video_context=None, + retry=Retry(maximum=10.0), + timeout=5, + task_id="detect_video_explicit_content", + ) + ) + # [END howto_operator_video_intelligence_detect_explicit_content] + + # [START howto_operator_video_intelligence_detect_explicit_content_result] + detect_video_explicit_content_result = BashOperator( + bash_command="echo {{ task_instance.xcom_pull('detect_video_explicit_content')" + "['annotationResults'][0]['explicitAnnotation']['frames'][0]}}", + task_id="detect_video_explicit_content_result", + ) + # [END howto_operator_video_intelligence_detect_explicit_content_result] + + # [START howto_operator_video_intelligence_detect_video_shots] + detect_video_shots = CloudVideoIntelligenceDetectVideoShotsOperator( + input_uri=INPUT_URI, + output_uri=None, + video_context=None, + retry=Retry(maximum=10.0), + timeout=5, + task_id="detect_video_shots", + ) + # [END howto_operator_video_intelligence_detect_video_shots] + + # [START howto_operator_video_intelligence_detect_video_shots_result] + detect_video_shots_result = BashOperator( + bash_command="echo {{ task_instance.xcom_pull('detect_video_shots')" + "['annotationResults'][0]['shotAnnotations'][0]}}", + task_id="detect_video_shots_result", + ) + # [END howto_operator_video_intelligence_detect_video_shots_result] + + detect_video_label >> detect_video_label_result + detect_video_explicit_content >> detect_video_explicit_content_result + detect_video_shots >> detect_video_shots_result diff --git a/reference/providers/google/cloud/example_dags/example_vision.py b/reference/providers/google/cloud/example_dags/example_vision.py new file mode 100644 index 0000000..a53e21b --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_vision.py @@ -0,0 +1,535 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that creates, gets, updates and deletes Products and Product Sets in the Google Cloud +Vision service. + +This DAG relies on the following OS environment variables + +* GCP_VISION_LOCATION - Zone where the instance exists. +* GCP_VISION_PRODUCT_SET_ID - Product Set ID. +* GCP_VISION_PRODUCT_ID - Product ID. +* GCP_VISION_REFERENCE_IMAGE_ID - Reference Image ID. +* GCP_VISION_REFERENCE_IMAGE_URL - A link to the bucket that contains the reference image. +* GCP_VISION_ANNOTATE_IMAGE_URL - A link to the bucket that contains the file to be annotated. + +""" + +import os + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.operators.vision import ( + CloudVisionAddProductToProductSetOperator, + CloudVisionCreateProductOperator, + CloudVisionCreateProductSetOperator, + CloudVisionCreateReferenceImageOperator, + CloudVisionDeleteProductOperator, + CloudVisionDeleteProductSetOperator, + CloudVisionDeleteReferenceImageOperator, + CloudVisionDetectImageLabelsOperator, + CloudVisionDetectImageSafeSearchOperator, + CloudVisionDetectTextOperator, + CloudVisionGetProductOperator, + CloudVisionGetProductSetOperator, + CloudVisionImageAnnotateOperator, + CloudVisionRemoveProductFromProductSetOperator, + CloudVisionTextDetectOperator, + CloudVisionUpdateProductOperator, + CloudVisionUpdateProductSetOperator, +) +from airflow.utils.dates import days_ago + +# [END howto_operator_vision_product_import] +# [START howto_operator_vision_reference_image_import] +# [END howto_operator_vision_product_set_import] +# [START howto_operator_vision_product_import] +# [END howto_operator_vision_retry_import] +# [START howto_operator_vision_product_set_import] +from google.cloud.vision_v1.types import ( # isort:skip pylint: disable=wrong-import-order + Product, + ProductSet, + ReferenceImage, +) + +# [START howto_operator_vision_retry_import] +from google.api_core.retry import Retry # isort:skip pylint: disable=wrong-import-order + + + + +# [END howto_operator_vision_reference_image_import] +# [START howto_operator_vision_enums_import] +from google.cloud.vision import enums # isort:skip pylint: disable=wrong-import-order + +# [END howto_operator_vision_enums_import] + + +GCP_VISION_LOCATION = os.environ.get("GCP_VISION_LOCATION", "europe-west1") + +GCP_VISION_PRODUCT_SET_ID = os.environ.get( + "GCP_VISION_PRODUCT_SET_ID", "product_set_explicit_id" +) +GCP_VISION_PRODUCT_ID = os.environ.get("GCP_VISION_PRODUCT_ID", "product_explicit_id") +GCP_VISION_REFERENCE_IMAGE_ID = os.environ.get( + "GCP_VISION_REFERENCE_IMAGE_ID", "reference_image_explicit_id" +) +GCP_VISION_REFERENCE_IMAGE_URL = os.environ.get( + "GCP_VISION_REFERENCE_IMAGE_URL", "gs://bucket/image1.jpg" +) +GCP_VISION_ANNOTATE_IMAGE_URL = os.environ.get( + "GCP_VISION_ANNOTATE_IMAGE_URL", "gs://bucket/image2.jpg" +) + +# [START howto_operator_vision_product_set] +product_set = ProductSet(display_name="My Product Set") +# [END howto_operator_vision_product_set] + +# [START howto_operator_vision_product] +product = Product(display_name="My Product 1", product_category="toys") +# [END howto_operator_vision_product] + +# [START howto_operator_vision_reference_image] +reference_image = ReferenceImage(uri=GCP_VISION_REFERENCE_IMAGE_URL) +# [END howto_operator_vision_reference_image] + +# [START howto_operator_vision_annotate_image_request] +annotate_image_request = { + "image": {"source": {"image_uri": GCP_VISION_ANNOTATE_IMAGE_URL}}, + "features": [{"type": enums.Feature.Type.LOGO_DETECTION}], +} +# [END howto_operator_vision_annotate_image_request] + +# [START howto_operator_vision_detect_image_param] +DETECT_IMAGE = {"source": {"image_uri": GCP_VISION_ANNOTATE_IMAGE_URL}} +# [END howto_operator_vision_detect_image_param] + +with models.DAG( + "example_gcp_vision_autogenerated_id", + start_date=days_ago(1), + schedule_interval=None, +) as dag_autogenerated_id: + # ################################## # + # ### Autogenerated IDs examples ### # + # ################################## # + + # [START howto_operator_vision_product_set_create] + product_set_create = CloudVisionCreateProductSetOperator( + location=GCP_VISION_LOCATION, + product_set=product_set, + retry=Retry(maximum=10.0), + timeout=5, + task_id="product_set_create", + ) + # [END howto_operator_vision_product_set_create] + + # [START howto_operator_vision_product_set_get] + product_set_get = CloudVisionGetProductSetOperator( + location=GCP_VISION_LOCATION, + product_set_id="{{ task_instance.xcom_pull('product_set_create') }}", + task_id="product_set_get", + ) + # [END howto_operator_vision_product_set_get] + + # [START howto_operator_vision_product_set_update] + product_set_update = CloudVisionUpdateProductSetOperator( + location=GCP_VISION_LOCATION, + product_set_id="{{ task_instance.xcom_pull('product_set_create') }}", + product_set=ProductSet(display_name="My Product Set 2"), + task_id="product_set_update", + ) + # [END howto_operator_vision_product_set_update] + + # [START howto_operator_vision_product_set_delete] + product_set_delete = CloudVisionDeleteProductSetOperator( + location=GCP_VISION_LOCATION, + product_set_id="{{ task_instance.xcom_pull('product_set_create') }}", + task_id="product_set_delete", + ) + # [END howto_operator_vision_product_set_delete] + + # [START howto_operator_vision_product_create] + product_create = CloudVisionCreateProductOperator( + location=GCP_VISION_LOCATION, + product=product, + retry=Retry(maximum=10.0), + timeout=5, + task_id="product_create", + ) + # [END howto_operator_vision_product_create] + + # [START howto_operator_vision_product_get] + product_get = CloudVisionGetProductOperator( + location=GCP_VISION_LOCATION, + product_id="{{ task_instance.xcom_pull('product_create') }}", + task_id="product_get", + ) + # [END howto_operator_vision_product_get] + + # [START howto_operator_vision_product_update] + product_update = CloudVisionUpdateProductOperator( + location=GCP_VISION_LOCATION, + product_id="{{ task_instance.xcom_pull('product_create') }}", + product=Product( + display_name="My Product 2", description="My updated description" + ), + task_id="product_update", + ) + # [END howto_operator_vision_product_update] + + # [START howto_operator_vision_product_delete] + product_delete = CloudVisionDeleteProductOperator( + location=GCP_VISION_LOCATION, + product_id="{{ task_instance.xcom_pull('product_create') }}", + task_id="product_delete", + ) + # [END howto_operator_vision_product_delete] + + # [START howto_operator_vision_reference_image_create] + reference_image_create = CloudVisionCreateReferenceImageOperator( + location=GCP_VISION_LOCATION, + reference_image=reference_image, + product_id="{{ task_instance.xcom_pull('product_create') }}", + reference_image_id=GCP_VISION_REFERENCE_IMAGE_ID, + retry=Retry(maximum=10.0), + timeout=5, + task_id="reference_image_create", + ) + # [END howto_operator_vision_reference_image_create] + + # [START howto_operator_vision_reference_image_delete] + reference_image_delete = CloudVisionDeleteReferenceImageOperator( + location=GCP_VISION_LOCATION, + product_id="{{ task_instance.xcom_pull('product_create') }}", + reference_image_id=GCP_VISION_REFERENCE_IMAGE_ID, + retry=Retry(maximum=10.0), + timeout=5, + task_id="reference_image_delete", + ) + # [END howto_operator_vision_reference_image_delete] + + # [START howto_operator_vision_add_product_to_product_set] + add_product_to_product_set = CloudVisionAddProductToProductSetOperator( + location=GCP_VISION_LOCATION, + product_set_id="{{ task_instance.xcom_pull('product_set_create') }}", + product_id="{{ task_instance.xcom_pull('product_create') }}", + retry=Retry(maximum=10.0), + timeout=5, + task_id="add_product_to_product_set", + ) + # [END howto_operator_vision_add_product_to_product_set] + + # [START howto_operator_vision_remove_product_from_product_set] + remove_product_from_product_set = CloudVisionRemoveProductFromProductSetOperator( + location=GCP_VISION_LOCATION, + product_set_id="{{ task_instance.xcom_pull('product_set_create') }}", + product_id="{{ task_instance.xcom_pull('product_create') }}", + retry=Retry(maximum=10.0), + timeout=5, + task_id="remove_product_from_product_set", + ) + # [END howto_operator_vision_remove_product_from_product_set] + + # Product path + product_create >> product_get >> product_update >> product_delete + + # ProductSet path + product_set_create >> product_set_get >> product_set_update >> product_set_delete + + # ReferenceImage path + product_create >> reference_image_create >> reference_image_delete >> product_delete + + # Product/ProductSet path + product_create >> add_product_to_product_set + product_set_create >> add_product_to_product_set + add_product_to_product_set >> remove_product_from_product_set + remove_product_from_product_set >> product_delete + remove_product_from_product_set >> product_set_delete + +with models.DAG( + "example_gcp_vision_explicit_id", start_date=days_ago(1), schedule_interval=None +) as dag_explicit_id: + # ############################# # + # ### Explicit IDs examples ### # + # ############################# # + + # [START howto_operator_vision_product_set_create_2] + product_set_create_2 = CloudVisionCreateProductSetOperator( + product_set_id=GCP_VISION_PRODUCT_SET_ID, + location=GCP_VISION_LOCATION, + product_set=product_set, + retry=Retry(maximum=10.0), + timeout=5, + task_id="product_set_create_2", + ) + # [END howto_operator_vision_product_set_create_2] + + # Second 'create' task with the same product_set_id to demonstrate idempotence + product_set_create_2_idempotence = CloudVisionCreateProductSetOperator( + product_set_id=GCP_VISION_PRODUCT_SET_ID, + location=GCP_VISION_LOCATION, + product_set=product_set, + retry=Retry(maximum=10.0), + timeout=5, + task_id="product_set_create_2_idempotence", + ) + + # [START howto_operator_vision_product_set_get_2] + product_set_get_2 = CloudVisionGetProductSetOperator( + location=GCP_VISION_LOCATION, + product_set_id=GCP_VISION_PRODUCT_SET_ID, + task_id="product_set_get_2", + ) + # [END howto_operator_vision_product_set_get_2] + + # [START howto_operator_vision_product_set_update_2] + product_set_update_2 = CloudVisionUpdateProductSetOperator( + location=GCP_VISION_LOCATION, + product_set_id=GCP_VISION_PRODUCT_SET_ID, + product_set=ProductSet(display_name="My Product Set 2"), + task_id="product_set_update_2", + ) + # [END howto_operator_vision_product_set_update_2] + + # [START howto_operator_vision_product_set_delete_2] + product_set_delete_2 = CloudVisionDeleteProductSetOperator( + location=GCP_VISION_LOCATION, + product_set_id=GCP_VISION_PRODUCT_SET_ID, + task_id="product_set_delete_2", + ) + # [END howto_operator_vision_product_set_delete_2] + + # [START howto_operator_vision_product_create_2] + product_create_2 = CloudVisionCreateProductOperator( + product_id=GCP_VISION_PRODUCT_ID, + location=GCP_VISION_LOCATION, + product=product, + retry=Retry(maximum=10.0), + timeout=5, + task_id="product_create_2", + ) + # [END howto_operator_vision_product_create_2] + + # Second 'create' task with the same product_id to demonstrate idempotence + product_create_2_idempotence = CloudVisionCreateProductOperator( + product_id=GCP_VISION_PRODUCT_ID, + location=GCP_VISION_LOCATION, + product=product, + retry=Retry(maximum=10.0), + timeout=5, + task_id="product_create_2_idempotence", + ) + + # [START howto_operator_vision_product_get_2] + product_get_2 = CloudVisionGetProductOperator( + location=GCP_VISION_LOCATION, + product_id=GCP_VISION_PRODUCT_ID, + task_id="product_get_2", + ) + # [END howto_operator_vision_product_get_2] + + # [START howto_operator_vision_product_update_2] + product_update_2 = CloudVisionUpdateProductOperator( + location=GCP_VISION_LOCATION, + product_id=GCP_VISION_PRODUCT_ID, + product=Product( + display_name="My Product 2", description="My updated description" + ), + task_id="product_update_2", + ) + # [END howto_operator_vision_product_update_2] + + # [START howto_operator_vision_product_delete_2] + product_delete_2 = CloudVisionDeleteProductOperator( + location=GCP_VISION_LOCATION, + product_id=GCP_VISION_PRODUCT_ID, + task_id="product_delete_2", + ) + # [END howto_operator_vision_product_delete_2] + + # [START howto_operator_vision_reference_image_create_2] + reference_image_create_2 = CloudVisionCreateReferenceImageOperator( + location=GCP_VISION_LOCATION, + reference_image=reference_image, + product_id=GCP_VISION_PRODUCT_ID, + reference_image_id=GCP_VISION_REFERENCE_IMAGE_ID, + retry=Retry(maximum=10.0), + timeout=5, + task_id="reference_image_create_2", + ) + # [END howto_operator_vision_reference_image_create_2] + + # [START howto_operator_vision_reference_image_delete_2] + reference_image_delete_2 = CloudVisionDeleteReferenceImageOperator( + location=GCP_VISION_LOCATION, + reference_image_id=GCP_VISION_REFERENCE_IMAGE_ID, + product_id=GCP_VISION_PRODUCT_ID, + retry=Retry(maximum=10.0), + timeout=5, + task_id="reference_image_delete_2", + ) + # [END howto_operator_vision_reference_image_delete_2] + + # Second 'create' task with the same product_id to demonstrate idempotence + reference_image_create_2_idempotence = CloudVisionCreateReferenceImageOperator( + location=GCP_VISION_LOCATION, + reference_image=reference_image, + product_id=GCP_VISION_PRODUCT_ID, + reference_image_id=GCP_VISION_REFERENCE_IMAGE_ID, + retry=Retry(maximum=10.0), + timeout=5, + task_id="reference_image_create_2_idempotence", + ) + + # [START howto_operator_vision_add_product_to_product_set_2] + add_product_to_product_set_2 = CloudVisionAddProductToProductSetOperator( + location=GCP_VISION_LOCATION, + product_set_id=GCP_VISION_PRODUCT_SET_ID, + product_id=GCP_VISION_PRODUCT_ID, + retry=Retry(maximum=10.0), + timeout=5, + task_id="add_product_to_product_set_2", + ) + # [END howto_operator_vision_add_product_to_product_set_2] + + # [START howto_operator_vision_remove_product_from_product_set_2] + remove_product_from_product_set_2 = CloudVisionRemoveProductFromProductSetOperator( + location=GCP_VISION_LOCATION, + product_set_id=GCP_VISION_PRODUCT_SET_ID, + product_id=GCP_VISION_PRODUCT_ID, + retry=Retry(maximum=10.0), + timeout=5, + task_id="remove_product_from_product_set_2", + ) + # [END howto_operator_vision_remove_product_from_product_set_2] + + # Product path + product_create_2 >> product_create_2_idempotence >> product_get_2 >> product_update_2 >> product_delete_2 + + # ProductSet path + product_set_create_2 >> product_set_get_2 >> product_set_update_2 >> product_set_delete_2 + product_set_create_2 >> product_set_create_2_idempotence >> product_set_delete_2 + + # ReferenceImage path + product_create_2 >> reference_image_create_2 >> reference_image_create_2_idempotence + reference_image_create_2_idempotence >> reference_image_delete_2 >> product_delete_2 + + # Product/ProductSet path + add_product_to_product_set_2 >> remove_product_from_product_set_2 + product_set_create_2 >> add_product_to_product_set_2 + product_create_2 >> add_product_to_product_set_2 + remove_product_from_product_set_2 >> product_set_delete_2 + remove_product_from_product_set_2 >> product_delete_2 + +with models.DAG( + "example_gcp_vision_annotate_image", start_date=days_ago(1), schedule_interval=None +) as dag_annotate_image: + # ############################## # + # ### Annotate image example ### # + # ############################## # + + # [START howto_operator_vision_annotate_image] + annotate_image = CloudVisionImageAnnotateOperator( + request=annotate_image_request, + retry=Retry(maximum=10.0), + timeout=5, + task_id="annotate_image", + ) + # [END howto_operator_vision_annotate_image] + + # [START howto_operator_vision_annotate_image_result] + annotate_image_result = BashOperator( + bash_command="echo {{ task_instance.xcom_pull('annotate_image')" + "['logoAnnotations'][0]['description'] }}", + task_id="annotate_image_result", + ) + # [END howto_operator_vision_annotate_image_result] + + # [START howto_operator_vision_detect_text] + detect_text = CloudVisionDetectTextOperator( + image=DETECT_IMAGE, + retry=Retry(maximum=10.0), + timeout=5, + task_id="detect_text", + language_hints="en", + web_detection_params={"include_geo_results": True}, + ) + # [END howto_operator_vision_detect_text] + + # [START howto_operator_vision_detect_text_result] + detect_text_result = BashOperator( + bash_command="echo {{ task_instance.xcom_pull('detect_text')['textAnnotations'][0] }}", + task_id="detect_text_result", + ) + # [END howto_operator_vision_detect_text_result] + + # [START howto_operator_vision_document_detect_text] + document_detect_text = CloudVisionTextDetectOperator( + image=DETECT_IMAGE, + retry=Retry(maximum=10.0), + timeout=5, + task_id="document_detect_text", + ) + # [END howto_operator_vision_document_detect_text] + + # [START howto_operator_vision_document_detect_text_result] + document_detect_text_result = BashOperator( + bash_command="echo {{ task_instance.xcom_pull('document_detect_text')['textAnnotations'][0] }}", + task_id="document_detect_text_result", + ) + # [END howto_operator_vision_document_detect_text_result] + + # [START howto_operator_vision_detect_labels] + detect_labels = CloudVisionDetectImageLabelsOperator( + image=DETECT_IMAGE, + retry=Retry(maximum=10.0), + timeout=5, + task_id="detect_labels", + ) + # [END howto_operator_vision_detect_labels] + + # [START howto_operator_vision_detect_labels_result] + detect_labels_result = BashOperator( + bash_command="echo {{ task_instance.xcom_pull('detect_labels')['labelAnnotations'][0] }}", + task_id="detect_labels_result", + ) + # [END howto_operator_vision_detect_labels_result] + + # [START howto_operator_vision_detect_safe_search] + detect_safe_search = CloudVisionDetectImageSafeSearchOperator( + image=DETECT_IMAGE, + retry=Retry(maximum=10.0), + timeout=5, + task_id="detect_safe_search", + ) + # [END howto_operator_vision_detect_safe_search] + + # [START howto_operator_vision_detect_safe_search_result] + detect_safe_search_result = BashOperator( + bash_command="echo {{ task_instance.xcom_pull('detect_safe_search') }}", + task_id="detect_safe_search_result", + ) + # [END howto_operator_vision_detect_safe_search_result] + + annotate_image >> annotate_image_result + + detect_text >> detect_text_result + document_detect_text >> document_detect_text_result + detect_labels >> detect_labels_result + detect_safe_search >> detect_safe_search_result diff --git a/reference/providers/google/cloud/example_dags/example_workflows.py b/reference/providers/google/cloud/example_dags/example_workflows.py new file mode 100644 index 0000000..06353aa --- /dev/null +++ b/reference/providers/google/cloud/example_dags/example_workflows.py @@ -0,0 +1,208 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import DAG +from airflow.providers.google.cloud.operators.workflows import ( + WorkflowsCancelExecutionOperator, + WorkflowsCreateExecutionOperator, + WorkflowsCreateWorkflowOperator, + WorkflowsDeleteWorkflowOperator, + WorkflowsGetExecutionOperator, + WorkflowsGetWorkflowOperator, + WorkflowsListExecutionsOperator, + WorkflowsListWorkflowsOperator, + WorkflowsUpdateWorkflowOperator, +) +from airflow.providers.google.cloud.sensors.workflows import WorkflowExecutionSensor +from airflow.utils.dates import days_ago + +LOCATION = os.environ.get("GCP_WORKFLOWS_LOCATION", "us-central1") +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id") + +WORKFLOW_ID = os.getenv("GCP_WORKFLOWS_WORKFLOW_ID", "airflow-test-workflow") + +# [START how_to_define_workflow] +WORKFLOW_CONTENT = """ +- getCurrentTime: + call: http.get + args: + url: https://us-central1-workflowsample.cloudfunctions.net/datetime + result: currentTime +- readWikipedia: + call: http.get + args: + url: https://en.wikipedia.org/w/api.php + query: + action: opensearch + search: ${currentTime.body.dayOfTheWeek} + result: wikiResult +- returnResult: + return: ${wikiResult.body[1]} +""" + +WORKFLOW = { + "description": "Test workflow", + "labels": {"airflow-version": "dev"}, + "source_contents": WORKFLOW_CONTENT, +} +# [END how_to_define_workflow] + +EXECUTION = {"argument": ""} + +SLEEP_WORKFLOW_ID = os.getenv("GCP_WORKFLOWS_SLEEP_WORKFLOW_ID", "sleep_workflow") +SLEEP_WORKFLOW_CONTENT = """ +- someSleep: + call: sys.sleep + args: + seconds: 120 +""" + +SLEEP_WORKFLOW = { + "description": "Test workflow", + "labels": {"airflow-version": "dev"}, + "source_contents": SLEEP_WORKFLOW_CONTENT, +} + + +with DAG( + "example_cloud_workflows", start_date=days_ago(1), schedule_interval=None +) as dag: + # [START how_to_create_workflow] + create_workflow = WorkflowsCreateWorkflowOperator( + task_id="create_workflow", + location=LOCATION, + project_id=PROJECT_ID, + workflow=WORKFLOW, + workflow_id=WORKFLOW_ID, + ) + # [END how_to_create_workflow] + + # [START how_to_update_workflow] + update_workflows = WorkflowsUpdateWorkflowOperator( + task_id="update_workflows", + location=LOCATION, + project_id=PROJECT_ID, + workflow_id=WORKFLOW_ID, + update_mask={"paths": ["name", "description"]}, + ) + # [END how_to_update_workflow] + + # [START how_to_get_workflow] + get_workflow = WorkflowsGetWorkflowOperator( + task_id="get_workflow", + location=LOCATION, + project_id=PROJECT_ID, + workflow_id=WORKFLOW_ID, + ) + # [END how_to_get_workflow] + + # [START how_to_list_workflows] + list_workflows = WorkflowsListWorkflowsOperator( + task_id="list_workflows", + location=LOCATION, + project_id=PROJECT_ID, + ) + # [END how_to_list_workflows] + + # [START how_to_delete_workflow] + delete_workflow = WorkflowsDeleteWorkflowOperator( + task_id="delete_workflow", + location=LOCATION, + project_id=PROJECT_ID, + workflow_id=WORKFLOW_ID, + ) + # [END how_to_delete_workflow] + + # [START how_to_create_execution] + create_execution = WorkflowsCreateExecutionOperator( + task_id="create_execution", + location=LOCATION, + project_id=PROJECT_ID, + execution=EXECUTION, + workflow_id=WORKFLOW_ID, + ) + # [END how_to_create_execution] + + # [START how_to_wait_for_execution] + wait_for_execution = WorkflowExecutionSensor( + task_id="wait_for_execution", + location=LOCATION, + project_id=PROJECT_ID, + workflow_id=WORKFLOW_ID, + execution_id='{{ task_instance.xcom_pull("create_execution", key="execution_id") }}', + ) + # [END how_to_wait_for_execution] + + # [START how_to_get_execution] + get_execution = WorkflowsGetExecutionOperator( + task_id="get_execution", + location=LOCATION, + project_id=PROJECT_ID, + workflow_id=WORKFLOW_ID, + execution_id='{{ task_instance.xcom_pull("create_execution", key="execution_id") }}', + ) + # [END how_to_get_execution] + + # [START how_to_list_executions] + list_executions = WorkflowsListExecutionsOperator( + task_id="list_executions", + location=LOCATION, + project_id=PROJECT_ID, + workflow_id=WORKFLOW_ID, + ) + # [END how_to_list_executions] + + create_workflow_for_cancel = WorkflowsCreateWorkflowOperator( + task_id="create_workflow_for_cancel", + location=LOCATION, + project_id=PROJECT_ID, + workflow=SLEEP_WORKFLOW, + workflow_id=SLEEP_WORKFLOW_ID, + ) + + create_execution_for_cancel = WorkflowsCreateExecutionOperator( + task_id="create_execution_for_cancel", + location=LOCATION, + project_id=PROJECT_ID, + execution=EXECUTION, + workflow_id=SLEEP_WORKFLOW_ID, + ) + + # [START how_to_cancel_execution] + cancel_execution = WorkflowsCancelExecutionOperator( + task_id="cancel_execution", + location=LOCATION, + project_id=PROJECT_ID, + workflow_id=SLEEP_WORKFLOW_ID, + execution_id='{{ task_instance.xcom_pull("create_execution_for_cancel", key="execution_id") }}', + ) + # [END how_to_cancel_execution] + + create_workflow >> update_workflows >> [get_workflow, list_workflows] + update_workflows >> [create_execution, create_execution_for_cancel] + + create_execution >> wait_for_execution >> [get_execution, list_executions] + create_workflow_for_cancel >> create_execution_for_cancel >> cancel_execution + + [cancel_execution, list_executions] >> delete_workflow + + +if __name__ == "__main__": + dag.clear(dag_run_state=None) + dag.run() diff --git a/reference/providers/google/cloud/hooks/__init__.py b/reference/providers/google/cloud/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/cloud/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/cloud/hooks/automl.py b/reference/providers/google/cloud/hooks/automl.py new file mode 100644 index 0000000..79d3d23 --- /dev/null +++ b/reference/providers/google/cloud/hooks/automl.py @@ -0,0 +1,717 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains a Google AutoML hook.""" +from typing import Dict, List, Optional, Sequence, Tuple, Union + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.operation import Operation +from google.api_core.retry import Retry +from google.cloud.automl_v1beta1 import ( + AutoMlClient, + BatchPredictInputConfig, + BatchPredictOutputConfig, + ColumnSpec, + Dataset, + ExamplePayload, + ImageObjectDetectionModelDeploymentMetadata, + InputConfig, + Model, + PredictionServiceClient, + PredictResponse, + TableSpec, +) +from google.protobuf.field_mask_pb2 import FieldMask + + +class CloudAutoMLHook(GoogleBaseHook): + """ + Google Cloud AutoML hook. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client = None # type: Optional[AutoMlClient] + + @staticmethod + def extract_object_id(obj: Dict) -> str: + """Returns unique id of the object.""" + return obj["name"].rpartition("/")[-1] + + def get_conn(self) -> AutoMlClient: + """ + Retrieves connection to AutoML. + + :return: Google Cloud AutoML client object. + :rtype: google.cloud.automl_v1beta1.AutoMlClient + """ + if self._client is None: + self._client = AutoMlClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._client + + @cached_property + def prediction_client(self) -> PredictionServiceClient: + """ + Creates PredictionServiceClient. + + :return: Google Cloud AutoML PredictionServiceClient client object. + :rtype: google.cloud.automl_v1beta1.PredictionServiceClient + """ + return PredictionServiceClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_model( + self, + model: Union[dict, Model], + location: str, + project_id: str, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + retry: Optional[Retry] = None, + ) -> Operation: + """ + Creates a model_id. Returns a Model in the `response` field when it + completes. When you create a model, several model evaluations are + created for it: a global evaluation, and one evaluation for each + annotation spec. + + :param model: The model_id to create. If a dict is provided, it must be of the same form + as the protobuf message `google.cloud.automl_v1beta1.types.Model` + :type model: Union[dict, google.cloud.automl_v1beta1.types.Model] + :param project_id: ID of the Google Cloud project where model will be created if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + + :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance + """ + client = self.get_conn() + parent = f"projects/{project_id}/locations/{location}" + return client.create_model( + request={"parent": parent, "model": model}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def batch_predict( + self, + model_id: str, + input_config: Union[dict, BatchPredictInputConfig], + output_config: Union[dict, BatchPredictOutputConfig], + location: str, + project_id: str, + params: Optional[Dict[str, str]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Perform a batch prediction. Unlike the online `Predict`, batch + prediction result won't be immediately available in the response. + Instead, a long running operation object is returned. + + :param model_id: Name of the model_id requested to serve the batch prediction. + :type model_id: str + :param input_config: Required. The input configuration for batch prediction. + If a dict is provided, it must be of the same form as the protobuf message + `google.cloud.automl_v1beta1.types.BatchPredictInputConfig` + :type input_config: Union[dict, google.cloud.automl_v1beta1.types.BatchPredictInputConfig] + :param output_config: Required. The Configuration specifying where output predictions should be + written. If a dict is provided, it must be of the same form as the protobuf message + `google.cloud.automl_v1beta1.types.BatchPredictOutputConfig` + :type output_config: Union[dict, google.cloud.automl_v1beta1.types.BatchPredictOutputConfig] + :param params: Additional domain-specific parameters for the predictions, any string must be up to + 25000 characters long. + :type params: Optional[Dict[str, str]] + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + + :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance + """ + client = self.prediction_client + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + result = client.batch_predict( + request={ + "name": name, + "input_config": input_config, + "output_config": output_config, + "params": params, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def predict( + self, + model_id: str, + payload: Union[dict, ExamplePayload], + location: str, + project_id: str, + params: Optional[Dict[str, str]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> PredictResponse: + """ + Perform an online prediction. The prediction result will be directly + returned in the response. + + :param model_id: Name of the model_id requested to serve the prediction. + :type model_id: str + :param payload: Required. Payload to perform a prediction on. The payload must match the problem type + that the model_id was trained to solve. If a dict is provided, it must be of + the same form as the protobuf message `google.cloud.automl_v1beta1.types.ExamplePayload` + :type payload: Union[dict, google.cloud.automl_v1beta1.types.ExamplePayload] + :param params: Additional domain-specific parameters, any string must be up to 25000 characters long. + :type params: Optional[Dict[str, str]] + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + + :return: `google.cloud.automl_v1beta1.types.PredictResponse` instance + """ + client = self.prediction_client + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + result = client.predict( + request={"name": name, "payload": payload, "params": params}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_dataset( + self, + dataset: Union[dict, Dataset], + location: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Dataset: + """ + Creates a dataset. + + :param dataset: The dataset to create. If a dict is provided, it must be of the + same form as the protobuf message Dataset. + :type dataset: Union[dict, Dataset] + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + + :return: `google.cloud.automl_v1beta1.types.Dataset` instance. + """ + client = self.get_conn() + parent = f"projects/{project_id}/locations/{location}" + result = client.create_dataset( + request={"parent": parent, "dataset": dataset}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def import_data( + self, + dataset_id: str, + location: str, + input_config: Union[dict, InputConfig], + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Imports data into a dataset. For Tables this method can only be called on an empty Dataset. + + :param dataset_id: Name of the AutoML dataset. + :type dataset_id: str + :param input_config: The desired input location and its domain specific semantics, if any. + If a dict is provided, it must be of the same form as the protobuf message InputConfig. + :type input_config: Union[dict, InputConfig] + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + + :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.import_data( + request={"name": name, "input_config": input_config}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_column_specs( # pylint: disable=too-many-arguments + self, + dataset_id: str, + table_spec_id: str, + location: str, + project_id: str, + field_mask: Union[dict, FieldMask] = None, + filter_: Optional[str] = None, + page_size: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> ColumnSpec: + """ + Lists column specs in a table spec. + + :param dataset_id: Name of the AutoML dataset. + :type dataset_id: str + :param table_spec_id: table_spec_id for path builder. + :type table_spec_id: str + :param field_mask: Mask specifying which fields to read. If a dict is provided, it must be of the same + form as the protobuf message `google.cloud.automl_v1beta1.types.FieldMask` + :type field_mask: Union[dict, google.cloud.automl_v1beta1.types.FieldMask] + :param filter_: Filter expression, see go/filtering. + :type filter_: str + :param page_size: The maximum number of resources contained in the + underlying API response. If page streaming is performed per + resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number + of resources in a page. + :type page_size: int + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + + :return: `google.cloud.automl_v1beta1.types.ColumnSpec` instance. + """ + client = self.get_conn() + parent = client.table_spec_path( + project=project_id, + location=location, + dataset=dataset_id, + table_spec=table_spec_id, + ) + result = client.list_column_specs( + request={ + "parent": parent, + "field_mask": field_mask, + "filter": filter_, + "page_size": page_size, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_model( + self, + model_id: str, + location: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Model: + """ + Gets a AutoML model. + + :param model_id: Name of the model. + :type model_id: str + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + + :return: `google.cloud.automl_v1beta1.types.Model` instance. + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + result = client.get_model( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_model( + self, + model_id: str, + location: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Model: + """ + Deletes a AutoML model. + + :param model_id: Name of the model. + :type model_id: str + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + + :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance. + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + result = client.delete_model( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return result + + def update_dataset( + self, + dataset: Union[dict, Dataset], + update_mask: Union[dict, FieldMask] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Dataset: + """ + Updates a dataset. + + :param dataset: The dataset which replaces the resource on the server. + If a dict is provided, it must be of the same form as the protobuf message Dataset. + :type dataset: Union[dict, Dataset] + :param update_mask: The update mask applies to the resource. If a dict is provided, it must + be of the same form as the protobuf message FieldMask. + :type update_mask: Union[dict, FieldMask] + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + + :return: `google.cloud.automl_v1beta1.types.Dataset` instance.. + """ + client = self.get_conn() + result = client.update_dataset( + request={"dataset": dataset, "update_mask": update_mask}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def deploy_model( + self, + model_id: str, + location: str, + project_id: str, + image_detection_metadata: Union[ + ImageObjectDetectionModelDeploymentMetadata, dict + ] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Deploys a model. If a model is already deployed, deploying it with the same parameters + has no effect. Deploying with different parameters (as e.g. changing node_number) will + reset the deployment state without pausing the model_id’s availability. + + Only applicable for Text Classification, Image Object Detection and Tables; all other + domains manage deployment automatically. + + :param model_id: Name of the model requested to serve the prediction. + :type model_id: str + :param image_detection_metadata: Model deployment metadata specific to Image Object Detection. + If a dict is provided, it must be of the same form as the protobuf message + ImageObjectDetectionModelDeploymentMetadata + :type image_detection_metadata: Union[ImageObjectDetectionModelDeploymentMetadata, dict] + :param project_id: ID of the Google Cloud project where model will be created if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + + :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance. + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + result = client.deploy_model( + request={ + "name": name, + "image_object_detection_model_deployment_metadata": image_detection_metadata, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return result + + def list_table_specs( + self, + dataset_id: str, + location: str, + project_id: Optional[str] = None, + filter_: Optional[str] = None, + page_size: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> List[TableSpec]: + """ + Lists table specs in a dataset_id. + + :param dataset_id: Name of the dataset. + :type dataset_id: str + :param filter_: Filter expression, see go/filtering. + :type filter_: str + :param page_size: The maximum number of resources contained in the + underlying API response. If page streaming is performed per + resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number + of resources in a page. + :type page_size: int + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + + :return: A `google.gax.PageIterator` instance. By default, this + is an iterable of `google.cloud.automl_v1beta1.types.TableSpec` instances. + This object can also be configured to iterate over the pages + of the response through the `options` parameter. + """ + client = self.get_conn() + parent = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.list_table_specs( + request={"parent": parent, "filter": filter_, "page_size": page_size}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_datasets( + self, + location: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Dataset: + """ + Lists datasets in a project. + + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + + :return: A `google.gax.PageIterator` instance. By default, this + is an iterable of `google.cloud.automl_v1beta1.types.Dataset` instances. + This object can also be configured to iterate over the pages + of the response through the `options` parameter. + """ + client = self.get_conn() + parent = f"projects/{project_id}/locations/{location}" + result = client.list_datasets( + request={"parent": parent}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_dataset( + self, + dataset_id: str, + location: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Deletes a dataset and all of its contents. + + :param dataset_id: ID of dataset to be deleted. + :type dataset_id: str + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + + :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.delete_dataset( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return result diff --git a/reference/providers/google/cloud/hooks/bigquery.py b/reference/providers/google/cloud/hooks/bigquery.py new file mode 100644 index 0000000..415e176 --- /dev/null +++ b/reference/providers/google/cloud/hooks/bigquery.py @@ -0,0 +1,3194 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +This module contains a BigQuery Hook, as well as a very basic PEP 249 +implementation for BigQuery. +""" +import hashlib +import json +import logging +import time +import warnings +from copy import deepcopy +from datetime import datetime, timedelta +from typing import ( + Any, + Dict, + Iterable, + List, + Mapping, + NoReturn, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from airflow.exceptions import AirflowException +from airflow.hooks.dbapi import DbApiHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.utils.helpers import convert_camel_to_snake +from airflow.utils.log.logging_mixin import LoggingMixin +from google.api_core.retry import Retry +from google.cloud.bigquery import ( + DEFAULT_RETRY, + Client, + CopyJob, + ExternalConfig, + ExtractJob, + LoadJob, + QueryJob, + SchemaField, +) +from google.cloud.bigquery.dataset import ( + AccessEntry, + Dataset, + DatasetListItem, + DatasetReference, +) +from google.cloud.bigquery.table import ( + EncryptionConfiguration, + Row, + Table, + TableReference, +) +from google.cloud.exceptions import NotFound +from googleapiclient.discovery import Resource, build +from pandas import DataFrame +from pandas_gbq import read_gbq +from pandas_gbq.gbq import GbqConnector +from pandas_gbq.gbq import ( + _check_google_client_version as gbq_check_google_client_version, +) +from pandas_gbq.gbq import _test_google_api_imports as gbq_test_google_api_imports + +log = logging.getLogger(__name__) + +BigQueryJob = Union[CopyJob, QueryJob, LoadJob, ExtractJob] + + +# pylint: disable=too-many-public-methods +class BigQueryHook(GoogleBaseHook, DbApiHook): + """Interact with BigQuery. This hook uses the Google Cloud connection.""" + + conn_name_attr = "gcp_conn_id" + default_conn_name = "google_cloud_default" + conn_type = "google_cloud_platform" + hook_name = "Google Cloud" + + def __init__( + self, + gcp_conn_id: str = default_conn_name, + delegate_to: Optional[str] = None, + use_legacy_sql: bool = True, + location: Optional[str] = None, + bigquery_conn_id: Optional[str] = None, + api_resource_configs: Optional[Dict] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + labels: Optional[Dict] = None, + ) -> None: + # To preserve backward compatibility + # TODO: remove one day + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=2, + ) + gcp_conn_id = bigquery_conn_id + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.use_legacy_sql = use_legacy_sql + self.location = location + self.running_job_id = None # type: Optional[str] + self.api_resource_configs = ( + api_resource_configs if api_resource_configs else {} + ) # type Dict + self.labels = labels + + def get_conn(self) -> "BigQueryConnection": + """Returns a BigQuery PEP 249 connection object.""" + service = self.get_service() + return BigQueryConnection( + service=service, + project_id=self.project_id, + use_legacy_sql=self.use_legacy_sql, + location=self.location, + num_retries=self.num_retries, + hook=self, + ) + + def get_service(self) -> Re# + """Returns a BigQuery service object.""" + warnings.warn( + "This method will be deprecated. Please use `BigQueryHook.get_client` method", + DeprecationWarning, + ) + http_authorized = self._authorize() + return build("bigquery", "v2", http=http_authorized, cache_discovery=False) + + def get_client( + self, project_id: Optional[str] = None, location: Optional[str] = None + ) -> Client: + """ + Returns authenticated BigQuery Client. + + :param project_id: Project ID for the project which the client acts on behalf of. + :type project_id: str + :param location: Default location for jobs / datasets / tables. + :type location: str + :return: + """ + return Client( + client_info=self.client_info, + project=project_id, + location=location, + credentials=self._get_credentials(), + ) + + @staticmethod + def _resolve_table_reference( + table_re# Dict[str, Any], + project_id: Optional[str] = None, + dataset_id: Optional[str] = None, + table_id: Optional[str] = None, + ) -> Dict[str, Any]: + try: + # Check if tableReference is present and is valid + TableReference.from_api_repr(table_resource["tableReference"]) + except KeyError: + # Something is wrong so we try to build the reference + table_resource["tableReference"] = table_resource.get("tableReference", {}) + values = [ + ("projectId", project_id), + ("tableId", table_id), + ("datasetId", dataset_id), + ] + for key, value in values: + # Check if value is already present if no use the provided one + resolved_value = table_resource["tableReference"].get(key, value) + if not resolved_value: + # If there's no value in tableReference and provided one is None raise error + raise AirflowException( + f"Table resource is missing proper `tableReference` and `{key}` is None" + ) + table_resource["tableReference"][key] = resolved_value + return table_resource + + def insert_rows( + self, + table: Any, + rows: Any, + target_fields: Any = None, + commit_every: Any = 1000, + replace: Any = False, + **kwargs, + ) -> None: + """ + Insertion is currently unsupported. Theoretically, you could use + BigQuery's streaming API to insert rows into a table, but this hasn't + been implemented. + """ + raise NotImplementedError() + + def get_pandas_df( + self, + sql: str, + parameters: Optional[Union[Iterable, Mapping]] = None, + dialect: Optional[str] = None, + **kwargs, + ) -> DataFrame: + """ + Returns a Pandas DataFrame for the results produced by a BigQuery + query. The DbApiHook method must be overridden because Pandas + doesn't support PEP 249 connections, except for SQLite. See: + + https://github.com/pydata/pandas/blob/master/pandas/io/sql.py#L447 + https://github.com/pydata/pandas/issues/6900 + + :param sql: The BigQuery SQL to execute. + :type sql: str + :param parameters: The parameters to render the SQL query with (not + used, leave to override superclass method) + :type parameters: mapping or iterable + :param dialect: Dialect of BigQuery SQL – legacy SQL or standard SQL + defaults to use `self.use_legacy_sql` if not specified + :type dialect: str in {'legacy', 'standard'} + :param kwargs: (optional) passed into pandas_gbq.read_gbq method + :type kwargs: dict + """ + if dialect is None: + dialect = "legacy" if self.use_legacy_sql else "standard" + + credentials, project_id = self._get_credentials_and_project_id() + + return read_gbq( + sql, + project_id=project_id, + dialect=dialect, + verbose=False, + credentials=credentials, + **kwargs, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def table_exists(self, dataset_id: str, table_id: str, project_id: str) -> bool: + """ + Checks for the existence of a table in Google BigQuery. + + :param project_id: The Google cloud project in which to look for the + table. The connection supplied to the hook must provide access to + the specified project. + :type project_id: str + :param dataset_id: The name of the dataset in which to look for the + table. + :type dataset_id: str + :param table_id: The name of the table to check the existence of. + :type table_id: str + """ + table_reference = TableReference( + DatasetReference(project_id, dataset_id), table_id + ) + try: + self.get_client(project_id=project_id).get_table(table_reference) + return True + except NotFound: + return False + + @GoogleBaseHook.fallback_to_default_project_id + def table_partition_exists( + self, dataset_id: str, table_id: str, partition_id: str, project_id: str + ) -> bool: + """ + Checks for the existence of a partition in a table in Google BigQuery. + + :param project_id: The Google cloud project in which to look for the + table. The connection supplied to the hook must provide access to + the specified project. + :type project_id: str + :param dataset_id: The name of the dataset in which to look for the + table. + :type dataset_id: str + :param table_id: The name of the table to check the existence of. + :type table_id: str + :param partition_id: The name of the partition to check the existence of. + :type partition_id: str + """ + table_reference = TableReference( + DatasetReference(project_id, dataset_id), table_id + ) + try: + return partition_id in self.get_client( + project_id=project_id + ).list_partitions(table_reference) + except NotFound: + return False + + @GoogleBaseHook.fallback_to_default_project_id + def create_empty_table( # pylint: disable=too-many-arguments + self, + project_id: Optional[str] = None, + dataset_id: Optional[str] = None, + table_id: Optional[str] = None, + table_re# Optional[Dict[str, Any]] = None, + schema_fields: Optional[List] = None, + time_partitioning: Optional[Dict] = None, + cluster_fields: Optional[List[str]] = None, + labels: Optional[Dict] = None, + view: Optional[Dict] = None, + materialized_view: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None, + retry: Optional[Retry] = DEFAULT_RETRY, + num_retries: Optional[int] = None, + location: Optional[str] = None, + exists_ok: bool = True, + ) -> Table: + """ + Creates a new, empty table in the dataset. + To create a view, which is defined by a SQL query, parse a dictionary to 'view' kwarg + + :param project_id: The project to create the table into. + :type project_id: str + :param dataset_id: The dataset to create the table into. + :type dataset_id: str + :param table_id: The Name of the table to be created. + :type table_id: str + :param table_re# Table resource as described in documentation: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#Table + If provided all other parameters are ignored. + :type table_re# Dict[str, Any] + :param schema_fields: If set, the schema field list as defined here: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema + :type schema_fields: list + :param labels: a dictionary containing labels for the table, passed to BigQuery + :type labels: dict + :param retry: Optional. How to retry the RPC. + :type retry: google.api_core.retry.Retry + + **Example**: :: + + schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}] + + :param time_partitioning: configure optional time partitioning fields i.e. + partition by field, type and expiration as per API specifications. + + .. seealso:: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#timePartitioning + :type time_partitioning: dict + :param cluster_fields: [Optional] The fields used for clustering. + BigQuery supports clustering for both partitioned and + non-partitioned tables. + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#clustering.fields + :type cluster_fields: list + :param view: [Optional] A dictionary containing definition for the view. + If set, it will create a view instead of a table: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#ViewDefinition + :type view: dict + + **Example**: :: + + view = { + "query": "SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*` LIMIT 1000", + "useLegacySql": False + } + + :param materialized_view: [Optional] The materialized view definition. + :type materialized_view: dict + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + :param num_retries: Maximum number of retries in case of connection problems. + :type num_retries: int + :param exists_ok: If ``True``, ignore "already exists" errors when creating the table. + :type exists_ok: bool + :return: Created table + """ + if num_retries: + warnings.warn("Parameter `num_retries` is deprecated", DeprecationWarning) + + _table_re# Dict[str, Any] = {} + + if self.location: + _table_resource["location"] = self.location + + if schema_fields: + _table_resource["schema"] = {"fields": schema_fields} + + if time_partitioning: + _table_resource["timePartitioning"] = time_partitioning + + if cluster_fields: + _table_resource["clustering"] = {"fields": cluster_fields} + + if labels: + _table_resource["labels"] = labels + + if view: + _table_resource["view"] = view + + if materialized_view: + _table_resource["materializedView"] = materialized_view + + if encryption_configuration: + _table_resource["encryptionConfiguration"] = encryption_configuration + + table_resource = table_resource or _table_resource + table_resource = self._resolve_table_reference( + table_resource=table_resource, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + ) + table = Table.from_api_repr(table_resource) + return self.get_client(project_id=project_id, location=location).create_table( + table=table, exists_ok=exists_ok, retry=retry + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_empty_dataset( + self, + dataset_id: Optional[str] = None, + project_id: Optional[str] = None, + location: Optional[str] = None, + dataset_reference: Optional[Dict[str, Any]] = None, + exists_ok: bool = True, + ) -> None: + """ + Create a new empty dataset: + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/insert + + :param project_id: The name of the project where we want to create + an empty a dataset. Don't need to provide, if projectId in dataset_reference. + :type project_id: str + :param dataset_id: The id of dataset. Don't need to provide, if datasetId in dataset_reference. + :type dataset_id: str + :param location: (Optional) The geographic location where the dataset should reside. + There is no default value but the dataset will be created in US if nothing is provided. + :type location: str + :param dataset_reference: Dataset reference that could be provided with request body. More info: + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + :type dataset_reference: dict + :param exists_ok: If ``True``, ignore "already exists" errors when creating the dataset. + :type exists_ok: bool + """ + dataset_reference = dataset_reference or {"datasetReference": {}} + + for param, value in zip(["datasetId", "projectId"], [dataset_id, project_id]): + specified_param = dataset_reference["datasetReference"].get(param) + if specified_param: + if value: + self.log.info( + "`%s` was provided in both `dataset_reference` and as `%s`. " + "Using value from `dataset_reference`", + param, + convert_camel_to_snake(param), + ) + continue # use specified value + if not value: + raise ValueError( + f"Please specify `{param}` either in `dataset_reference` " + f"or by providing `{convert_camel_to_snake(param)}`", + ) + # dataset_reference has no param but we can fallback to default value + self.log.info( + "%s was not specified in `dataset_reference`. Will use default value %s.", + param, + value, + ) + dataset_reference["datasetReference"][param] = value + + location = location or self.location + if location: + dataset_reference["location"] = dataset_reference.get("location", location) + + dataset: Dataset = Dataset.from_api_repr(dataset_reference) + self.log.info( + "Creating dataset: %s in project: %s ", dataset.dataset_id, dataset.project + ) + self.get_client(location=location).create_dataset( + dataset=dataset, exists_ok=exists_ok + ) + self.log.info("Dataset created successfully.") + + @GoogleBaseHook.fallback_to_default_project_id + def get_dataset_tables( + self, + dataset_id: str, + project_id: Optional[str] = None, + max_results: Optional[int] = None, + retry: Retry = DEFAULT_RETRY, + ) -> List[Dict[str, Any]]: + """ + Get the list of tables for a given dataset. + + For more information, see: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables/list + + :param dataset_id: the dataset ID of the requested dataset. + :type dataset_id: str + :param project_id: (Optional) the project of the requested dataset. If None, + self.project_id will be used. + :type project_id: str + :param max_results: (Optional) the maximum number of tables to return. + :type max_results: int + :param retry: How to retry the RPC. + :type retry: google.api_core.retry.Retry + :return: List of tables associated with the dataset. + """ + self.log.info( + "Start getting tables list from dataset: %s.%s", project_id, dataset_id + ) + tables = self.get_client().list_tables( + dataset=DatasetReference(project=project_id, dataset_id=dataset_id), + max_results=max_results, + retry=retry, + ) + # Convert to a list (consumes all values) + return [t.reference.to_api_repr() for t in tables] + + @GoogleBaseHook.fallback_to_default_project_id + def delete_dataset( + self, + dataset_id: str, + project_id: Optional[str] = None, + delete_contents: bool = False, + retry: Retry = DEFAULT_RETRY, + ) -> None: + """ + Delete a dataset of Big query in your project. + + :param project_id: The name of the project where we have the dataset. + :type project_id: str + :param dataset_id: The dataset to be delete. + :type dataset_id: str + :param delete_contents: If True, delete all the tables in the dataset. + If False and the dataset contains tables, the request will fail. + :type delete_contents: bool + :param retry: How to retry the RPC. + :type retry: google.api_core.retry.Retry + """ + self.log.info("Deleting from project: %s Dataset:%s", project_id, dataset_id) + self.get_client(project_id=project_id).delete_dataset( + dataset=DatasetReference(project=project_id, dataset_id=dataset_id), + delete_contents=delete_contents, + retry=retry, + not_found_ok=True, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_external_table( # pylint: disable=too-many-locals,too-many-arguments + self, + external_project_dataset_table: str, + schema_fields: List, + source_uris: List, + source_format: str = "CSV", + autodetect: bool = False, + compression: str = "NONE", + ignore_unknown_values: bool = False, + max_bad_records: int = 0, + skip_leading_rows: int = 0, + field_delimiter: str = ",", + quote_character: Optional[str] = None, + allow_quoted_newlines: bool = False, + allow_jagged_rows: bool = False, + encoding: str = "UTF-8", + src_fmt_configs: Optional[Dict] = None, + labels: Optional[Dict] = None, + description: Optional[str] = None, + encryption_configuration: Optional[Dict] = None, + location: Optional[str] = None, + project_id: Optional[str] = None, + ) -> None: + """ + Creates a new external table in the dataset with the data from Google + Cloud Storage. See here: + + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#resource + + for more details about these parameters. + + :param external_project_dataset_table: + The dotted ``(.|:).($)`` BigQuery + table name to create external table. + If ```` is not included, project will be the + project defined in the connection json. + :type external_project_dataset_table: str + :param schema_fields: The schema field list as defined here: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#resource + :type schema_fields: list + :param source_uris: The source Google Cloud + Storage URI (e.g. gs://some-bucket/some-file.txt). A single wild + per-object name can be used. + :type source_uris: list + :param source_format: File format to export. + :type source_format: str + :param autodetect: Try to detect schema and format options automatically. + Any option specified explicitly will be honored. + :type autodetect: bool + :param compression: [Optional] The compression type of the data source. + Possible values include GZIP and NONE. + The default value is NONE. + This setting is ignored for Google Cloud Bigtable, + Google Cloud Datastore backups and Avro formats. + :type compression: str + :param ignore_unknown_values: [Optional] Indicates if BigQuery should allow + extra values that are not represented in the table schema. + If true, the extra values are ignored. If false, records with extra columns + are treated as bad records, and if there are too many bad records, an + invalid error is returned in the job result. + :type ignore_unknown_values: bool + :param max_bad_records: The maximum number of bad records that BigQuery can + ignore when running the job. + :type max_bad_records: int + :param skip_leading_rows: Number of rows to skip when loading from a CSV. + :type skip_leading_rows: int + :param field_delimiter: The delimiter to use when loading from a CSV. + :type field_delimiter: str + :param quote_character: The value that is used to quote data sections in a CSV + file. + :type quote_character: str + :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not + (false). + :type allow_quoted_newlines: bool + :param allow_jagged_rows: Accept rows that are missing trailing optional columns. + The missing values are treated as nulls. If false, records with missing + trailing columns are treated as bad records, and if there are too many bad + records, an invalid error is returned in the job result. Only applicable when + source_format is CSV. + :type allow_jagged_rows: bool + :param encoding: The character encoding of the data. See: + + .. seealso:: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#externalDataConfiguration.csvOptions.encoding + :type encoding: str + :param src_fmt_configs: configure optional fields specific to the source format + :type src_fmt_configs: dict + :param labels: A dictionary containing labels for the BiqQuery table. + :type labels: dict + :param description: A string containing the description for the BigQuery table. + :type descriptin: str + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + """ + warnings.warn( + "This method is deprecated. Please use `BigQueryHook.create_empty_table` method with" + "pass passing the `table_resource` object. This gives more flexibility than this method.", + DeprecationWarning, + ) + location = location or self.location + src_fmt_configs = src_fmt_configs or {} + source_format = source_format.upper() + compression = compression.upper() + + external_config_api_repr = { + "autodetect": autodetect, + "sourceFormat": source_format, + "sourceUris": source_uris, + "compression": compression, + "ignoreUnknownValues": ignore_unknown_values, + } + + # if following fields are not specified in src_fmt_configs, + # honor the top-level params for backward-compatibility + backward_compatibility_configs = { + "skipLeadingRows": skip_leading_rows, + "fieldDelimiter": field_delimiter, + "quote": quote_character, + "allowQuotedNewlines": allow_quoted_newlines, + "allowJaggedRows": allow_jagged_rows, + "encoding": encoding, + } + src_fmt_to_param_mapping = { + "CSV": "csvOptions", + "GOOGLE_SHEETS": "googleSheetsOptions", + } + src_fmt_to_configs_mapping = { + "csvOptions": [ + "allowJaggedRows", + "allowQuotedNewlines", + "fieldDelimiter", + "skipLeadingRows", + "quote", + "encoding", + ], + "googleSheetsOptions": ["skipLeadingRows"], + } + if source_format in src_fmt_to_param_mapping.keys(): + valid_configs = src_fmt_to_configs_mapping[ + src_fmt_to_param_mapping[source_format] + ] + src_fmt_configs = _validate_src_fmt_configs( + source_format, + src_fmt_configs, + valid_configs, + backward_compatibility_configs, + ) + external_config_api_repr[ + src_fmt_to_param_mapping[source_format] + ] = src_fmt_configs + + # build external config + external_config = ExternalConfig.from_api_repr(external_config_api_repr) + if schema_fields: + external_config.schema = [ + SchemaField.from_api_repr(f) for f in schema_fields + ] + if max_bad_records: + external_config.max_bad_records = max_bad_records + + # build table definition + table = Table( + table_ref=TableReference.from_string( + external_project_dataset_table, project_id + ) + ) + table.external_data_configuration = external_config + if labels: + table.labels = labels + + if description: + table.description = description + + if encryption_configuration: + table.encryption_configuration = EncryptionConfiguration.from_api_repr( + encryption_configuration + ) + + self.log.info("Creating external table: %s", external_project_dataset_table) + self.create_empty_table( + table_resource=table.to_api_repr(), + project_id=project_id, + location=location, + exists_ok=True, + ) + self.log.info( + "External table created successfully: %s", external_project_dataset_table + ) + + @GoogleBaseHook.fallback_to_default_project_id + def update_table( + self, + table_re# Dict[str, Any], + fields: Optional[List[str]] = None, + dataset_id: Optional[str] = None, + table_id: Optional[str] = None, + project_id: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Change some fields of a table. + + Use ``fields`` to specify which fields to update. At least one field + must be provided. If a field is listed in ``fields`` and is ``None`` + in ``table``, the field value will be deleted. + + If ``table.etag`` is not ``None``, the update will only succeed if + the table on the server has the same ETag. Thus reading a table with + ``get_table``, changing its fields, and then passing it to + ``update_table`` will ensure that the changes will only be saved if + no modifications to the table occurred since the read. + + :param project_id: The project to create the table into. + :type project_id: str + :param dataset_id: The dataset to create the table into. + :type dataset_id: str + :param table_id: The Name of the table to be created. + :type table_id: str + :param table_re# Table resource as described in documentation: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#Table + The table has to contain ``tableReference`` or ``project_id``, ``dataset_id`` and ``table_id`` + have to be provided. + :type table_re# Dict[str, Any] + :param fields: The fields of ``table`` to change, spelled as the Table + properties (e.g. "friendly_name"). + :type fields: List[str] + """ + fields = fields or list(table_resource.keys()) + table_resource = self._resolve_table_reference( + table_resource=table_resource, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + ) + + table = Table.from_api_repr(table_resource) + self.log.info("Updating table: %s", table_resource["tableReference"]) + table_object = self.get_client(project_id=project_id).update_table( + table=table, fields=fields + ) + self.log.info( + "Table %s.%s.%s updated successfully", project_id, dataset_id, table_id + ) + return table_object.to_api_repr() + + @GoogleBaseHook.fallback_to_default_project_id + def patch_table( # pylint: disable=too-many-arguments + self, + dataset_id: str, + table_id: str, + project_id: Optional[str] = None, + description: Optional[str] = None, + expiration_time: Optional[int] = None, + external_data_configuration: Optional[Dict] = None, + friendly_name: Optional[str] = None, + labels: Optional[Dict] = None, + schema: Optional[List] = None, + time_partitioning: Optional[Dict] = None, + view: Optional[Dict] = None, + require_partition_filter: Optional[bool] = None, + encryption_configuration: Optional[Dict] = None, + ) -> None: + """ + Patch information in an existing table. + It only updates fields that are provided in the request object. + + Reference: https://cloud.google.com/bigquery/docs/reference/rest/v2/tables/patch + + :param dataset_id: The dataset containing the table to be patched. + :type dataset_id: str + :param table_id: The Name of the table to be patched. + :type table_id: str + :param project_id: The project containing the table to be patched. + :type project_id: str + :param description: [Optional] A user-friendly description of this table. + :type description: str + :param expiration_time: [Optional] The time when this table expires, + in milliseconds since the epoch. + :type expiration_time: int + :param external_data_configuration: [Optional] A dictionary containing + properties of a table stored outside of BigQuery. + :type external_data_configuration: dict + :param friendly_name: [Optional] A descriptive name for this table. + :type friendly_name: str + :param labels: [Optional] A dictionary containing labels associated with this table. + :type labels: dict + :param schema: [Optional] If set, the schema field list as defined here: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema + The supported schema modifications and unsupported schema modification are listed here: + https://cloud.google.com/bigquery/docs/managing-table-schemas + **Example**: :: + + schema=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}] + + :type schema: list + :param time_partitioning: [Optional] A dictionary containing time-based partitioning + definition for the table. + :type time_partitioning: dict + :param view: [Optional] A dictionary containing definition for the view. + If set, it will patch a view instead of a table: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#ViewDefinition + **Example**: :: + + view = { + "query": "SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*` LIMIT 500", + "useLegacySql": False + } + + :type view: dict + :param require_partition_filter: [Optional] If true, queries over the this table require a + partition filter. If false, queries over the table + :type require_partition_filter: bool + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + + """ + warnings.warn( + "This method is deprecated, please use ``BigQueryHook.update_table`` method.", + DeprecationWarning, + ) + table_re# Dict[str, Any] = {} + + if description is not None: + table_resource["description"] = description + if expiration_time is not None: + table_resource["expirationTime"] = expiration_time + if external_data_configuration: + table_resource["externalDataConfiguration"] = external_data_configuration + if friendly_name is not None: + table_resource["friendlyName"] = friendly_name + if labels: + table_resource["labels"] = labels + if schema: + table_resource["schema"] = {"fields": schema} + if time_partitioning: + table_resource["timePartitioning"] = time_partitioning + if view: + table_resource["view"] = view + if require_partition_filter is not None: + table_resource["requirePartitionFilter"] = require_partition_filter + if encryption_configuration: + table_resource["encryptionConfiguration"] = encryption_configuration + + self.update_table( + table_resource=table_resource, + fields=list(table_resource.keys()), + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def insert_all( + self, + project_id: str, + dataset_id: str, + table_id: str, + rows: List, + ignore_unknown_values: bool = False, + skip_invalid_rows: bool = False, + fail_on_error: bool = False, + ) -> None: + """ + Method to stream data into BigQuery one record at a time without needing + to run a load job + + .. seealso:: + For more information, see: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tabledata/insertAll + + :param project_id: The name of the project where we have the table + :type project_id: str + :param dataset_id: The name of the dataset where we have the table + :type dataset_id: str + :param table_id: The name of the table + :type table_id: str + :param rows: the rows to insert + :type rows: list + + **Example or rows**: + rows=[{"json": {"a_key": "a_value_0"}}, {"json": {"a_key": "a_value_1"}}] + + :param ignore_unknown_values: [Optional] Accept rows that contain values + that do not match the schema. The unknown values are ignored. + The default value is false, which treats unknown values as errors. + :type ignore_unknown_values: bool + :param skip_invalid_rows: [Optional] Insert all valid rows of a request, + even if invalid rows exist. The default value is false, which causes + the entire request to fail if any invalid rows exist. + :type skip_invalid_rows: bool + :param fail_on_error: [Optional] Force the task to fail if any errors occur. + The default value is false, which indicates the task should not fail + even if any insertion errors occur. + :type fail_on_error: bool + """ + self.log.info( + "Inserting %s row(s) into table %s:%s.%s", + len(rows), + project_id, + dataset_id, + table_id, + ) + + table_ref = TableReference( + dataset_ref=DatasetReference(project_id, dataset_id), table_id=table_id + ) + bq_client = self.get_client(project_id=project_id) + table = bq_client.get_table(table_ref) + errors = bq_client.insert_rows( + table=table, + rows=rows, + ignore_unknown_values=ignore_unknown_values, + skip_invalid_rows=skip_invalid_rows, + ) + if errors: + error_msg = f"{len(errors)} insert error(s) occurred. Details: {errors}" + self.log.error(error_msg) + if fail_on_error: + raise AirflowException(f"BigQuery job failed. Error was: {error_msg}") + else: + self.log.info( + "All row(s) inserted successfully: %s:%s.%s", + project_id, + dataset_id, + table_id, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def update_dataset( + self, + fields: Sequence[str], + dataset_re# Dict[str, Any], + dataset_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Retry = DEFAULT_RETRY, + ) -> Dataset: + """ + Change some fields of a dataset. + + Use ``fields`` to specify which fields to update. At least one field + must be provided. If a field is listed in ``fields`` and is ``None`` in + ``dataset``, it will be deleted. + + If ``dataset.etag`` is not ``None``, the update will only + succeed if the dataset on the server has the same ETag. Thus + reading a dataset with ``get_dataset``, changing its fields, + and then passing it to ``update_dataset`` will ensure that the changes + will only be saved if no modifications to the dataset occurred + since the read. + + :param dataset_re# Dataset resource that will be provided + in request body. + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + :type dataset_re# dict + :param dataset_id: The id of the dataset. + :type dataset_id: str + :param fields: The properties of ``dataset`` to change (e.g. "friendly_name"). + :type fields: Sequence[str] + :param project_id: The Google Cloud Project ID + :type project_id: str + :param retry: How to retry the RPC. + :type retry: google.api_core.retry.Retry + """ + dataset_resource["datasetReference"] = dataset_resource.get( + "datasetReference", {} + ) + + for key, value in zip(["datasetId", "projectId"], [dataset_id, project_id]): + spec_value = dataset_resource["datasetReference"].get(key) + if value and not spec_value: + dataset_resource["datasetReference"][key] = value + + self.log.info("Start updating dataset") + dataset = self.get_client(project_id=project_id).update_dataset( + dataset=Dataset.from_api_repr(dataset_resource), + fields=fields, + retry=retry, + ) + self.log.info("Dataset successfully updated: %s", dataset) + return dataset + + def patch_dataset( + self, dataset_id: str, dataset_re# Dict, project_id: Optional[str] = None + ) -> Dict: + """ + Patches information in an existing dataset. + It only replaces fields that are provided in the submitted dataset resource. + More info: + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/patch + + :param dataset_id: The BigQuery Dataset ID + :type dataset_id: str + :param dataset_re# Dataset resource that will be provided + in request body. + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + :type dataset_re# dict + :param project_id: The Google Cloud Project ID + :type project_id: str + :rtype: dataset + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + """ + warnings.warn( + "This method is deprecated. Please use ``update_dataset``.", + DeprecationWarning, + ) + project_id = project_id or self.project_id + if not dataset_id or not isinstance(dataset_id, str): + raise ValueError( + "dataset_id argument must be provided and has " + "a type 'str'. You provided: {}".format(dataset_id) + ) + + service = self.get_service() + dataset_project_id = project_id or self.project_id + + self.log.info("Start patching dataset: %s:%s", dataset_project_id, dataset_id) + dataset = ( + service.datasets() # pylint: disable=no-member + .patch( + datasetId=dataset_id, + projectId=dataset_project_id, + body=dataset_resource, + ) + .execute(num_retries=self.num_retries) + ) + self.log.info("Dataset successfully patched: %s", dataset) + + return dataset + + def get_dataset_tables_list( + self, + dataset_id: str, + project_id: Optional[str] = None, + table_prefix: Optional[str] = None, + max_results: Optional[int] = None, + ) -> List[Dict[str, Any]]: + """ + Method returns tables list of a BigQuery tables. If table prefix is specified, + only tables beginning by it are returned. + + For more information, see: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables/list + + :param dataset_id: The BigQuery Dataset ID + :type dataset_id: str + :param project_id: The Google Cloud Project ID + :type project_id: str + :param table_prefix: Tables must begin by this prefix to be returned (case sensitive) + :type table_prefix: str + :param max_results: The maximum number of results to return in a single response page. + Leverage the page tokens to iterate through the entire collection. + :type max_results: int + :return: List of tables associated with the dataset + """ + warnings.warn( + "This method is deprecated. Please use ``get_dataset_tables``.", + DeprecationWarning, + ) + project_id = project_id or self.project_id + tables = self.get_client().list_tables( + dataset=DatasetReference(project=project_id, dataset_id=dataset_id), + max_results=max_results, + ) + + if table_prefix: + result = [ + t.reference.to_api_repr() + for t in tables + if t.table_id.startswith(table_prefix) + ] + else: + result = [t.reference.to_api_repr() for t in tables] + + self.log.info("%s tables found", len(result)) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_datasets_list( + self, + project_id: Optional[str] = None, + include_all: bool = False, + filter_: Optional[str] = None, + max_results: Optional[int] = None, + page_token: Optional[str] = None, + retry: Retry = DEFAULT_RETRY, + ) -> List[DatasetListItem]: + """ + Method returns full list of BigQuery datasets in the current project + + For more information, see: + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/list + + :param project_id: Google Cloud Project for which you try to get all datasets + :type project_id: str + :param include_all: True if results include hidden datasets. Defaults to False. + :param filter_: An expression for filtering the results by label. For syntax, see + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/list#filter. + :param filter_: str + :param max_results: Maximum number of datasets to return. + :param max_results: int + :param page_token: Token representing a cursor into the datasets. If not passed, + the API will return the first page of datasets. The token marks the beginning of the + iterator to be returned and the value of the ``page_token`` can be accessed at + ``next_page_token`` of the :class:`~google.api_core.page_iterator.HTTPIterator`. + :param page_token: str + :param retry: How to retry the RPC. + :type retry: google.api_core.retry.Retry + """ + datasets = self.get_client(project_id=project_id).list_datasets( + project=project_id, + include_all=include_all, + filter=filter_, + max_results=max_results, + page_token=page_token, + retry=retry, + ) + datasets_list = list(datasets) + + self.log.info("Datasets List: %s", len(datasets_list)) + return datasets_list + + @GoogleBaseHook.fallback_to_default_project_id + def get_dataset(self, dataset_id: str, project_id: Optional[str] = None) -> Dataset: + """ + Fetch the dataset referenced by dataset_id. + + :param dataset_id: The BigQuery Dataset ID + :type dataset_id: str + :param project_id: The Google Cloud Project ID + :type project_id: str + :return: dataset_resource + + .. seealso:: + For more information, see Dataset Resource content: + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + """ + dataset = self.get_client(project_id=project_id).get_dataset( + dataset_ref=DatasetReference(project_id, dataset_id) + ) + self.log.info("Dataset Re# %s", dataset) + return dataset + + @GoogleBaseHook.fallback_to_default_project_id + def run_grant_dataset_view_access( + self, + source_dataset: str, + view_dataset: str, + view_table: str, + source_project: Optional[str] = None, + view_project: Optional[str] = None, + project_id: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Grant authorized view access of a dataset to a view table. + If this view has already been granted access to the dataset, do nothing. + This method is not atomic. Running it may clobber a simultaneous update. + + :param source_dataset: the source dataset + :type source_dataset: str + :param view_dataset: the dataset that the view is in + :type view_dataset: str + :param view_table: the table of the view + :type view_table: str + :param project_id: the project of the source dataset. If None, + self.project_id will be used. + :type project_id: str + :param view_project: the project that the view is in. If None, + self.project_id will be used. + :type view_project: str + :return: the datasets resource of the source dataset. + """ + if source_project: + project_id = source_project + warnings.warn( + "Parameter ``source_project`` is deprecated. Use ``project_id``.", + DeprecationWarning, + ) + view_project = view_project or project_id + view_access = AccessEntry( + role=None, + entity_type="view", + entity_id={ + "projectId": view_project, + "datasetId": view_dataset, + "tableId": view_table, + }, + ) + + dataset = self.get_dataset(project_id=project_id, dataset_id=source_dataset) + + # Check to see if the view we want to add already exists. + if view_access not in dataset.access_entries: + self.log.info( + "Granting table %s:%s.%s authorized view access to %s:%s dataset.", + view_project, + view_dataset, + view_table, + project_id, + source_dataset, + ) + dataset.access_entries += [view_access] + dataset = self.update_dataset( + fields=["access"], + dataset_resource=dataset.to_api_repr(), + project_id=project_id, + ) + else: + self.log.info( + "Table %s:%s.%s already has authorized view access to %s:%s dataset.", + view_project, + view_dataset, + view_table, + project_id, + source_dataset, + ) + return dataset.to_api_repr() + + @GoogleBaseHook.fallback_to_default_project_id + def run_table_upsert( + self, + dataset_id: str, + table_re# Dict[str, Any], + project_id: Optional[str] = None, + ) -> Dict[str, Any]: + """ + If the table already exists, update the existing table if not create new. + Since BigQuery does not natively allow table upserts, this is not an + atomic operation. + + :param dataset_id: the dataset to upsert the table into. + :type dataset_id: str + :param table_re# a table resource. see + https://cloud.google.com/bigquery/docs/reference/v2/tables#resource + :type table_re# dict + :param project_id: the project to upsert the table into. If None, + project will be self.project_id. + :return: + """ + table_id = table_resource["tableReference"]["tableId"] + table_resource = self._resolve_table_reference( + table_resource=table_resource, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + ) + + tables_list_resp = self.get_dataset_tables( + dataset_id=dataset_id, project_id=project_id + ) + if any(table["tableId"] == table_id for table in tables_list_resp): + self.log.info( + "Table %s:%s.%s exists, updating.", project_id, dataset_id, table_id + ) + table = self.update_table(table_resource=table_resource) + else: + self.log.info( + "Table %s:%s.%s does not exist. creating.", + project_id, + dataset_id, + table_id, + ) + table = self.create_empty_table( + table_resource=table_resource, project_id=project_id + ).to_api_repr() + return table + + def run_table_delete( + self, deletion_dataset_table: str, ignore_if_missing: bool = False + ) -> None: + """ + Delete an existing table from the dataset; + If the table does not exist, return an error unless ignore_if_missing + is set to True. + + :param deletion_dataset_table: A dotted + ``(.|:).
`` that indicates which table + will be deleted. + :type deletion_dataset_table: str + :param ignore_if_missing: if True, then return success even if the + requested table does not exist. + :type ignore_if_missing: bool + :return: + """ + warnings.warn( + "This method is deprecated. Please use `delete_table`.", DeprecationWarning + ) + return self.delete_table( + table_id=deletion_dataset_table, not_found_ok=ignore_if_missing + ) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_table( + self, + table_id: str, + not_found_ok: bool = True, + project_id: Optional[str] = None, + ) -> None: + """ + Delete an existing table from the dataset. If the table does not exist, return an error + unless not_found_ok is set to True. + + :param table_id: A dotted ``(.|:).
`` + that indicates which table will be deleted. + :type table_id: str + :param not_found_ok: if True, then return success even if the + requested table does not exist. + :type not_found_ok: bool + :param project_id: the project used to perform the request + :type project_id: str + """ + self.get_client(project_id=project_id).delete_table( + table=Table.from_string(table_id), + not_found_ok=not_found_ok, + ) + self.log.info("Deleted table %s", table_id) + + def get_tabledata( + self, + dataset_id: str, + table_id: str, + max_results: Optional[int] = None, + selected_fields: Optional[str] = None, + page_token: Optional[str] = None, + start_index: Optional[int] = None, + ) -> List[Dict]: + """ + Get the data of a given dataset.table and optionally with selected columns. + see https://cloud.google.com/bigquery/docs/reference/v2/tabledata/list + + :param dataset_id: the dataset ID of the requested table. + :param table_id: the table ID of the requested table. + :param max_results: the maximum results to return. + :param selected_fields: List of fields to return (comma-separated). If + unspecified, all fields are returned. + :param page_token: page token, returned from a previous call, + identifying the result set. + :param start_index: zero based index of the starting row to read. + :return: list of rows + """ + warnings.warn( + "This method is deprecated. Please use `list_rows`.", DeprecationWarning + ) + rows = self.list_rows( + dataset_id, table_id, max_results, selected_fields, page_token, start_index + ) + return [dict(r) for r in rows] + + @GoogleBaseHook.fallback_to_default_project_id + def list_rows( + self, + dataset_id: str, + table_id: str, + max_results: Optional[int] = None, + selected_fields: Optional[Union[List[str], str]] = None, + page_token: Optional[str] = None, + start_index: Optional[int] = None, + project_id: Optional[str] = None, + location: Optional[str] = None, + ) -> List[Row]: + """ + List the rows of the table. + See https://cloud.google.com/bigquery/docs/reference/rest/v2/tabledata/list + + :param dataset_id: the dataset ID of the requested table. + :param table_id: the table ID of the requested table. + :param max_results: the maximum results to return. + :param selected_fields: List of fields to return (comma-separated). If + unspecified, all fields are returned. + :param page_token: page token, returned from a previous call, + identifying the result set. + :param start_index: zero based index of the starting row to read. + :param project_id: Project ID for the project which the client acts on behalf of. + :param location: Default location for job. + :return: list of rows + """ + location = location or self.location + if isinstance(selected_fields, str): + selected_fields = selected_fields.split(",") + + if selected_fields: + selected_fields = [SchemaField(n, "") for n in selected_fields] + else: + selected_fields = None + + table = self._resolve_table_reference( + table_resource={}, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + ) + + result = self.get_client(project_id=project_id, location=location).list_rows( + table=Table.from_api_repr(table), + selected_fields=selected_fields, + max_results=max_results, + page_token=page_token, + start_index=start_index, + ) + return list(result) + + @GoogleBaseHook.fallback_to_default_project_id + def get_schema( + self, dataset_id: str, table_id: str, project_id: Optional[str] = None + ) -> dict: + """ + Get the schema for a given dataset and table. + see https://cloud.google.com/bigquery/docs/reference/v2/tables#resource + + :param dataset_id: the dataset ID of the requested table + :param table_id: the table ID of the requested table + :param project_id: the optional project ID of the requested table. + If not provided, the connector's configured project will be used. + :return: a table schema + """ + table_ref = TableReference( + dataset_ref=DatasetReference(project_id, dataset_id), table_id=table_id + ) + table = self.get_client(project_id=project_id).get_table(table_ref) + return {"fields": [s.to_api_repr() for s in table.schema]} + + @GoogleBaseHook.fallback_to_default_project_id + def poll_job_complete( + self, + job_id: str, + project_id: Optional[str] = None, + location: Optional[str] = None, + retry: Retry = DEFAULT_RETRY, + ) -> bool: + """ + Check if jobs completed. + + :param job_id: id of the job. + :type job_id: str + :param project_id: Google Cloud Project where the job is running + :type project_id: str + :param location: location the job is running + :type location: str + :param retry: How to retry the RPC. + :type retry: google.api_core.retry.Retry + :rtype: bool + """ + location = location or self.location + job = self.get_client(project_id=project_id, location=location).get_job( + job_id=job_id + ) + return job.done(retry=retry) + + def cancel_query(self) -> None: + """Cancel all started queries that have not yet completed""" + warnings.warn( + "This method is deprecated. Please use `BigQueryHook.cancel_job`.", + DeprecationWarning, + ) + if self.running_job_id: + self.cancel_job(job_id=self.running_job_id) + else: + self.log.info("No running BigQuery jobs to cancel.") + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_job( + self, + job_id: str, + project_id: Optional[str] = None, + location: Optional[str] = None, + ) -> None: + """ + Cancels a job an wait for cancellation to complete + + :param job_id: id of the job. + :type job_id: str + :param project_id: Google Cloud Project where the job is running + :type project_id: str + :param location: location the job is running + :type location: str + """ + location = location or self.location + + if self.poll_job_complete(job_id=job_id): + self.log.info("No running BigQuery jobs to cancel.") + return + + self.log.info("Attempting to cancel job : %s, %s", project_id, job_id) + self.get_client(location=location, project_id=project_id).cancel_job( + job_id=job_id + ) + + # Wait for all the calls to cancel to finish + max_polling_attempts = 12 + polling_attempts = 0 + + job_complete = False + while polling_attempts < max_polling_attempts and not job_complete: + polling_attempts += 1 + job_complete = self.poll_job_complete(job_id) + if job_complete: + self.log.info("Job successfully canceled: %s, %s", project_id, job_id) + elif polling_attempts == max_polling_attempts: + self.log.info( + "Stopping polling due to timeout. Job with id %s " + "has not completed cancel and may or may not finish.", + job_id, + ) + else: + self.log.info("Waiting for canceled job with id %s to finish.", job_id) + time.sleep(5) + + @GoogleBaseHook.fallback_to_default_project_id + def get_job( + self, + job_id: Optional[str] = None, + project_id: Optional[str] = None, + location: Optional[str] = None, + ) -> Union[CopyJob, QueryJob, LoadJob, ExtractJob]: + """ + Retrieves a BigQuery job. For more information see: + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + :param job_id: The ID of the job. The ID must contain only letters (a-z, A-Z), + numbers (0-9), underscores (_), or dashes (-). The maximum length is 1,024 + characters. If not provided then uuid will be generated. + :type job_id: str + :param project_id: Google Cloud Project where the job is running + :type project_id: str + :param location: location the job is running + :type location: str + """ + client = self.get_client(project_id=project_id, location=location) + job = client.get_job(job_id=job_id, project=project_id, location=location) + return job + + @staticmethod + def _custom_job_id(configuration: Dict[str, Any]) -> str: + hash_base = json.dumps(configuration, sort_keys=True) + uniqueness_suffix = hashlib.md5(hash_base.encode()).hexdigest() + microseconds_from_epoch = int( + (datetime.now() - datetime.fromtimestamp(0)) / timedelta(microseconds=1) + ) + return f"airflow_{microseconds_from_epoch}_{uniqueness_suffix}" + + @GoogleBaseHook.fallback_to_default_project_id + def insert_job( + self, + configuration: Dict, + job_id: Optional[str] = None, + project_id: Optional[str] = None, + location: Optional[str] = None, + ) -> BigQueryJob: + """ + Executes a BigQuery job. Waits for the job to complete and returns job id. + See here: + + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + :param configuration: The configuration parameter maps directly to + BigQuery's configuration field in the job object. See + https://cloud.google.com/bigquery/docs/reference/v2/jobs for + details. + :type configuration: Dict[str, Any] + :param job_id: The ID of the job. The ID must contain only letters (a-z, A-Z), + numbers (0-9), underscores (_), or dashes (-). The maximum length is 1,024 + characters. If not provided then uuid will be generated. + :type job_id: str + :param project_id: Google Cloud Project where the job is running + :type project_id: str + :param location: location the job is running + :type location: str + """ + location = location or self.location + job_id = job_id or self._custom_job_id(configuration) + + client = self.get_client(project_id=project_id, location=location) + job_data = { + "configuration": configuration, + "jobReference": { + "jobId": job_id, + "projectId": project_id, + "location": location, + }, + } + # pylint: disable=protected-access + supported_jobs = { + LoadJob._JOB_TYPE: LoadJob, + CopyJob._JOB_TYPE: CopyJob, + ExtractJob._JOB_TYPE: ExtractJob, + QueryJob._JOB_TYPE: QueryJob, + } + # pylint: enable=protected-access + job = None + for job_type, job_object in supported_jobs.items(): + if job_type in configuration: + job = job_object + break + + if not job: + raise AirflowException( + f"Unknown job type. Supported types: {supported_jobs.keys()}" + ) + job = job.from_api_repr(job_data, client) + self.log.info("Inserting job %s", job.job_id) + # Start the job and wait for it to complete and get the result. + job.result() + return job + + def run_with_configuration(self, configuration: dict) -> str: + """ + Executes a BigQuery SQL query. See here: + + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + For more details about the configuration parameter. + + :param configuration: The configuration parameter maps directly to + BigQuery's configuration field in the job object. See + https://cloud.google.com/bigquery/docs/reference/v2/jobs for + details. + """ + warnings.warn( + "This method is deprecated. Please use `BigQueryHook.insert_job`", + DeprecationWarning, + ) + job = self.insert_job(configuration=configuration, project_id=self.project_id) + self.running_job_id = job.job_id + return job.job_id + + def run_load( # pylint: disable=too-many-locals,too-many-arguments,invalid-name + self, + destination_project_dataset_table: str, + source_uris: List, + schema_fields: Optional[List] = None, + source_format: str = "CSV", + create_disposition: str = "CREATE_IF_NEEDED", + skip_leading_rows: int = 0, + write_disposition: str = "WRITE_EMPTY", + field_delimiter: str = ",", + max_bad_records: int = 0, + quote_character: Optional[str] = None, + ignore_unknown_values: bool = False, + allow_quoted_newlines: bool = False, + allow_jagged_rows: bool = False, + encoding: str = "UTF-8", + schema_update_options: Optional[Iterable] = None, + src_fmt_configs: Optional[Dict] = None, + time_partitioning: Optional[Dict] = None, + cluster_fields: Optional[List] = None, + autodetect: bool = False, + encryption_configuration: Optional[Dict] = None, + labels: Optional[Dict] = None, + description: Optional[str] = None, + ) -> str: + """ + Executes a BigQuery load command to load data from Google Cloud Storage + to BigQuery. See here: + + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + For more details about these parameters. + + :param destination_project_dataset_table: + The dotted ``(.|:).
($)`` BigQuery + table to load data into. If ```` is not included, project will be the + project defined in the connection json. If a partition is specified the + operator will automatically append the data, create a new partition or create + a new DAY partitioned table. + :type destination_project_dataset_table: str + :param schema_fields: The schema field list as defined here: + https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.load + Required if autodetect=False; optional if autodetect=True. + :type schema_fields: list + :param autodetect: Attempt to autodetect the schema for CSV and JSON + source files. + :type autodetect: bool + :param source_uris: The source Google Cloud + Storage URI (e.g. gs://some-bucket/some-file.txt). A single wild + per-object name can be used. + :type source_uris: list + :param source_format: File format to export. + :type source_format: str + :param create_disposition: The create disposition if the table doesn't exist. + :type create_disposition: str + :param skip_leading_rows: Number of rows to skip when loading from a CSV. + :type skip_leading_rows: int + :param write_disposition: The write disposition if the table already exists. + :type write_disposition: str + :param field_delimiter: The delimiter to use when loading from a CSV. + :type field_delimiter: str + :param max_bad_records: The maximum number of bad records that BigQuery can + ignore when running the job. + :type max_bad_records: int + :param quote_character: The value that is used to quote data sections in a CSV + file. + :type quote_character: str + :param ignore_unknown_values: [Optional] Indicates if BigQuery should allow + extra values that are not represented in the table schema. + If true, the extra values are ignored. If false, records with extra columns + are treated as bad records, and if there are too many bad records, an + invalid error is returned in the job result. + :type ignore_unknown_values: bool + :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not + (false). + :type allow_quoted_newlines: bool + :param allow_jagged_rows: Accept rows that are missing trailing optional columns. + The missing values are treated as nulls. If false, records with missing + trailing columns are treated as bad records, and if there are too many bad + records, an invalid error is returned in the job result. Only applicable when + source_format is CSV. + :type allow_jagged_rows: bool + :param encoding: The character encoding of the data. + + .. seealso:: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#externalDataConfiguration.csvOptions.encoding + :type encoding: str + :param schema_update_options: Allows the schema of the destination + table to be updated as a side effect of the load job. + :type schema_update_options: Union[list, tuple, set] + :param src_fmt_configs: configure optional fields specific to the source format + :type src_fmt_configs: dict + :param time_partitioning: configure optional time partitioning fields i.e. + partition by field, type and expiration as per API specifications. + :type time_partitioning: dict + :param cluster_fields: Request that the result of this load be stored sorted + by one or more columns. BigQuery supports clustering for both partitioned and + non-partitioned tables. The order of columns given determines the sort order. + :type cluster_fields: list[str] + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + :param labels: A dictionary containing labels for the BiqQuery table. + :type labels: dict + :param description: A string containing the description for the BigQuery table. + :type descriptin: str + """ + warnings.warn( + "This method is deprecated. Please use `BigQueryHook.insert_job` method.", + DeprecationWarning, + ) + + if not self.project_id: + raise ValueError("The project_id should be set") + + # To provide backward compatibility + schema_update_options = list(schema_update_options or []) + + # bigquery only allows certain source formats + # we check to make sure the passed source format is valid + # if it's not, we raise a ValueError + # Refer to this link for more details: + # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat # noqa # pylint: disable=line-too-long + + if schema_fields is None and not autodetect: + raise ValueError("You must either pass a schema or autodetect=True.") + + if src_fmt_configs is None: + src_fmt_configs = {} + + source_format = source_format.upper() + allowed_formats = [ + "CSV", + "NEWLINE_DELIMITED_JSON", + "AVRO", + "GOOGLE_SHEETS", + "DATASTORE_BACKUP", + "PARQUET", + ] + if source_format not in allowed_formats: + raise ValueError( + "{} is not a valid source format. " + "Please use one of the following types: {}".format( + source_format, allowed_formats + ) + ) + + # bigquery also allows you to define how you want a table's schema to change + # as a side effect of a load + # for more details: + # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schemaUpdateOptions + allowed_schema_update_options = [ + "ALLOW_FIELD_ADDITION", + "ALLOW_FIELD_RELAXATION", + ] + if not set(allowed_schema_update_options).issuperset( + set(schema_update_options) + ): + raise ValueError( + "{} contains invalid schema update options." + "Please only use one or more of the following options: {}".format( + schema_update_options, allowed_schema_update_options + ) + ) + + destination_project, destination_dataset, destination_table = _split_tablename( + table_input=destination_project_dataset_table, + default_project_id=self.project_id, + var_name="destination_project_dataset_table", + ) + + configuration = { + "load": { + "autodetect": autodetect, + "createDisposition": create_disposition, + "destinationTable": { + "projectId": destination_project, + "datasetId": destination_dataset, + "tableId": destination_table, + }, + "sourceFormat": source_format, + "sourceUris": source_uris, + "writeDisposition": write_disposition, + "ignoreUnknownValues": ignore_unknown_values, + } + } + + time_partitioning = _cleanse_time_partitioning( + destination_project_dataset_table, time_partitioning + ) + if time_partitioning: + configuration["load"].update({"timePartitioning": time_partitioning}) + + if cluster_fields: + configuration["load"].update({"clustering": {"fields": cluster_fields}}) + + if schema_fields: + configuration["load"]["schema"] = {"fields": schema_fields} + + if schema_update_options: + if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: + raise ValueError( + "schema_update_options is only " + "allowed if write_disposition is " + "'WRITE_APPEND' or 'WRITE_TRUNCATE'." + ) + else: + self.log.info( + "Adding experimental 'schemaUpdateOptions': %s", + schema_update_options, + ) + configuration["load"]["schemaUpdateOptions"] = schema_update_options + + if max_bad_records: + configuration["load"]["maxBadRecords"] = max_bad_records + + if encryption_configuration: + configuration["load"][ + "destinationEncryptionConfiguration" + ] = encryption_configuration + + if labels or description: + configuration["load"].update({"destinationTableProperties": {}}) + + if labels: + configuration["load"]["destinationTableProperties"]["labels"] = labels + + if description: + configuration["load"]["destinationTableProperties"][ + "description" + ] = description + + src_fmt_to_configs_mapping = { + "CSV": [ + "allowJaggedRows", + "allowQuotedNewlines", + "autodetect", + "fieldDelimiter", + "skipLeadingRows", + "ignoreUnknownValues", + "nullMarker", + "quote", + "encoding", + ], + "DATASTORE_BACKUP": ["projectionFields"], + "NEWLINE_DELIMITED_JSON": ["autodetect", "ignoreUnknownValues"], + "PARQUET": ["autodetect", "ignoreUnknownValues"], + "AVRO": ["useAvroLogicalTypes"], + } + + valid_configs = src_fmt_to_configs_mapping[source_format] + + # if following fields are not specified in src_fmt_configs, + # honor the top-level params for backward-compatibility + backward_compatibility_configs = { + "skipLeadingRows": skip_leading_rows, + "fieldDelimiter": field_delimiter, + "ignoreUnknownValues": ignore_unknown_values, + "quote": quote_character, + "allowQuotedNewlines": allow_quoted_newlines, + "encoding": encoding, + } + + src_fmt_configs = _validate_src_fmt_configs( + source_format, + src_fmt_configs, + valid_configs, + backward_compatibility_configs, + ) + + configuration["load"].update(src_fmt_configs) + + if allow_jagged_rows: + configuration["load"]["allowJaggedRows"] = allow_jagged_rows + + job = self.insert_job(configuration=configuration, project_id=self.project_id) + self.running_job_id = job.job_id + return job.job_id + + def run_copy( # pylint: disable=invalid-name + self, + source_project_dataset_tables: Union[List, str], + destination_project_dataset_table: str, + write_disposition: str = "WRITE_EMPTY", + create_disposition: str = "CREATE_IF_NEEDED", + labels: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None, + ) -> str: + """ + Executes a BigQuery copy command to copy data from one BigQuery table + to another. See here: + + https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.copy + + For more details about these parameters. + + :param source_project_dataset_tables: One or more dotted + ``(project:|project.).
`` + BigQuery tables to use as the source data. Use a list if there are + multiple source tables. + If ```` is not included, project will be the project defined + in the connection json. + :type source_project_dataset_tables: list|string + :param destination_project_dataset_table: The destination BigQuery + table. Format is: ``(project:|project.).
`` + :type destination_project_dataset_table: str + :param write_disposition: The write disposition if the table already exists. + :type write_disposition: str + :param create_disposition: The create disposition if the table doesn't exist. + :type create_disposition: str + :param labels: a dictionary containing labels for the job/query, + passed to BigQuery + :type labels: dict + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + """ + warnings.warn( + "This method is deprecated. Please use `BigQueryHook.insert_job` method.", + DeprecationWarning, + ) + if not self.project_id: + raise ValueError("The project_id should be set") + + source_project_dataset_tables = ( + [source_project_dataset_tables] + if not isinstance(source_project_dataset_tables, list) + else source_project_dataset_tables + ) + + source_project_dataset_tables_fixup = [] + for source_project_dataset_table in source_project_dataset_tables: + source_project, source_dataset, source_table = _split_tablename( + table_input=source_project_dataset_table, + default_project_id=self.project_id, + var_name="source_project_dataset_table", + ) + source_project_dataset_tables_fixup.append( + { + "projectId": source_project, + "datasetId": source_dataset, + "tableId": source_table, + } + ) + + destination_project, destination_dataset, destination_table = _split_tablename( + table_input=destination_project_dataset_table, + default_project_id=self.project_id, + ) + configuration = { + "copy": { + "createDisposition": create_disposition, + "writeDisposition": write_disposition, + "sourceTables": source_project_dataset_tables_fixup, + "destinationTable": { + "projectId": destination_project, + "datasetId": destination_dataset, + "tableId": destination_table, + }, + } + } + + if labels: + configuration["labels"] = labels + + if encryption_configuration: + configuration["copy"][ + "destinationEncryptionConfiguration" + ] = encryption_configuration + + job = self.insert_job(configuration=configuration, project_id=self.project_id) + self.running_job_id = job.job_id + return job.job_id + + def run_extract( + self, + source_project_dataset_table: str, + destination_cloud_storage_uris: str, + compression: str = "NONE", + export_format: str = "CSV", + field_delimiter: str = ",", + print_header: bool = True, + labels: Optional[Dict] = None, + ) -> str: + """ + Executes a BigQuery extract command to copy data from BigQuery to + Google Cloud Storage. See here: + + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + For more details about these parameters. + + :param source_project_dataset_table: The dotted ``.
`` + BigQuery table to use as the source data. + :type source_project_dataset_table: str + :param destination_cloud_storage_uris: The destination Google Cloud + Storage URI (e.g. gs://some-bucket/some-file.txt). Follows + convention defined here: + https://cloud.google.com/bigquery/exporting-data-from-bigquery#exportingmultiple + :type destination_cloud_storage_uris: list + :param compression: Type of compression to use. + :type compression: str + :param export_format: File format to export. + :type export_format: str + :param field_delimiter: The delimiter to use when extracting to a CSV. + :type field_delimiter: str + :param print_header: Whether to print a header for a CSV file extract. + :type print_header: bool + :param labels: a dictionary containing labels for the job/query, + passed to BigQuery + :type labels: dict + """ + warnings.warn( + "This method is deprecated. Please use `BigQueryHook.insert_job` method.", + DeprecationWarning, + ) + if not self.project_id: + raise ValueError("The project_id should be set") + + source_project, source_dataset, source_table = _split_tablename( + table_input=source_project_dataset_table, + default_project_id=self.project_id, + var_name="source_project_dataset_table", + ) + + configuration = { + "extract": { + "sourceTable": { + "projectId": source_project, + "datasetId": source_dataset, + "tableId": source_table, + }, + "compression": compression, + "destinationUris": destination_cloud_storage_uris, + "destinationFormat": export_format, + } + } # type: Dict[str, Any] + + if labels: + configuration["labels"] = labels + + if export_format == "CSV": + # Only set fieldDelimiter and printHeader fields if using CSV. + # Google does not like it if you set these fields for other export + # formats. + configuration["extract"]["fieldDelimiter"] = field_delimiter + configuration["extract"]["printHeader"] = print_header + + job = self.insert_job(configuration=configuration, project_id=self.project_id) + self.running_job_id = job.job_id + return job.job_id + + # pylint: disable=too-many-locals,too-many-arguments, too-many-branches + def run_query( + self, + sql: str, + destination_dataset_table: Optional[str] = None, + write_disposition: str = "WRITE_EMPTY", + allow_large_results: bool = False, + flatten_results: Optional[bool] = None, + udf_config: Optional[List] = None, + use_legacy_sql: Optional[bool] = None, + maximum_billing_tier: Optional[int] = None, + maximum_bytes_billed: Optional[float] = None, + create_disposition: str = "CREATE_IF_NEEDED", + query_params: Optional[List] = None, + labels: Optional[Dict] = None, + schema_update_options: Optional[Iterable] = None, + priority: str = "INTERACTIVE", + time_partitioning: Optional[Dict] = None, + api_resource_configs: Optional[Dict] = None, + cluster_fields: Optional[List[str]] = None, + location: Optional[str] = None, + encryption_configuration: Optional[Dict] = None, + ) -> str: + """ + Executes a BigQuery SQL query. Optionally persists results in a BigQuery + table. See here: + + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + For more details about these parameters. + + :param sql: The BigQuery SQL to execute. + :type sql: str + :param destination_dataset_table: The dotted ``.
`` + BigQuery table to save the query results. + :type destination_dataset_table: str + :param write_disposition: What to do if the table already exists in + BigQuery. + :type write_disposition: str + :param allow_large_results: Whether to allow large results. + :type allow_large_results: bool + :param flatten_results: If true and query uses legacy SQL dialect, flattens + all nested and repeated fields in the query results. ``allowLargeResults`` + must be true if this is set to false. For standard SQL queries, this + flag is ignored and results are never flattened. + :type flatten_results: bool + :param udf_config: The User Defined Function configuration for the query. + See https://cloud.google.com/bigquery/user-defined-functions for details. + :type udf_config: list + :param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false). + If `None`, defaults to `self.use_legacy_sql`. + :type use_legacy_sql: bool + :param api_resource_configs: a dictionary that contain params + 'configuration' applied for Google BigQuery Jobs API: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs + for example, {'query': {'useQueryCache': False}}. You could use it + if you need to provide some params that are not supported by the + BigQueryHook like args. + :type api_resource_configs: dict + :param maximum_billing_tier: Positive integer that serves as a + multiplier of the basic price. + :type maximum_billing_tier: int + :param maximum_bytes_billed: Limits the bytes billed for this job. + Queries that will have bytes billed beyond this limit will fail + (without incurring a charge). If unspecified, this will be + set to your project default. + :type maximum_bytes_billed: float + :param create_disposition: Specifies whether the job is allowed to + create new tables. + :type create_disposition: str + :param query_params: a list of dictionary containing query parameter types and + values, passed to BigQuery + :type query_params: list + :param labels: a dictionary containing labels for the job/query, + passed to BigQuery + :type labels: dict + :param schema_update_options: Allows the schema of the destination + table to be updated as a side effect of the query job. + :type schema_update_options: Union[list, tuple, set] + :param priority: Specifies a priority for the query. + Possible values include INTERACTIVE and BATCH. + The default value is INTERACTIVE. + :type priority: str + :param time_partitioning: configure optional time partitioning fields i.e. + partition by field, type and expiration as per API specifications. + :type time_partitioning: dict + :param cluster_fields: Request that the result of this query be stored sorted + by one or more columns. BigQuery supports clustering for both partitioned and + non-partitioned tables. The order of columns given determines the sort order. + :type cluster_fields: list[str] + :param location: The geographic location of the job. Required except for + US and EU. See details at + https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :type location: str + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + """ + warnings.warn( + "This method is deprecated. Please use `BigQueryHook.insert_job` method.", + DeprecationWarning, + ) + if not self.project_id: + raise ValueError("The project_id should be set") + + labels = labels or self.labels + schema_update_options = list(schema_update_options or []) + + if time_partitioning is None: + time_partitioning = {} + + if location: + self.location = location + + if not api_resource_configs: + api_resource_configs = self.api_resource_configs + else: + _validate_value("api_resource_configs", api_resource_configs, dict) + configuration = deepcopy(api_resource_configs) + if "query" not in configuration: + configuration["query"] = {} + + else: + _validate_value( + "api_resource_configs['query']", configuration["query"], dict + ) + + if sql is None and not configuration["query"].get("query", None): + raise TypeError( + "`BigQueryBaseCursor.run_query` missing 1 required positional argument: `sql`" + ) + + # BigQuery also allows you to define how you want a table's schema to change + # as a side effect of a query job + # for more details: + # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions # noqa # pylint: disable=line-too-long + + allowed_schema_update_options = [ + "ALLOW_FIELD_ADDITION", + "ALLOW_FIELD_RELAXATION", + ] + + if not set(allowed_schema_update_options).issuperset( + set(schema_update_options) + ): + raise ValueError( + "{} contains invalid schema update options. " + "Please only use one or more of the following " + "options: {}".format( + schema_update_options, allowed_schema_update_options + ) + ) + + if schema_update_options: + if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: + raise ValueError( + "schema_update_options is only " + "allowed if write_disposition is " + "'WRITE_APPEND' or 'WRITE_TRUNCATE'." + ) + + if destination_dataset_table: + ( + destination_project, + destination_dataset, + destination_table, + ) = _split_tablename( + table_input=destination_dataset_table, + default_project_id=self.project_id, + ) + + destination_dataset_table = { # type: ignore + "projectId": destination_project, + "datasetId": destination_dataset, + "tableId": destination_table, + } + + if cluster_fields: + cluster_fields = {"fields": cluster_fields} # type: ignore + + query_param_list = [ + (sql, "query", None, (str,)), + (priority, "priority", "INTERACTIVE", (str,)), + (use_legacy_sql, "useLegacySql", self.use_legacy_sql, bool), + (query_params, "queryParameters", None, list), + (udf_config, "userDefinedFunctionResources", None, list), + (maximum_billing_tier, "maximumBillingTier", None, int), + (maximum_bytes_billed, "maximumBytesBilled", None, float), + (time_partitioning, "timePartitioning", {}, dict), + (schema_update_options, "schemaUpdateOptions", None, list), + (destination_dataset_table, "destinationTable", None, dict), + (cluster_fields, "clustering", None, dict), + ] # type: List[Tuple] + + for param, param_name, param_default, param_type in query_param_list: + if param_name not in configuration["query"] and param in [None, {}, ()]: + if param_name == "timePartitioning": + param_default = _cleanse_time_partitioning( + destination_dataset_table, time_partitioning + ) + param = param_default + + if param in [None, {}, ()]: + continue + + _api_resource_configs_duplication_check( + param_name, param, configuration["query"] + ) + + configuration["query"][param_name] = param + + # check valid type of provided param, + # it last step because we can get param from 2 sources, + # and first of all need to find it + + _validate_value(param_name, configuration["query"][param_name], param_type) + + if param_name == "schemaUpdateOptions" and param: + self.log.info( + "Adding experimental 'schemaUpdateOptions': %s", + schema_update_options, + ) + + if param_name != "destinationTable": + continue + + for key in ["projectId", "datasetId", "tableId"]: + if key not in configuration["query"]["destinationTable"]: + raise ValueError( + "Not correct 'destinationTable' in " + "api_resource_configs. 'destinationTable' " + "must be a dict with {'projectId':'', " + "'datasetId':'', 'tableId':''}" + ) + + configuration["query"].update( + { + "allowLargeResults": allow_large_results, + "flattenResults": flatten_results, + "writeDisposition": write_disposition, + "createDisposition": create_disposition, + } + ) + + if ( + "useLegacySql" in configuration["query"] + and configuration["query"]["useLegacySql"] + and "queryParameters" in configuration["query"] + ): + raise ValueError("Query parameters are not allowed when using legacy SQL") + + if labels: + _api_resource_configs_duplication_check("labels", labels, configuration) + configuration["labels"] = labels + + if encryption_configuration: + configuration["query"][ + "destinationEncryptionConfiguration" + ] = encryption_configuration + + job = self.insert_job(configuration=configuration, project_id=self.project_id) + self.running_job_id = job.job_id + return job.job_id + + +class BigQueryPandasConnector(GbqConnector): + """ + This connector behaves identically to GbqConnector (from Pandas), except + that it allows the service to be injected, and disables a call to + self.get_credentials(). This allows Airflow to use BigQuery with Pandas + without forcing a three legged OAuth connection. Instead, we can inject + service account credentials into the binding. + """ + + def __init__( + self, + project_id: str, + service: str, + reauth: bool = False, + verbose: bool = False, + dialect="legacy", + ) -> None: + super().__init__(project_id) + gbq_check_google_client_version() + gbq_test_google_api_imports() + self.project_id = project_id + self.reauth = reauth + self.service = service + self.verbose = verbose + self.dialect = dialect + + +class BigQueryConnection: + """ + BigQuery does not have a notion of a persistent connection. Thus, these + objects are small stateless factories for cursors, which do all the real + work. + """ + + def __init__(self, *args, **kwargs) -> None: + self._args = args + self._kwargs = kwargs + + def close(self) -> None: # noqa: D403 + """BigQueryConnection does not have anything to close""" + + def commit(self) -> None: # noqa: D403 + """BigQueryConnection does not support transactions""" + + def cursor(self) -> "BigQueryCursor": # noqa: D403 + """Return a new :py:class:`Cursor` object using the connection""" + return BigQueryCursor(*self._args, **self._kwargs) + + def rollback(self) -> NoReturn: # noqa: D403 + """BigQueryConnection does not have transactions""" + raise NotImplementedError("BigQueryConnection does not have transactions") + + +class BigQueryBaseCursor(LoggingMixin): + """ + The BigQuery base cursor contains helper methods to execute queries against + BigQuery. The methods can be used directly by operators, in cases where a + PEP 249 cursor isn't needed. + """ + + def __init__( + self, + service: Any, + project_id: str, + hook: BigQueryHook, + use_legacy_sql: bool = True, + api_resource_configs: Optional[Dict] = None, + location: Optional[str] = None, + num_retries: int = 5, + labels: Optional[Dict] = None, + ) -> None: + + super().__init__() + self.service = service + self.project_id = project_id + self.use_legacy_sql = use_legacy_sql + if api_resource_configs: + _validate_value("api_resource_configs", api_resource_configs, dict) + self.api_resource_configs = ( + api_resource_configs if api_resource_configs else {} + ) # type Dict + self.running_job_id = None # type: Optional[str] + self.location = location + self.num_retries = num_retries + self.labels = labels + self.hook = hook + + def create_empty_table(self, *args, **kwargs) -> None: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_table` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_table`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.create_empty_table(*args, **kwargs) + + def create_empty_dataset(self, *args, **kwargs) -> None: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_dataset` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_dataset`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.create_empty_dataset(*args, **kwargs) + + def get_dataset_tables(self, *args, **kwargs) -> List[Dict[str, Any]]: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset_tables` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset_tables`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.get_dataset_tables(*args, **kwargs) + + def delete_dataset(self, *args, **kwargs) -> None: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.delete_dataset` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.delete_dataset`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.delete_dataset(*args, **kwargs) + + def create_external_table(self, *args, **kwargs) -> None: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_external_table` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_external_table`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.create_external_table(*args, **kwargs) + + def patch_table(self, *args, **kwargs) -> None: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.patch_table` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.patch_table`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.patch_table(*args, **kwargs) + + def insert_all(self, *args, **kwargs) -> None: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_all` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_all`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.insert_all(*args, **kwargs) + + def update_dataset(self, *args, **kwargs) -> Dict: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.update_dataset` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.update_dataset`", + DeprecationWarning, + stacklevel=3, + ) + return Dataset.to_api_repr(self.hook.update_dataset(*args, **kwargs)) + + def patch_dataset(self, *args, **kwargs) -> Dict: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.patch_dataset` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.patch_dataset`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.patch_dataset(*args, **kwargs) + + def get_dataset_tables_list(self, *args, **kwargs) -> List[Dict[str, Any]]: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset_tables_list` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset_tables_list`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.get_dataset_tables_list(*args, **kwargs) + + def get_datasets_list(self, *args, **kwargs) -> list: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_datasets_list` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_datasets_list`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.get_datasets_list(*args, **kwargs) + + def get_dataset(self, *args, **kwargs) -> dict: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.get_dataset(*args, **kwargs) + + def run_grant_dataset_view_access(self, *args, **kwargs) -> dict: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_grant_dataset_view_access` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks" + ".bigquery.BigQueryHook.run_grant_dataset_view_access`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.run_grant_dataset_view_access(*args, **kwargs) + + def run_table_upsert(self, *args, **kwargs) -> dict: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_table_upsert` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_table_upsert`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.run_table_upsert(*args, **kwargs) + + def run_table_delete(self, *args, **kwargs) -> None: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_table_delete` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_table_delete`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.run_table_delete(*args, **kwargs) + + def get_tabledata(self, *args, **kwargs) -> List[dict]: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_tabledata` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_tabledata`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.get_tabledata(*args, **kwargs) + + def get_schema(self, *args, **kwargs) -> dict: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_schema` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_schema`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.get_schema(*args, **kwargs) + + def poll_job_complete(self, *args, **kwargs) -> bool: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.poll_job_complete` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.poll_job_complete`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.poll_job_complete(*args, **kwargs) + + def cancel_query(self, *args, **kwargs) -> None: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.cancel_query` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.cancel_query`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.cancel_query(*args, **kwargs) # type: ignore # noqa + + def run_with_configuration(self, *args, **kwargs) -> str: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_with_configuration` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_with_configuration`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.run_with_configuration(*args, **kwargs) + + def run_load(self, *args, **kwargs) -> str: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_load` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_load`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.run_load(*args, **kwargs) + + def run_copy(self, *args, **kwargs) -> str: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_copy` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_copy`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.run_copy(*args, **kwargs) + + def run_extract(self, *args, **kwargs) -> str: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_extract` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_extract`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.run_extract(*args, **kwargs) + + def run_query(self, *args, **kwargs) -> str: + """ + This method is deprecated. + Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_query` + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_query`", + DeprecationWarning, + stacklevel=3, + ) + return self.hook.run_query(*args, **kwargs) + + +class BigQueryCursor(BigQueryBaseCursor): + """ + A very basic BigQuery PEP 249 cursor implementation. The PyHive PEP 249 + implementation was used as a reference: + + https://github.com/dropbox/PyHive/blob/master/pyhive/presto.py + https://github.com/dropbox/PyHive/blob/master/pyhive/common.py + """ + + def __init__( + self, + service: Any, + project_id: str, + hook: BigQueryHook, + use_legacy_sql: bool = True, + location: Optional[str] = None, + num_retries: int = 5, + ) -> None: + super().__init__( + service=service, + project_id=project_id, + hook=hook, + use_legacy_sql=use_legacy_sql, + location=location, + num_retries=num_retries, + ) + self.buffersize = None # type: Optional[int] + self.page_token = None # type: Optional[str] + self.job_id = None # type: Optional[str] + self.buffer = [] # type: list + self.all_pages_loaded = False # type: bool + + @property + def description(self) -> None: + """The schema description method is not currently implemented""" + raise NotImplementedError + + def close(self) -> None: + """By default, do nothing""" + + @property + def rowcount(self) -> int: + """By default, return -1 to indicate that this is not supported""" + return -1 + + def execute(self, operation: str, parameters: Optional[dict] = None) -> None: + """ + Executes a BigQuery query, and returns the job ID. + + :param operation: The query to execute. + :type operation: str + :param parameters: Parameters to substitute into the query. + :type parameters: dict + """ + sql = _bind_parameters(operation, parameters) if parameters else operation + self.flush_results() + self.job_id = self.hook.run_query(sql) + + def executemany(self, operation: str, seq_of_parameters: list) -> None: + """ + Execute a BigQuery query multiple times with different parameters. + + :param operation: The query to execute. + :type operation: str + :param seq_of_parameters: List of dictionary parameters to substitute into the + query. + :type seq_of_parameters: list + """ + for parameters in seq_of_parameters: + self.execute(operation, parameters) + + def flush_results(self) -> None: + """Flush results related cursor attributes""" + self.page_token = None + self.job_id = None + self.all_pages_loaded = False + self.buffer = [] + + def fetchone(self) -> Union[List, None]: + """Fetch the next row of a query result set""" + # pylint: disable=not-callable + return self.next() + + def next(self) -> Union[List, None]: + """ + Helper method for fetchone, which returns the next row from a buffer. + If the buffer is empty, attempts to paginate through the result set for + the next page, and load it into the buffer. + """ + if not self.job_id: + return None + + if not self.buffer: + if self.all_pages_loaded: + return None + + query_results = ( + self.service.jobs() + .getQueryResults( + projectId=self.project_id, + jobId=self.job_id, + location=self.location, + pageToken=self.page_token, + ) + .execute(num_retries=self.num_retries) + ) + + if "rows" in query_results and query_results["rows"]: + self.page_token = query_results.get("pageToken") + fields = query_results["schema"]["fields"] + col_types = [field["type"] for field in fields] + rows = query_results["rows"] + + for dict_row in rows: + typed_row = [ + _bq_cast(vs["v"], col_types[idx]) + for idx, vs in enumerate(dict_row["f"]) + ] + self.buffer.append(typed_row) + + if not self.page_token: + self.all_pages_loaded = True + + else: + # Reset all state since we've exhausted the results. + self.flush_results() + return None + + return self.buffer.pop(0) + + def fetchmany(self, size: Optional[int] = None) -> list: + """ + Fetch the next set of rows of a query result, returning a sequence of sequences + (e.g. a list of tuples). An empty sequence is returned when no more rows are + available. The number of rows to fetch per call is specified by the parameter. + If it is not given, the cursor's arraysize determines the number of rows to be + fetched. The method should try to fetch as many rows as indicated by the size + parameter. If this is not possible due to the specified number of rows not being + available, fewer rows may be returned. An :py:class:`~pyhive.exc.Error` + (or subclass) exception is raised if the previous call to + :py:meth:`execute` did not produce any result set or no call was issued yet. + """ + if size is None: + size = self.arraysize + result = [] + for _ in range(size): + one = self.fetchone() + if one is None: + break + result.append(one) + return result + + def fetchall(self) -> List[list]: + """ + Fetch all (remaining) rows of a query result, returning them as a sequence of + sequences (e.g. a list of tuples). + """ + result = [] + while True: + one = self.fetchone() + if one is None: + break + result.append(one) + return result + + def get_arraysize(self) -> int: + """Specifies the number of rows to fetch at a time with .fetchmany()""" + return self.buffersize or 1 + + def set_arraysize(self, arraysize: int) -> None: + """Specifies the number of rows to fetch at a time with .fetchmany()""" + self.buffersize = arraysize + + arraysize = property(get_arraysize, set_arraysize) + + def setinputsizes(self, sizes: Any) -> None: + """Does nothing by default""" + + def setoutputsize(self, size: Any, column: Any = None) -> None: + """Does nothing by default""" + + +def _bind_parameters(operation: str, parameters: dict) -> str: + """Helper method that binds parameters to a SQL query""" + # inspired by MySQL Python Connector (conversion.py) + string_parameters = {} # type Dict[str, str] + for (name, value) in parameters.items(): + if value is None: + string_parameters[name] = "NULL" + elif isinstance(value, str): + string_parameters[name] = "'" + _escape(value) + "'" + else: + string_parameters[name] = str(value) + return operation % string_parameters + + +def _escape(s: str) -> str: + """Helper method that escapes parameters to a SQL query""" + e = s + e = e.replace("\\", "\\\\") + e = e.replace("\n", "\\n") + e = e.replace("\r", "\\r") + e = e.replace("'", "\\'") + e = e.replace('"', '\\"') + return e + + +def _bq_cast(string_field: str, bq_type: str) -> Union[None, int, float, bool, str]: + """ + Helper method that casts a BigQuery row to the appropriate data types. + This is useful because BigQuery returns all fields as strings. + """ + if string_field is None: + return None + elif bq_type == "INTEGER": + return int(string_field) + elif bq_type in ("FLOAT", "TIMESTAMP"): + return float(string_field) + elif bq_type == "BOOLEAN": + if string_field not in ["true", "false"]: + raise ValueError(f"{string_field} must have value 'true' or 'false'") + return string_field == "true" + else: + return string_field + + +def _split_tablename( + table_input: str, default_project_id: str, var_name: Optional[str] = None +) -> Tuple[str, str, str]: + + if "." not in table_input: + raise ValueError( + f"Expected table name in the format of .
. Got: {table_input}" + ) + + if not default_project_id: + raise ValueError("INTERNAL: No default project is specified") + + def var_print(var_name): + if var_name is None: + return "" + else: + return f"Format exception for {var_name}: " + + if table_input.count(".") + table_input.count(":") > 3: + raise Exception( + "{var}Use either : or . to specify project " + "got {input}".format(var=var_print(var_name), input=table_input) + ) + cmpt = table_input.rsplit(":", 1) + project_id = None + rest = table_input + if len(cmpt) == 1: + project_id = None + rest = cmpt[0] + elif len(cmpt) == 2 and cmpt[0].count(":") <= 1: + if cmpt[-1].count(".") != 2: + project_id = cmpt[0] + rest = cmpt[1] + else: + raise Exception( + "{var}Expect format of (.
, " + "got {input}".format(var=var_print(var_name), input=table_input) + ) + + cmpt = rest.split(".") + if len(cmpt) == 3: + if project_id: + raise ValueError( + f"{var_print(var_name)}Use either : or . to specify project" + ) + project_id = cmpt[0] + dataset_id = cmpt[1] + table_id = cmpt[2] + + elif len(cmpt) == 2: + dataset_id = cmpt[0] + table_id = cmpt[1] + else: + raise Exception( + "{var}Expect format of (.
, " + "got {input}".format(var=var_print(var_name), input=table_input) + ) + + if project_id is None: + if var_name is not None: + log.info( + 'Project not included in %s: %s; using project "%s"', + var_name, + table_input, + default_project_id, + ) + project_id = default_project_id + + return project_id, dataset_id, table_id + + +def _cleanse_time_partitioning( + destination_dataset_table: Optional[str], time_partitioning_in: Optional[Dict] +) -> Dict: # if it is a partitioned table ($ is in the table name) add partition load option + + if time_partitioning_in is None: + time_partitioning_in = {} + + time_partitioning_out = {} + if destination_dataset_table and "$" in destination_dataset_table: + time_partitioning_out["type"] = "DAY" + time_partitioning_out.update(time_partitioning_in) + return time_partitioning_out + + +def _validate_value(key: Any, value: Any, expected_type: Type) -> None: + """Function to check expected type and raise error if type is not correct""" + if not isinstance(value, expected_type): + raise TypeError( + f"{key} argument must have a type {expected_type} not {type(value)}" + ) + + +def _api_resource_configs_duplication_check( + key: Any, value: Any, config_dict: dict, config_dict_name="api_resource_configs" +) -> None: + if key in config_dict and value != config_dict[key]: + raise ValueError( + "Values of {param_name} param are duplicated. " + "{dict_name} contained {param_name} param " + "in `query` config and {param_name} was also provided " + "with arg to run_query() method. Please remove duplicates.".format( + param_name=key, dict_name=config_dict_name + ) + ) + + +def _validate_src_fmt_configs( + source_format: str, + src_fmt_configs: dict, + valid_configs: List[str], + backward_compatibility_configs: Optional[Dict] = None, +) -> Dict: + """ + Validates the given src_fmt_configs against a valid configuration for the source format. + Adds the backward compatibility config to the src_fmt_configs. + + :param source_format: File format to export. + :type source_format: str + :param src_fmt_configs: Configure optional fields specific to the source format. + :type src_fmt_configs: dict + :param valid_configs: Valid configuration specific to the source format + :type valid_configs: List[str] + :param backward_compatibility_configs: The top-level params for backward-compatibility + :type backward_compatibility_configs: dict + """ + if backward_compatibility_configs is None: + backward_compatibility_configs = {} + + for k, v in backward_compatibility_configs.items(): + if k not in src_fmt_configs and k in valid_configs: + src_fmt_configs[k] = v + + for k, v in src_fmt_configs.items(): + if k not in valid_configs: + raise ValueError( + f"{k} is not a valid src_fmt_configs for type {source_format}." + ) + + return src_fmt_configs diff --git a/reference/providers/google/cloud/hooks/bigquery_dts.py b/reference/providers/google/cloud/hooks/bigquery_dts.py new file mode 100644 index 0000000..c5c5d4e --- /dev/null +++ b/reference/providers/google/cloud/hooks/bigquery_dts.py @@ -0,0 +1,287 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains a BigQuery Hook.""" +from copy import copy +from typing import Optional, Sequence, Tuple, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.retry import Retry +from google.cloud.bigquery_datatransfer_v1 import DataTransferServiceClient +from google.cloud.bigquery_datatransfer_v1.types import ( + StartManualTransferRunsResponse, + TransferConfig, + TransferRun, +) +from googleapiclient.discovery import Resource + + +def get_object_id(obj: dict) -> str: + """Returns unique id of the object.""" + return obj["name"].rpartition("/")[-1] + + +class BiqQueryDataTransferServiceHook(GoogleBaseHook): + """ + Hook for Google Bigquery Transfer API. + + All the methods in the hook where ``project_id`` is used must be called with + keyword arguments rather than positional. + """ + + _conn = None # type: Optional[Resource] + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + + @staticmethod + def _disable_auto_scheduling(config: Union[dict, TransferConfig]) -> TransferConfig: + """ + In the case of Airflow, the customer needs to create a transfer config + with the automatic scheduling disabled (UI, CLI or an Airflow operator) and + then trigger a transfer run using a specialized Airflow operator that will + call start_manual_transfer_runs. + + :param config: Data transfer configuration to create. + :type config: Union[dict, google.cloud.bigquery_datatransfer_v1.types.TransferConfig] + """ + config = ( + TransferConfig.to_dict(config) + if isinstance(config, TransferConfig) + else config + ) + new_config = copy(config) + schedule_options = new_config.get("schedule_options") + if schedule_options: + disable_auto_scheduling = schedule_options.get( + "disable_auto_scheduling", None + ) + if disable_auto_scheduling is None: + schedule_options["disable_auto_scheduling"] = True + else: + new_config["schedule_options"] = {"disable_auto_scheduling": True} + # HACK: TransferConfig.to_dict returns invalid representation + # See: https://github.com/googleapis/python-bigquery-datatransfer/issues/90 + if isinstance(new_config.get("user_id"), str): + new_config["user_id"] = int(new_config["user_id"]) + return TransferConfig(**new_config) + + def get_conn(self) -> DataTransferServiceClient: + """ + Retrieves connection to Google Bigquery. + + :return: Google Bigquery API client + :rtype: google.cloud.bigquery_datatransfer_v1.DataTransferServiceClient + """ + if not self._conn: + self._conn = DataTransferServiceClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._conn + + @GoogleBaseHook.fallback_to_default_project_id + def create_transfer_config( + self, + transfer_config: Union[dict, TransferConfig], + project_id: str, + authorization_code: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> TransferConfig: + """ + Creates a new data transfer configuration. + + :param transfer_config: Data transfer configuration to create. + :type transfer_config: Union[dict, google.cloud.bigquery_datatransfer_v1.types.TransferConfig] + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :type project_id: str + :param authorization_code: authorization code to use with this transfer configuration. + This is required if new credentials are needed. + :type authorization_code: Optional[str] + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :return: A ``google.cloud.bigquery_datatransfer_v1.types.TransferConfig`` instance. + """ + client = self.get_conn() + parent = f"projects/{project_id}" + return client.create_transfer_config( + request={ + "parent": parent, + "transfer_config": self._disable_auto_scheduling(transfer_config), + "authorization_code": authorization_code, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_transfer_config( + self, + transfer_config_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes transfer configuration. + + :param transfer_config_id: Id of transfer config to be used. + :type transfer_config_id: str + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :type project_id: str + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :return: None + """ + client = self.get_conn() + name = f"projects/{project_id}/transferConfigs/{transfer_config_id}" + return client.delete_transfer_config( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def start_manual_transfer_runs( + self, + transfer_config_id: str, + project_id: str, + requested_time_range: Optional[dict] = None, + requested_run_time: Optional[dict] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> StartManualTransferRunsResponse: + """ + Start manual transfer runs to be executed now with schedule_time equal + to current time. The transfer runs can be created for a time range where + the run_time is between start_time (inclusive) and end_time + (exclusive), or for a specific run_time. + + :param transfer_config_id: Id of transfer config to be used. + :type transfer_config_id: str + :param requested_time_range: Time range for the transfer runs that should be started. + If a dict is provided, it must be of the same form as the protobuf + message `~google.cloud.bigquery_datatransfer_v1.types.TimeRange` + :type requested_time_range: Union[dict, ~google.cloud.bigquery_datatransfer_v1.types.TimeRange] + :param requested_run_time: Specific run_time for a transfer run to be started. The + requested_run_time must not be in the future. If a dict is provided, it + must be of the same form as the protobuf message + `~google.cloud.bigquery_datatransfer_v1.types.Timestamp` + :type requested_run_time: Union[dict, ~google.cloud.bigquery_datatransfer_v1.types.Timestamp] + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :type project_id: str + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :return: An ``google.cloud.bigquery_datatransfer_v1.types.StartManualTransferRunsResponse`` instance. + """ + client = self.get_conn() + parent = f"projects/{project_id}/transferConfigs/{transfer_config_id}" + return client.start_manual_transfer_runs( + request={ + "parent": parent, + "requested_time_range": requested_time_range, + "requested_run_time": requested_run_time, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_transfer_run( + self, + run_id: str, + transfer_config_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> TransferRun: + """ + Returns information about the particular transfer run. + + :param run_id: ID of the transfer run. + :type run_id: str + :param transfer_config_id: ID of transfer config to be used. + :type transfer_config_id: str + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :type project_id: str + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :return: An ``google.cloud.bigquery_datatransfer_v1.types.TransferRun`` instance. + """ + client = self.get_conn() + name = ( + f"projects/{project_id}/transferConfigs/{transfer_config_id}/runs/{run_id}" + ) + return client.get_transfer_run( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) diff --git a/reference/providers/google/cloud/hooks/bigtable.py b/reference/providers/google/cloud/hooks/bigtable.py new file mode 100644 index 0000000..efa55e8 --- /dev/null +++ b/reference/providers/google/cloud/hooks/bigtable.py @@ -0,0 +1,354 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Bigtable Hook.""" +import enum +import warnings +from typing import Dict, List, Optional, Sequence, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.cloud.bigtable import Client +from google.cloud.bigtable.cluster import Cluster +from google.cloud.bigtable.column_family import ColumnFamily, GarbageCollectionRule +from google.cloud.bigtable.instance import Instance +from google.cloud.bigtable.table import ClusterState, Table +from google.cloud.bigtable_admin_v2 import enums + + +class BigtableHook(GoogleBaseHook): + """ + Hook for Google Cloud Bigtable APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + # pylint: disable=too-many-arguments + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client = None + + def _get_client(self, project_id: str): + if not self._client: + self._client = Client( + project=project_id, + credentials=self._get_credentials(), + client_info=self.client_info, + admin=True, + ) + return self._client + + @GoogleBaseHook.fallback_to_default_project_id + def get_instance(self, instance_id: str, project_id: str) -> Instance: + """ + Retrieves and returns the specified Cloud Bigtable instance if it exists. + Otherwise, returns None. + + :param instance_id: The ID of the Cloud Bigtable instance. + :type instance_id: str + :param project_id: Optional, Google Cloud project ID where the + BigTable exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + """ + instance = self._get_client(project_id=project_id).instance(instance_id) + if not instance.exists(): + return None + return instance + + @GoogleBaseHook.fallback_to_default_project_id + def delete_instance(self, instance_id: str, project_id: str) -> None: + """ + Deletes the specified Cloud Bigtable instance. + Raises google.api_core.exceptions.NotFound if the Cloud Bigtable instance does + not exist. + + :param project_id: Optional, Google Cloud project ID where the + BigTable exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :param instance_id: The ID of the Cloud Bigtable instance. + :type instance_id: str + """ + instance = self.get_instance(instance_id=instance_id, project_id=project_id) + if instance: + instance.delete() + else: + self.log.warning( + "The instance '%s' does not exist in project '%s'. Exiting", + instance_id, + project_id, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_instance( + self, + instance_id: str, + main_cluster_id: str, + main_cluster_zone: str, + project_id: str, + replica_clusters: Optional[List[Dict[str, str]]] = None, + replica_cluster_id: Optional[str] = None, + replica_cluster_zone: Optional[str] = None, + instance_display_name: Optional[str] = None, + instance_type: enums.Instance.Type = enums.Instance.Type.TYPE_UNSPECIFIED, + instance_labels: Optional[Dict] = None, + cluster_nodes: Optional[int] = None, + cluster_storage_type: enums.StorageType = enums.StorageType.STORAGE_TYPE_UNSPECIFIED, + timeout: Optional[float] = None, + ) -> Instance: + """ + Creates new instance. + + :type instance_id: str + :param instance_id: The ID for the new instance. + :type main_cluster_id: str + :param main_cluster_id: The ID for main cluster for the new instance. + :type main_cluster_zone: str + :param main_cluster_zone: The zone for main cluster. + See https://cloud.google.com/bigtable/docs/locations for more details. + :type project_id: str + :param project_id: Optional, Google Cloud project ID where the + BigTable exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type replica_clusters: List[Dict[str, str]] + :param replica_clusters: (optional) A list of replica clusters for the new + instance. Each cluster dictionary contains an id and a zone. + Example: [{"id": "replica-1", "zone": "us-west1-a"}] + :type replica_cluster_id: str + :param replica_cluster_id: (deprecated) The ID for replica cluster for the new + instance. + :type replica_cluster_zone: str + :param replica_cluster_zone: (deprecated) The zone for replica cluster. + :type instance_type: enums.Instance.Type + :param instance_type: (optional) The type of the instance. + :type instance_display_name: str + :param instance_display_name: (optional) Human-readable name of the instance. + Defaults to ``instance_id``. + :type instance_labels: dict + :param instance_labels: (optional) Dictionary of labels to associate with the + instance. + :type cluster_nodes: int + :param cluster_nodes: (optional) Number of nodes for cluster. + :type cluster_storage_type: enums.StorageType + :param cluster_storage_type: (optional) The type of storage. + :type timeout: int + :param timeout: (optional) timeout (in seconds) for instance creation. + If None is not specified, Operator will wait indefinitely. + """ + cluster_storage_type = enums.StorageType(cluster_storage_type) + instance_type = enums.Instance.Type(instance_type) + + instance = Instance( + instance_id, + self._get_client(project_id=project_id), + instance_display_name, + instance_type, + instance_labels, + ) + + cluster_kwargs = dict( + cluster_id=main_cluster_id, + location_id=main_cluster_zone, + default_storage_type=cluster_storage_type, + ) + if instance_type != enums.Instance.Type.DEVELOPMENT and cluster_nodes: + cluster_kwargs["serve_nodes"] = cluster_nodes + clusters = [instance.cluster(**cluster_kwargs)] + if replica_cluster_id and replica_cluster_zone: + warnings.warn( + "The replica_cluster_id and replica_cluster_zone parameter have been deprecated." + "You should pass the replica_clusters parameter.", + DeprecationWarning, + stacklevel=2, + ) + clusters.append( + instance.cluster( + replica_cluster_id, + replica_cluster_zone, + cluster_nodes, + cluster_storage_type, + ) + ) + if replica_clusters: + for replica_cluster in replica_clusters: + if "id" in replica_cluster and "zone" in replica_cluster: + clusters.append( + instance.cluster( + replica_cluster["id"], + replica_cluster["zone"], + cluster_nodes, + cluster_storage_type, + ) + ) + operation = instance.create(clusters=clusters) + operation.result(timeout) + return instance + + @GoogleBaseHook.fallback_to_default_project_id + def update_instance( + self, + instance_id: str, + project_id: str, + instance_display_name: Optional[str] = None, + instance_type: Optional[Union[enums.Instance.Type, enum.IntEnum]] = None, + instance_labels: Optional[Dict] = None, + timeout: Optional[float] = None, + ) -> Instance: + """ + Update an existing instance. + + :type instance_id: str + :param instance_id: The ID for the existing instance. + :type project_id: str + :param project_id: Optional, Google Cloud project ID where the + BigTable exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type instance_display_name: str + :param instance_display_name: (optional) Human-readable name of the instance. + :type instance_type: enums.Instance.Type or enum.IntEnum + :param instance_type: (optional) The type of the instance. + :type instance_labels: dict + :param instance_labels: (optional) Dictionary of labels to associate with the + instance. + :type timeout: int + :param timeout: (optional) timeout (in seconds) for instance update. + If None is not specified, Operator will wait indefinitely. + """ + instance_type = enums.Instance.Type(instance_type) + + instance = Instance( + instance_id=instance_id, + client=self._get_client(project_id=project_id), + display_name=instance_display_name, + instance_type=instance_type, + labels=instance_labels, + ) + + operation = instance.update() + operation.result(timeout) + + return instance + + @staticmethod + def create_table( + instance: Instance, + table_id: str, + initial_split_keys: Optional[List] = None, + column_families: Optional[Dict[str, GarbageCollectionRule]] = None, + ) -> None: + """ + Creates the specified Cloud Bigtable table. + Raises ``google.api_core.exceptions.AlreadyExists`` if the table exists. + + :type instance: Instance + :param instance: The Cloud Bigtable instance that owns the table. + :type table_id: str + :param table_id: The ID of the table to create in Cloud Bigtable. + :type initial_split_keys: list + :param initial_split_keys: (Optional) A list of row keys in bytes to use to + initially split the table. + :type column_families: dict + :param column_families: (Optional) A map of columns to create. The key is the + column_id str, and the value is a + :class:`google.cloud.bigtable.column_family.GarbageCollectionRule`. + """ + if column_families is None: + column_families = {} + if initial_split_keys is None: + initial_split_keys = [] + table = Table(table_id, instance) + table.create(initial_split_keys, column_families) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_table(self, instance_id: str, table_id: str, project_id: str) -> None: + """ + Deletes the specified table in Cloud Bigtable. + Raises google.api_core.exceptions.NotFound if the table does not exist. + + :type instance_id: str + :param instance_id: The ID of the Cloud Bigtable instance. + :type table_id: str + :param table_id: The ID of the table in Cloud Bigtable. + :type project_id: str + :param project_id: Optional, Google Cloud project ID where the + BigTable exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + """ + table = self.get_instance(instance_id=instance_id, project_id=project_id).table( + table_id=table_id + ) + table.delete() + + @staticmethod + def update_cluster(instance: Instance, cluster_id: str, nodes: int) -> None: + """ + Updates number of nodes in the specified Cloud Bigtable cluster. + Raises google.api_core.exceptions.NotFound if the cluster does not exist. + + :type instance: Instance + :param instance: The Cloud Bigtable instance that owns the cluster. + :type cluster_id: str + :param cluster_id: The ID of the cluster. + :type nodes: int + :param nodes: The desired number of nodes. + """ + cluster = Cluster(cluster_id, instance) + cluster.serve_nodes = nodes + cluster.update() + + @staticmethod + def get_column_families_for_table( + instance: Instance, table_id: str + ) -> Dict[str, ColumnFamily]: + """ + Fetches Column Families for the specified table in Cloud Bigtable. + + :type instance: Instance + :param instance: The Cloud Bigtable instance that owns the table. + :type table_id: str + :param table_id: The ID of the table in Cloud Bigtable to fetch Column Families + from. + """ + table = Table(table_id, instance) + return table.list_column_families() + + @staticmethod + def get_cluster_states_for_table( + instance: Instance, table_id: str + ) -> Dict[str, ClusterState]: + """ + Fetches Cluster States for the specified table in Cloud Bigtable. + Raises google.api_core.exceptions.NotFound if the table does not exist. + + :type instance: Instance + :param instance: The Cloud Bigtable instance that owns the table. + :type table_id: str + :param table_id: The ID of the table in Cloud Bigtable to fetch Cluster States + from. + """ + table = Table(table_id, instance) + return table.get_cluster_states() diff --git a/reference/providers/google/cloud/hooks/cloud_build.py b/reference/providers/google/cloud/hooks/cloud_build.py new file mode 100644 index 0000000..9437688 --- /dev/null +++ b/reference/providers/google/cloud/hooks/cloud_build.py @@ -0,0 +1,156 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Google Cloud Build service""" + +import time +from typing import Any, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import build + +# Time to sleep between active checks of the operation results +TIME_TO_SLEEP_IN_SECONDS = 5 + + +class CloudBuildHook(GoogleBaseHook): + """ + Hook for the Google Cloud Build APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + + :param api_version: API version used (for example v1 or v1beta1). + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + _conn = None # type: Optional[Any] + + def __init__( + self, + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + + self.api_version = api_version + + def get_conn(self) -> build: + """ + Retrieves the connection to Cloud Build. + + :return: Google Cloud Build services object. + """ + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "cloudbuild", + self.api_version, + http=http_authorized, + cache_discovery=False, + ) + return self._conn + + @GoogleBaseHook.fallback_to_default_project_id + def create_build(self, body: dict, project_id: str) -> dict: + """ + Starts a build with the specified configuration. + + :param body: The request body. + See: https://cloud.google.com/cloud-build/docs/api/reference/rest/v1/projects.builds + :type body: dict + :param project_id: Optional, Google Cloud Project project_id where the function belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: Dict + """ + service = self.get_conn() + + # Create build + response = ( + service.projects() # pylint: disable=no-member + .builds() + .create(projectId=project_id, body=body) + .execute(num_retries=self.num_retries) + ) + + # Wait + operation_name = response["name"] + self._wait_for_operation_to_complete(operation_name=operation_name) + + # Get result + build_id = response["metadata"]["build"]["id"] + + result = ( + service.projects() # pylint: disable=no-member + .builds() + .get(projectId=project_id, id=build_id) + .execute(num_retries=self.num_retries) + ) + + return result + + def _wait_for_operation_to_complete(self, operation_name: str) -> None: + """ + Waits for the named operation to complete - checks status of the + asynchronous call. + + :param operation_name: The name of the operation. + :type operation_name: str + :return: The response returned by the operation. + :rtype: dict + :exception: AirflowException in case error is returned. + """ + service = self.get_conn() + while True: + operation_response = ( + # pylint: disable=no-member + service.operations() + .get(name=operation_name) + .execute(num_retries=self.num_retries) + ) + if operation_response.get("done"): + response = operation_response.get("response") + error = operation_response.get("error") + # Note, according to documentation always either response or error is + # set when "done" == True + if error: + raise AirflowException(str(error)) + return response + time.sleep(TIME_TO_SLEEP_IN_SECONDS) diff --git a/reference/providers/google/cloud/hooks/cloud_memorystore.py b/reference/providers/google/cloud/hooks/cloud_memorystore.py new file mode 100644 index 0000000..667f5a0 --- /dev/null +++ b/reference/providers/google/cloud/hooks/cloud_memorystore.py @@ -0,0 +1,1031 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hooks for Cloud Memorystore service""" +from typing import Dict, Optional, Sequence, Tuple, Union + +from airflow import version +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core import path_template +from google.api_core.exceptions import NotFound +from google.api_core.retry import Retry +from google.cloud.memcache_v1beta2 import CloudMemcacheClient +from google.cloud.memcache_v1beta2.types import cloud_memcache +from google.cloud.redis_v1 import ( + CloudRedisClient, + FailoverInstanceRequest, + InputConfig, + Instance, + OutputConfig, +) +from google.protobuf.field_mask_pb2 import FieldMask + + +class CloudMemorystoreHook(GoogleBaseHook): + """ + Hook for Google Cloud Memorystore APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client: Optional[CloudRedisClient] = None + + def get_conn(self) -> CloudRedisClient: + """Retrieves client library object that allow access to Cloud Memorystore service.""" + if not self._client: + self._client = CloudRedisClient(credentials=self._get_credentials()) + return self._client + + @staticmethod + def _append_label(instance: Instance, key: str, val: str) -> Instance: + """ + Append labels to provided Instance type + + Labels must fit the regex ``[a-z]([-a-z0-9]*[a-z0-9])?`` (current + airflow version string follows semantic versioning spec: x.y.z). + + :param instance: The proto to append resource_label airflow + version to + :type instance: google.cloud.container_v1.types.Cluster + :param key: The key label + :type key: str + :param val: + :type val: str + :return: The cluster proto updated with new label + """ + val = val.replace(".", "-").replace("+", "-") + instance.labels.update({key: val}) + return instance + + @GoogleBaseHook.fallback_to_default_project_id + def create_instance( + self, + location: str, + instance_id: str, + instance: Union[Dict, Instance], + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Creates a Redis instance based on the specified tier and memory size. + + By default, the instance is accessible from the project's `default network + `__. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: Required. The logical name of the Redis instance in the customer project with the + following restrictions: + + - Must contain only lowercase letters, numbers, and hyphens. + - Must start with a letter. + - Must be between 1-40 characters. + - Must end with a number or a letter. + - Must be unique within the customer project / location + :type instance_id: str + :param instance: Required. A Redis [Instance] resource + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.Instance` + :type instance: Union[Dict, google.cloud.redis_v1.types.Instance] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + if isinstance(instance, dict): + instance = Instance(**instance) + elif not isinstance(instance, Instance): + raise AirflowException( + "instance is not instance of Instance type or python dict" + ) + + parent = f"projects/{project_id}/locations/{location}" + instance_name = ( + f"projects/{project_id}/locations/{location}/instances/{instance_id}" + ) + try: + self.log.info("Fetching instance: %s", instance_name) + instance = client.get_instance( + request={"name": instance_name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Instance exists. Skipping creation.") + return instance + except NotFound: + self.log.info("Instance not exists.") + + self._append_label(instance, "airflow-version", "v" + version.version) + + result = client.create_instance( + request={ + "parent": parent, + "instance_id": instance_id, + "instance": instance, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + result.result() + self.log.info("Instance created.") + return client.get_instance( + request={"name": instance_name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_instance( + self, + location: str, + instance: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Deletes a specific Redis instance. Instance stops serving and data is deleted. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/instances/{instance}" + self.log.info("Fetching Instance: %s", name) + instance = client.get_instance( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + if not instance: + return + + self.log.info("Deleting Instance: %s", name) + result = client.delete_instance( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + result.result() + self.log.info("Instance deleted: %s", name) + + @GoogleBaseHook.fallback_to_default_project_id + def export_instance( + self, + location: str, + instance: str, + output_config: Union[Dict, OutputConfig], + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Export Redis instance data into a Redis RDB format file in Cloud Storage. + + Redis will continue serving during this operation. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param output_config: Required. Specify data to be exported. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.OutputConfig` + :type output_config: Union[Dict, google.cloud.redis_v1.types.OutputConfig] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/instances/{instance}" + self.log.info("Exporting Instance: %s", name) + result = client.export_instance( + request={"name": name, "output_config": output_config}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + result.result() + self.log.info("Instance exported: %s", name) + + @GoogleBaseHook.fallback_to_default_project_id + def failover_instance( + self, + location: str, + instance: str, + data_protection_mode: FailoverInstanceRequest.DataProtectionMode, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Initiates a failover of the master node to current replica node for a specific STANDARD tier Cloud + Memorystore for Redis instance. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param data_protection_mode: Optional. Available data protection modes that the user can choose. If + it's unspecified, data protection mode will be LIMITED_DATA_LOSS by default. + :type data_protection_mode: google.cloud.redis_v1.gapic.enums.FailoverInstanceRequest + .DataProtectionMode + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/instances/{instance}" + self.log.info("Failovering Instance: %s", name) + + result = client.failover_instance( + request={"name": name, "data_protection_mode": data_protection_mode}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + result.result() + self.log.info("Instance failovered: %s", name) + + @GoogleBaseHook.fallback_to_default_project_id + def get_instance( + self, + location: str, + instance: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Gets the details of a specific Redis instance. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/instances/{instance}" + result = client.get_instance( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Fetched Instance: %s", name) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def import_instance( + self, + location: str, + instance: str, + input_config: Union[Dict, InputConfig], + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Import a Redis RDB snapshot file from Cloud Storage into a Redis instance. + + Redis may stop serving during this operation. Instance state will be IMPORTING for entire operation. + When complete, the instance will contain only data from the imported file. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param input_config: Required. Specify data to be imported. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.InputConfig` + :type input_config: Union[Dict, google.cloud.redis_v1.types.InputConfig] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/instances/{instance}" + self.log.info("Importing Instance: %s", name) + result = client.import_instance( + request={"name": name, "input_config": input_config}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + result.result() + self.log.info("Instance imported: %s", name) + + @GoogleBaseHook.fallback_to_default_project_id + def list_instances( + self, + location: str, + page_size: int, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Lists all Redis instances owned by a project in either the specified location (region) or all + locations. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + + If it is specified as ``-`` (wildcard), then all regions available to the project are + queried, and the results are aggregated. + :type location: str + :param page_size: The maximum number of resources contained in the underlying API response. If page + streaming is performed per- resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number of resources in a page. + :type page_size: int + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + parent = f"projects/{project_id}/locations/{location}" + result = client.list_instances( + request={"parent": parent, "page_size": page_size}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Fetched instances") + return result + + @GoogleBaseHook.fallback_to_default_project_id + def update_instance( + self, + update_mask: Union[Dict, FieldMask], + instance: Union[Dict, Instance], + project_id: str, + location: Optional[str] = None, + instance_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Updates the metadata and configuration of a specific Redis instance. + + :param update_mask: Required. Mask of fields to update. At least one path must be supplied in this + field. The elements of the repeated paths field may only include these fields from ``Instance``: + + - ``displayName`` + - ``labels`` + - ``memorySizeGb`` + - ``redisConfig`` + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.FieldMask` + :type update_mask: Union[Dict, google.cloud.redis_v1.types.FieldMask] + :param instance: Required. Update description. Only fields specified in ``update_mask`` are updated. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.Instance` + :type instance: Union[Dict, google.cloud.redis_v1.types.Instance] + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: The logical name of the Redis instance in the customer project. + :type instance_id: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + + if isinstance(instance, dict): + instance = Instance(**instance) + elif not isinstance(instance, Instance): + raise AirflowException( + "instance is not instance of Instance type or python dict" + ) + + if location and instance_id: + name = f"projects/{project_id}/locations/{location}/instances/{instance_id}" + instance.name = name + + self.log.info("Updating instances: %s", instance.name) + result = client.update_instance( + request={"update_mask": update_mask, "instance": instance}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + result.result() + self.log.info("Instance updated: %s", instance.name) + + +class CloudMemorystoreMemcachedHook(GoogleBaseHook): + """ + Hook for Google Cloud Memorystore for Memcached service APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client: Optional[CloudMemcacheClient] = None + + def get_conn( + self, + ): + """Retrieves client library object that allow access to Cloud Memorystore Memcached service.""" + if not self._client: + self._client = CloudMemcacheClient(credentials=self._get_credentials()) + return self._client + + @staticmethod + def _append_label( + instance: cloud_memcache.Instance, key: str, val: str + ) -> cloud_memcache.Instance: + """ + Append labels to provided Instance type + + Labels must fit the regex ``[a-z]([-a-z0-9]*[a-z0-9])?`` (current + airflow version string follows semantic versioning spec: x.y.z). + + :param instance: The proto to append resource_label airflow + version to + :type instance: google.cloud.memcache_v1beta2.types.cloud_memcache.Instance + :param key: The key label + :type key: str + :param val: + :type val: str + :return: The cluster proto updated with new label + """ + val = val.replace(".", "-").replace("+", "-") + instance.labels.update({key: val}) + return instance + + @GoogleBaseHook.fallback_to_default_project_id + def apply_parameters( + self, + node_ids: Sequence[str], + apply_all: bool, + project_id: str, + location: str, + instance_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Will update current set of Parameters to the set of specified nodes of the Memcached Instance. + + :param node_ids: Nodes to which we should apply the instance-level parameter group. + :type node_ids: Sequence[str] + :param apply_all: Whether to apply instance-level parameter group to all nodes. If set to true, + will explicitly restrict users from specifying any nodes, and apply parameter group updates + to all nodes within the instance. + :type apply_all: bool + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: The logical name of the Memcached instance in the customer project. + :type instance_id: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + metadata = metadata or () + name = CloudMemcacheClient.instance_path(project_id, location, instance_id) + + self.log.info("Applying update to instance: %s", instance_id) + result = client.apply_parameters( + name=name, + node_ids=node_ids, + apply_all=apply_all, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + result.result() + self.log.info("Instance updated: %s", instance_id) + + @GoogleBaseHook.fallback_to_default_project_id + def create_instance( + self, + location: str, + instance_id: str, + instance: Union[Dict, cloud_memcache.Instance], + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Creates a Memcached instance based on the specified tier and memory size. + + By default, the instance is accessible from the project's `default network + `__. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: Required. The logical name of the Memcached instance in the customer project + with the following restrictions: + + - Must contain only lowercase letters, numbers, and hyphens. + - Must start with a letter. + - Must be between 1-40 characters. + - Must end with a number or a letter. + - Must be unique within the customer project / location + :type instance_id: str + :param instance: Required. A Memcached [Instance] resource + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.memcache_v1beta2.types.cloud_memcache.Instance` + :type instance: Union[Dict, google.cloud.memcache_v1beta2.types.cloud_memcache.Instance] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + metadata = metadata or () + parent = path_template.expand( + "projects/{project}/locations/{location}", + project=project_id, + location=location, + ) + instance_name = CloudMemcacheClient.instance_path( + project_id, location, instance_id + ) + try: + instance = client.get_instance( + name=instance_name, retry=retry, timeout=timeout, metadata=metadata + ) + self.log.info("Instance exists. Skipping creation.") + return instance + except NotFound: + self.log.info("Instance not exists.") + + if isinstance(instance, dict): + instance = cloud_memcache.Instance(instance) + elif not isinstance(instance, cloud_memcache.Instance): + raise AirflowException( + "instance is not instance of Instance type or python dict" + ) + + self._append_label(instance, "airflow-version", "v" + version.version) + + result = client.create_instance( + parent=parent, + instance_id=instance_id, + resource=instance, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + result.result() + self.log.info("Instance created.") + return client.get_instance( + name=instance_name, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_instance( + self, + location: str, + instance: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Deletes a specific Memcached instance. Instance stops serving and data is deleted. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Memcached instance in the customer project. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + metadata = metadata or () + name = CloudMemcacheClient.instance_path(project_id, location, instance) + self.log.info("Fetching Instance: %s", name) + instance = client.get_instance( + name=name, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + if not instance: + return + + self.log.info("Deleting Instance: %s", name) + result = client.delete_instance( + name=name, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + result.result() + self.log.info("Instance deleted: %s", name) + + @GoogleBaseHook.fallback_to_default_project_id + def get_instance( + self, + location: str, + instance: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Gets the details of a specific Memcached instance. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Memcached instance in the customer project. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + metadata = metadata or () + name = CloudMemcacheClient.instance_path(project_id, location, instance) + result = client.get_instance( + name=name, retry=retry, timeout=timeout, metadata=metadata or () + ) + self.log.info("Fetched Instance: %s", name) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_instances( + self, + location: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Lists all Memcached instances owned by a project in either the specified location (region) or all + locations. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + + If it is specified as ``-`` (wildcard), then all regions available to the project are + queried, and the results are aggregated. + :type location: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + metadata = metadata or () + parent = path_template.expand( + "projects/{project}/locations/{location}", + project=project_id, + location=location, + ) + result = client.list_instances( + parent=parent, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Fetched instances") + return result + + @GoogleBaseHook.fallback_to_default_project_id + def update_instance( + self, + update_mask: Union[Dict, cloud_memcache.field_mask.FieldMask], + instance: Union[Dict, cloud_memcache.Instance], + project_id: str, + location: Optional[str] = None, + instance_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Updates the metadata and configuration of a specific Memcached instance. + + :param update_mask: Required. Mask of fields to update. At least one path must be supplied in this + field. The elements of the repeated paths field may only include these fields from ``Instance``: + + - ``displayName`` + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.memcache_v1beta2.types.cloud_memcache.field_mask.FieldMask` + :type update_mask: + Union[Dict, google.cloud.memcache_v1beta2.types.cloud_memcache.field_mask.FieldMask] + :param instance: Required. Update description. Only fields specified in ``update_mask`` are updated. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.memcache_v1beta2.types.cloud_memcache.Instance` + :type instance: Union[Dict, google.cloud.memcache_v1beta2.types.cloud_memcache.Instance] + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: The logical name of the Memcached instance in the customer project. + :type instance_id: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + metadata = metadata or () + + if isinstance(instance, dict): + instance = cloud_memcache.Instance(instance) + elif not isinstance(instance, cloud_memcache.Instance): + raise AirflowException( + "instance is not instance of Instance type or python dict" + ) + + if location and instance_id: + name = CloudMemcacheClient.instance_path(project_id, location, instance_id) + instance.name = name + + self.log.info("Updating instances: %s", instance.name) + result = client.update_instance( + update_mask=update_mask, + resource=instance, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + result.result() + self.log.info("Instance updated: %s", instance.name) + + @GoogleBaseHook.fallback_to_default_project_id + def update_parameters( + self, + update_mask: Union[Dict, cloud_memcache.field_mask.FieldMask], + parameters: Union[Dict, cloud_memcache.MemcacheParameters], + project_id: str, + location: str, + instance_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Updates the defined Memcached Parameters for an existing Instance. This method only stages the + parameters, it must be followed by apply_parameters to apply the parameters to nodes of + the Memcached Instance. + + :param update_mask: Required. Mask of fields to update. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.memcache_v1beta2.types.cloud_memcache.field_mask.FieldMask` + :type update_mask: + Union[Dict, google.cloud.memcache_v1beta2.types.cloud_memcache.field_mask.FieldMask] + :param parameters: The parameters to apply to the instance. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.memcache_v1beta2.types.cloud_memcache.MemcacheParameters` + :type parameters: Union[Dict, google.cloud.memcache_v1beta2.types.cloud_memcache.MemcacheParameters] + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: The logical name of the Memcached instance in the customer project. + :type instance_id: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + metadata = metadata or () + + if isinstance(parameters, dict): + parameters = cloud_memcache.MemcacheParameters(parameters) + elif not isinstance(parameters, cloud_memcache.MemcacheParameters): + raise AirflowException( + "instance is not instance of MemcacheParameters type or python dict" + ) + + name = CloudMemcacheClient.instance_path(project_id, location, instance_id) + self.log.info("Staging update to instance: %s", instance_id) + result = client.update_parameters( + name=name, + update_mask=update_mask, + parameters=parameters, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + result.result() + self.log.info("Update staged for instance: %s", instance_id) diff --git a/reference/providers/google/cloud/hooks/cloud_sql.py b/reference/providers/google/cloud/hooks/cloud_sql.py new file mode 100644 index 0000000..522593c --- /dev/null +++ b/reference/providers/google/cloud/hooks/cloud_sql.py @@ -0,0 +1,1085 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-lines +"""This module contains a Google Cloud SQL Hook.""" + +import errno +import json +import os +import os.path +import platform +import random +import re +import shutil +import socket +import string +import subprocess +import time +import uuid +from pathlib import Path +from subprocess import PIPE, Popen +from typing import Any, Dict, List, Optional, Sequence, Union +from urllib.parse import quote_plus + +import requests +from airflow.exceptions import AirflowException + +# Number of retries - used by googleapiclient method calls to perform retries +# For requests that are "retriable" +from airflow.hooks.base import BaseHook +from airflow.models import Connection +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.providers.postgres.hooks.postgres import PostgresHook +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.session import provide_session +from googleapiclient.discovery import Resource, build +from googleapiclient.errors import HttpError +from sqlalchemy.orm import Session + +UNIX_PATH_MAX = 108 + +# Time to sleep between active checks of the operation results +TIME_TO_SLEEP_IN_SECONDS = 20 + + +class CloudSqlOperationStatus: + """Helper class with operation statuses.""" + + PENDING = "PENDING" + RUNNING = "RUNNING" + DONE = "DONE" + UNKNOWN = "UNKNOWN" + + +class CloudSQLHook(GoogleBaseHook): + """ + Hook for Google Cloud SQL APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + conn_name_attr = "gcp_conn_id" + default_conn_name = "google_cloud_default" + conn_type = "gcpcloudsql" + hook_name = "Google Cloud SQL" + + def __init__( + self, + api_version: str, + gcp_conn_id: str = default_conn_name, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.api_version = api_version + self._conn = None + + def get_conn(self) -> Re# + """ + Retrieves connection to Cloud SQL. + + :return: Google Cloud SQL services object. + :rtype: dict + """ + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "sqladmin", + self.api_version, + http=http_authorized, + cache_discovery=False, + ) + return self._conn + + @GoogleBaseHook.fallback_to_default_project_id + def get_instance(self, instance: str, project_id: str) -> dict: + """ + Retrieves a resource containing information about a Cloud SQL instance. + + :param instance: Database instance ID. This does not include the project ID. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: A Cloud SQL instance resource. + :rtype: dict + """ + return ( + self.get_conn() # noqa # pylint: disable=no-member + .instances() + .get(project=project_id, instance=instance) + .execute(num_retries=self.num_retries) + ) + + @GoogleBaseHook.fallback_to_default_project_id + @GoogleBaseHook.operation_in_progress_retry() + def create_instance(self, body: Dict, project_id: str) -> None: + """ + Creates a new Cloud SQL instance. + + :param body: Body required by the Cloud SQL insert API, as described in + https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/instances/insert#request-body. + :type body: dict + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + response = ( + self.get_conn() # noqa # pylint: disable=no-member + .instances() + .insert(project=project_id, body=body) + .execute(num_retries=self.num_retries) + ) + operation_name = response["name"] + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name + ) + + @GoogleBaseHook.fallback_to_default_project_id + @GoogleBaseHook.operation_in_progress_retry() + def patch_instance(self, body: dict, instance: str, project_id: str) -> None: + """ + Updates settings of a Cloud SQL instance. + + Caution: This is not a partial update, so you must include values for + all the settings that you want to retain. + + :param body: Body required by the Cloud SQL patch API, as described in + https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/instances/patch#request-body. + :type body: dict + :param instance: Cloud SQL instance ID. This does not include the project ID. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + response = ( + self.get_conn() # noqa # pylint: disable=no-member + .instances() + .patch(project=project_id, instance=instance, body=body) + .execute(num_retries=self.num_retries) + ) + operation_name = response["name"] + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name + ) + + @GoogleBaseHook.fallback_to_default_project_id + @GoogleBaseHook.operation_in_progress_retry() + def delete_instance(self, instance: str, project_id: str) -> None: + """ + Deletes a Cloud SQL instance. + + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param instance: Cloud SQL instance ID. This does not include the project ID. + :type instance: str + :return: None + """ + response = ( + self.get_conn() # noqa # pylint: disable=no-member + .instances() + .delete(project=project_id, instance=instance) + .execute(num_retries=self.num_retries) + ) + operation_name = response["name"] + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_database(self, instance: str, database: str, project_id: str) -> dict: + """ + Retrieves a database resource from a Cloud SQL instance. + + :param instance: Database instance ID. This does not include the project ID. + :type instance: str + :param database: Name of the database in the instance. + :type database: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: A Cloud SQL database resource, as described in + https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/databases#resource. + :rtype: dict + """ + return ( + self.get_conn() # noqa # pylint: disable=no-member + .databases() + .get(project=project_id, instance=instance, database=database) + .execute(num_retries=self.num_retries) + ) + + @GoogleBaseHook.fallback_to_default_project_id + @GoogleBaseHook.operation_in_progress_retry() + def create_database(self, instance: str, body: Dict, project_id: str) -> None: + """ + Creates a new database inside a Cloud SQL instance. + + :param instance: Database instance ID. This does not include the project ID. + :type instance: str + :param body: The request body, as described in + https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/databases/insert#request-body. + :type body: dict + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + response = ( + self.get_conn() # noqa # pylint: disable=no-member + .databases() + .insert(project=project_id, instance=instance, body=body) + .execute(num_retries=self.num_retries) + ) + operation_name = response["name"] + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name + ) + + @GoogleBaseHook.fallback_to_default_project_id + @GoogleBaseHook.operation_in_progress_retry() + def patch_database( + self, + instance: str, + database: str, + body: Dict, + project_id: str, + ) -> None: + """ + Updates a database resource inside a Cloud SQL instance. + + This method supports patch semantics. + See https://cloud.google.com/sql/docs/mysql/admin-api/how-tos/performance#patch. + + :param instance: Database instance ID. This does not include the project ID. + :type instance: str + :param database: Name of the database to be updated in the instance. + :type database: str + :param body: The request body, as described in + https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/databases/insert#request-body. + :type body: dict + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + response = ( + self.get_conn() # noqa # pylint: disable=no-member + .databases() + .patch(project=project_id, instance=instance, database=database, body=body) + .execute(num_retries=self.num_retries) + ) + operation_name = response["name"] + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name + ) + + @GoogleBaseHook.fallback_to_default_project_id + @GoogleBaseHook.operation_in_progress_retry() + def delete_database(self, instance: str, database: str, project_id: str) -> None: + """ + Deletes a database from a Cloud SQL instance. + + :param instance: Database instance ID. This does not include the project ID. + :type instance: str + :param database: Name of the database to be deleted in the instance. + :type database: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + response = ( + self.get_conn() # noqa # pylint: disable=no-member + .databases() + .delete(project=project_id, instance=instance, database=database) + .execute(num_retries=self.num_retries) + ) + operation_name = response["name"] + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name + ) + + @GoogleBaseHook.fallback_to_default_project_id + @GoogleBaseHook.operation_in_progress_retry() + def export_instance(self, instance: str, body: Dict, project_id: str) -> None: + """ + Exports data from a Cloud SQL instance to a Cloud Storage bucket as a SQL dump + or CSV file. + + :param instance: Database instance ID of the Cloud SQL instance. This does not include the + project ID. + :type instance: str + :param body: The request body, as described in + https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/instances/export#request-body + :type body: dict + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + response = ( + self.get_conn() # noqa # pylint: disable=no-member + .instances() + .export(project=project_id, instance=instance, body=body) + .execute(num_retries=self.num_retries) + ) + operation_name = response["name"] + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name + ) + + @GoogleBaseHook.fallback_to_default_project_id + def import_instance(self, instance: str, body: Dict, project_id: str) -> None: + """ + Imports data into a Cloud SQL instance from a SQL dump or CSV file in + Cloud Storage. + + :param instance: Database instance ID. This does not include the + project ID. + :type instance: str + :param body: The request body, as described in + https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/instances/export#request-body + :type body: dict + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + try: + response = ( + self.get_conn() # noqa # pylint: disable=no-member + .instances() + .import_(project=project_id, instance=instance, body=body) + .execute(num_retries=self.num_retries) + ) + operation_name = response["name"] + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name + ) + except HttpError as ex: + raise AirflowException( + f"Importing instance {instance} failed: {ex.content}" + ) + + def _wait_for_operation_to_complete( + self, project_id: str, operation_name: str + ) -> None: + """ + Waits for the named operation to complete - checks status of the + asynchronous call. + + :param project_id: Project ID of the project that contains the instance. + :type project_id: str + :param operation_name: Name of the operation. + :type operation_name: str + :return: None + """ + service = self.get_conn() + while True: + operation_response = ( + service.operations() # noqa # pylint: disable=no-member + .get(project=project_id, operation=operation_name) + .execute(num_retries=self.num_retries) + ) + if operation_response.get("status") == CloudSqlOperationStatus.DONE: + error = operation_response.get("error") + if error: + # Extracting the errors list as string and trimming square braces + error_msg = str(error.get("errors"))[1:-1] + raise AirflowException(error_msg) + # No meaningful info to return from the response in case of success + return + time.sleep(TIME_TO_SLEEP_IN_SECONDS) + + +CLOUD_SQL_PROXY_DOWNLOAD_URL = "https://dl.google.com/cloudsql/cloud_sql_proxy.{}.{}" +CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL = ( + "https://storage.googleapis.com/cloudsql-proxy/{}/cloud_sql_proxy.{}.{}" +) + +GCP_CREDENTIALS_KEY_PATH = "extra__google_cloud_platform__key_path" +GCP_CREDENTIALS_KEYFILE_DICT = "extra__google_cloud_platform__keyfile_dict" + + +class CloudSqlProxyRunner(LoggingMixin): + """ + Downloads and runs cloud-sql-proxy as subprocess of the Python process. + + The cloud-sql-proxy needs to be downloaded and started before we can connect + to the Google Cloud SQL instance via database connection. It establishes + secure tunnel connection to the database. It authorizes using the + Google Cloud credentials that are passed by the configuration. + + More details about the proxy can be found here: + https://cloud.google.com/sql/docs/mysql/sql-proxy + + :param path_prefix: Unique path prefix where proxy will be downloaded and + directories created for unix sockets. + :type path_prefix: str + :param instance_specification: Specification of the instance to connect the + proxy to. It should be specified in the form that is described in + https://cloud.google.com/sql/docs/mysql/sql-proxy#multiple-instances in + -instances parameter (typically in the form of ``::`` + for UNIX socket connections and in the form of + ``::=tcp:`` for TCP connections. + :type instance_specification: str + :param gcp_conn_id: Id of Google Cloud connection to use for + authentication + :type gcp_conn_id: str + :param project_id: Optional id of the Google Cloud project to connect to - it overwrites + default project id taken from the Google Cloud connection. + :type project_id: str + :param sql_proxy_version: Specific version of SQL proxy to download + (for example 'v1.13'). By default latest version is downloaded. + :type sql_proxy_version: str + :param sql_proxy_binary_path: If specified, then proxy will be + used from the path specified rather than dynamically generated. This means + that if the binary is not present in that path it will also be downloaded. + :type sql_proxy_binary_path: str + """ + + def __init__( + self, + path_prefix: str, + instance_specification: str, + gcp_conn_id: str = "google_cloud_default", + project_id: Optional[str] = None, + sql_proxy_version: Optional[str] = None, + sql_proxy_binary_path: Optional[str] = None, + ) -> None: + super().__init__() + self.path_prefix = path_prefix + if not self.path_prefix: + raise AirflowException("The path_prefix must not be empty!") + self.sql_proxy_was_downloaded = False + self.sql_proxy_version = sql_proxy_version + self.download_sql_proxy_dir = None + self.sql_proxy_process = None # type: Optional[Popen] + self.instance_specification = instance_specification + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.command_line_parameters = [] # type: List[str] + self.cloud_sql_proxy_socket_directory = self.path_prefix + self.sql_proxy_path = ( + sql_proxy_binary_path + if sql_proxy_binary_path + else self.path_prefix + "_cloud_sql_proxy" + ) + self.credentials_path = self.path_prefix + "_credentials.json" + self._build_command_line_parameters() + + def _build_command_line_parameters(self) -> None: + self.command_line_parameters.extend( + ["-dir", self.cloud_sql_proxy_socket_directory] + ) + self.command_line_parameters.extend(["-instances", self.instance_specification]) + + @staticmethod + def _is_os_64bit() -> bool: + return platform.machine().endswith("64") + + def _download_sql_proxy_if_needed(self) -> None: + if os.path.isfile(self.sql_proxy_path): + self.log.info("cloud-sql-proxy is already present") + return + system = platform.system().lower() + processor = "amd64" if CloudSqlProxyRunner._is_os_64bit() else "386" + if not self.sql_proxy_version: + download_url = CLOUD_SQL_PROXY_DOWNLOAD_URL.format(system, processor) + else: + download_url = CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL.format( + self.sql_proxy_version, system, processor + ) + proxy_path_tmp = self.sql_proxy_path + ".tmp" + self.log.info( + "Downloading cloud_sql_proxy from %s to %s", download_url, proxy_path_tmp + ) + response = requests.get(download_url, allow_redirects=True) + # Downloading to .tmp file first to avoid case where partially downloaded + # binary is used by parallel operator which uses the same fixed binary path + with open(proxy_path_tmp, "wb") as file: + file.write(response.content) + if response.status_code != 200: + raise AirflowException( + "The cloud-sql-proxy could not be downloaded. Status code = {}. " + "Reason = {}".format(response.status_code, response.reason) + ) + + self.log.info( + "Moving sql_proxy binary from %s to %s", proxy_path_tmp, self.sql_proxy_path + ) + shutil.move(proxy_path_tmp, self.sql_proxy_path) + os.chmod(self.sql_proxy_path, 0o744) # Set executable bit + self.sql_proxy_was_downloaded = True + + @provide_session + def _get_credential_parameters(self, session: Session) -> List[str]: + connection = ( + session.query(Connection) + .filter(Connection.conn_id == self.gcp_conn_id) + .first() + ) + session.expunge_all() + if connection.extra_dejson.get(GCP_CREDENTIALS_KEY_PATH): + credential_params = [ + "-credential_file", + connection.extra_dejson[GCP_CREDENTIALS_KEY_PATH], + ] + elif connection.extra_dejson.get(GCP_CREDENTIALS_KEYFILE_DICT): + credential_file_content = json.loads( + connection.extra_dejson[GCP_CREDENTIALS_KEYFILE_DICT] + ) + self.log.info("Saving credentials to %s", self.credentials_path) + with open(self.credentials_path, "w") as file: + json.dump(credential_file_content, file) + credential_params = ["-credential_file", self.credentials_path] + else: + self.log.info( + "The credentials are not supplied by neither key_path nor " + "keyfile_dict of the gcp connection %s. Falling back to " + "default activated account", + self.gcp_conn_id, + ) + credential_params = [] + + if not self.instance_specification: + project_id = connection.extra_dejson.get( + "extra__google_cloud_platform__project" + ) + if self.project_id: + project_id = self.project_id + if not project_id: + raise AirflowException( + "For forwarding all instances, the project id " + "for Google Cloud should be provided either " + "by project_id extra in the Google Cloud connection or by " + "project_id provided in the operator." + ) + credential_params.extend(["-projects", project_id]) + return credential_params + + def start_proxy(self) -> None: + """ + Starts Cloud SQL Proxy. + + You have to remember to stop the proxy if you started it! + """ + self._download_sql_proxy_if_needed() + if self.sql_proxy_process: + raise AirflowException( + f"The sql proxy is already running: {self.sql_proxy_process}" + ) + else: + command_to_run = [self.sql_proxy_path] + command_to_run.extend(self.command_line_parameters) + self.log.info( + "Creating directory %s", self.cloud_sql_proxy_socket_directory + ) + Path(self.cloud_sql_proxy_socket_directory).mkdir( + parents=True, exist_ok=True + ) + command_to_run.extend( + self._get_credential_parameters() + ) # pylint: disable=no-value-for-parameter + self.log.info("Running the command: `%s`", " ".join(command_to_run)) + self.sql_proxy_process = Popen( + command_to_run, stdin=PIPE, stdout=PIPE, stderr=PIPE + ) + self.log.info("The pid of cloud_sql_proxy: %s", self.sql_proxy_process.pid) + while True: + line = ( + self.sql_proxy_process.stderr.readline().decode("utf-8") + if self.sql_proxy_process.stderr + else "" + ) + return_code = self.sql_proxy_process.poll() + if line == "" and return_code is not None: + self.sql_proxy_process = None + raise AirflowException( + f"The cloud_sql_proxy finished early with return code {return_code}!" + ) + if line != "": + self.log.info(line) + if "googleapi: Error" in line or "invalid instance name:" in line: + self.stop_proxy() + raise AirflowException( + f"Error when starting the cloud_sql_proxy {line}!" + ) + if "Ready for new connections" in line: + return + + def stop_proxy(self) -> None: + """ + Stops running proxy. + + You should stop the proxy after you stop using it. + """ + if not self.sql_proxy_process: + raise AirflowException("The sql proxy is not started yet") + else: + self.log.info( + "Stopping the cloud_sql_proxy pid: %s", self.sql_proxy_process.pid + ) + self.sql_proxy_process.kill() + self.sql_proxy_process = None + # Cleanup! + self.log.info( + "Removing the socket directory: %s", self.cloud_sql_proxy_socket_directory + ) + shutil.rmtree(self.cloud_sql_proxy_socket_directory, ignore_errors=True) + if self.sql_proxy_was_downloaded: + self.log.info("Removing downloaded proxy: %s", self.sql_proxy_path) + # Silently ignore if the file has already been removed (concurrency) + try: + os.remove(self.sql_proxy_path) + except OSError as e: + if e.errno != errno.ENOENT: + raise + else: + self.log.info( + "Skipped removing proxy - it was not downloaded: %s", + self.sql_proxy_path, + ) + if os.path.isfile(self.credentials_path): + self.log.info( + "Removing generated credentials file %s", self.credentials_path + ) + # Here file cannot be delete by concurrent task (each task has its own copy) + os.remove(self.credentials_path) + + def get_proxy_version(self) -> Optional[str]: + """Returns version of the Cloud SQL Proxy.""" + self._download_sql_proxy_if_needed() + command_to_run = [self.sql_proxy_path] + command_to_run.extend(["--version"]) + command_to_run.extend( + self._get_credential_parameters() + ) # pylint: disable=no-value-for-parameter + result = subprocess.check_output(command_to_run).decode("utf-8") + pattern = re.compile("^.*[V|v]ersion ([^;]*);.*$") + matched = pattern.match(result) + if matched: + return matched.group(1) + else: + return None + + def get_socket_path(self) -> str: + """ + Retrieves UNIX socket path used by Cloud SQL Proxy. + + :return: The dynamically generated path for the socket created by the proxy. + :rtype: str + """ + return self.cloud_sql_proxy_socket_directory + "/" + self.instance_specification + + +CONNECTION_URIS = { + "postgres": { + "proxy": { + "tcp": "postgresql://{user}:{password}@127.0.0.1:{proxy_port}/{database}", + "socket": "postgresql://{user}:{password}@{socket_path}/{database}", + }, + "public": { + "ssl": "postgresql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "sslmode=verify-ca&" + "sslcert={client_cert_file}&" + "sslkey={client_key_file}&" + "sslrootcert={server_ca_file}", + "non-ssl": "postgresql://{user}:{password}@{public_ip}:{public_port}/{database}", + }, + }, + "mysql": { + "proxy": { + "tcp": "mysql://{user}:{password}@127.0.0.1:{proxy_port}/{database}", + "socket": "mysql://{user}:{password}@localhost/{database}?unix_socket={socket_path}", + }, + "public": { + "ssl": "mysql://{user}:{password}@{public_ip}:{public_port}/{database}?ssl={ssl_spec}", + "non-ssl": "mysql://{user}:{password}@{public_ip}:{public_port}/{database}", + }, + }, +} # type: Dict[str, Dict[str, Dict[str, str]]] + +CLOUD_SQL_VALID_DATABASE_TYPES = ["postgres", "mysql"] + + +class CloudSQLDatabaseHook(BaseHook): # noqa + # pylint: disable=too-many-instance-attributes + """ + Serves DB connection configuration for Google Cloud SQL (Connections + of *gcpcloudsqldb://* type). + + The hook is a "meta" one. It does not perform an actual connection. + It is there to retrieve all the parameters configured in gcpcloudsql:// connection, + start/stop Cloud SQL Proxy if needed, dynamically generate Postgres or MySQL + connection in the database and return an actual Postgres or MySQL hook. + The returned Postgres/MySQL hooks are using direct connection or Cloud SQL + Proxy socket/TCP as configured. + + Main parameters of the hook are retrieved from the standard URI components: + + * **user** - User name to authenticate to the database (from login of the URI). + * **password** - Password to authenticate to the database (from password of the URI). + * **public_ip** - IP to connect to for public connection (from host of the URI). + * **public_port** - Port to connect to for public connection (from port of the URI). + * **database** - Database to connect to (from schema of the URI). + + Remaining parameters are retrieved from the extras (URI query parameters): + + * **project_id** - Optional, Google Cloud project where the Cloud SQL + instance exists. If missing, default project id passed is used. + * **instance** - Name of the instance of the Cloud SQL database instance. + * **location** - The location of the Cloud SQL instance (for example europe-west1). + * **database_type** - The type of the database instance (MySQL or Postgres). + * **use_proxy** - (default False) Whether SQL proxy should be used to connect to Cloud + SQL DB. + * **use_ssl** - (default False) Whether SSL should be used to connect to Cloud SQL DB. + You cannot use proxy and SSL together. + * **sql_proxy_use_tcp** - (default False) If set to true, TCP is used to connect via + proxy, otherwise UNIX sockets are used. + * **sql_proxy_binary_path** - Optional path to Cloud SQL Proxy binary. If the binary + is not specified or the binary is not present, it is automatically downloaded. + * **sql_proxy_version** - Specific version of the proxy to download (for example + v1.13). If not specified, the latest version is downloaded. + * **sslcert** - Path to client certificate to authenticate when SSL is used. + * **sslkey** - Path to client private key to authenticate when SSL is used. + * **sslrootcert** - Path to server's certificate to authenticate when SSL is used. + + :param gcp_cloudsql_conn_id: URL of the connection + :type gcp_cloudsql_conn_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud for + cloud-sql-proxy authentication. + :type gcp_conn_id: str + :param default_gcp_project_id: Default project id used if project_id not specified + in the connection URL + :type default_gcp_project_id: str + """ + conn_name_attr = "gcp_cloudsql_conn_id" + default_conn_name = "google_cloud_sql_default" + conn_type = "gcpcloudsqldb" + hook_name = "Google Cloud SQL Database" + + _conn = None # type: Optional[Any] + + def __init__( + self, + gcp_cloudsql_conn_id: str = "google_cloud_sql_default", + gcp_conn_id: str = "google_cloud_default", + default_gcp_project_id: Optional[str] = None, + ) -> None: + super().__init__() + self.gcp_conn_id = gcp_conn_id + self.gcp_cloudsql_conn_id = gcp_cloudsql_conn_id + self.cloudsql_connection = self.get_connection(self.gcp_cloudsql_conn_id) + self.extras = self.cloudsql_connection.extra_dejson + self.project_id = self.extras.get( + "project_id", default_gcp_project_id + ) # type: Optional[str] + self.instance = self.extras.get("instance") # type: Optional[str] + self.database = self.cloudsql_connection.schema # type: Optional[str] + self.location = self.extras.get("location") # type: Optional[str] + self.database_type = self.extras.get("database_type") # type: Optional[str] + self.use_proxy = self._get_bool( + self.extras.get("use_proxy", "False") + ) # type: bool + self.use_ssl = self._get_bool(self.extras.get("use_ssl", "False")) # type: bool + self.sql_proxy_use_tcp = self._get_bool( + self.extras.get("sql_proxy_use_tcp", "False") + ) # type: bool + self.sql_proxy_version = self.extras.get( + "sql_proxy_version" + ) # type: Optional[str] + self.sql_proxy_binary_path = self.extras.get( + "sql_proxy_binary_path" + ) # type: Optional[str] + self.user = self.cloudsql_connection.login # type: Optional[str] + self.password = self.cloudsql_connection.password # type: Optional[str] + self.public_ip = self.cloudsql_connection.host # type: Optional[str] + self.public_port = self.cloudsql_connection.port # type: Optional[int] + self.sslcert = self.extras.get("sslcert") # type: Optional[str] + self.sslkey = self.extras.get("sslkey") # type: Optional[str] + self.sslrootcert = self.extras.get("sslrootcert") # type: Optional[str] + # Port and socket path and db_hook are automatically generated + self.sql_proxy_tcp_port = None + self.sql_proxy_unique_path = None # type: Optional[str] + self.db_hook = None # type: Optional[Union[PostgresHook, MySqlHook]] + self.reserved_tcp_socket = None # type: Optional[socket.socket] + # Generated based on clock + clock sequence. Unique per host (!). + # This is important as different hosts share the database + self.db_conn_id = str(uuid.uuid1()) + self._validate_inputs() + + @staticmethod + def _get_bool(val: Any) -> bool: + if val == "False": + return False + return True + + @staticmethod + def _check_ssl_file(file_to_check, name) -> None: + if not file_to_check: + raise AirflowException(f"SSL connections requires {name} to be set") + if not os.path.isfile(file_to_check): + raise AirflowException(f"The {file_to_check} must be a readable file") + + def _validate_inputs(self) -> None: + if self.project_id == "": + raise AirflowException("The required extra 'project_id' is empty") + if not self.location: + raise AirflowException("The required extra 'location' is empty or None") + if not self.instance: + raise AirflowException("The required extra 'instance' is empty or None") + if self.database_type not in CLOUD_SQL_VALID_DATABASE_TYPES: + raise AirflowException( + "Invalid database type '{}'. Must be one of {}".format( + self.database_type, CLOUD_SQL_VALID_DATABASE_TYPES + ) + ) + if self.use_proxy and self.use_ssl: + raise AirflowException( + "Cloud SQL Proxy does not support SSL connections." + " SSL is not needed as Cloud SQL Proxy " + "provides encryption on its own" + ) + + def validate_ssl_certs(self) -> None: + """ + SSL certificates validator. + + :return: None + """ + if self.use_ssl: + self._check_ssl_file(self.sslcert, "sslcert") + self._check_ssl_file(self.sslkey, "sslkey") + self._check_ssl_file(self.sslrootcert, "sslrootcert") + + def validate_socket_path_length(self) -> None: + """ + Validates sockets path length. + + :return: None or rises AirflowException + """ + if self.use_proxy and not self.sql_proxy_use_tcp: + if self.database_type == "postgres": + suffix = "/.s.PGSQL.5432" + else: + suffix = "" + expected_path = "{}/{}:{}:{}{}".format( + self._generate_unique_path(), + self.project_id, + self.instance, + self.database, + suffix, + ) + if len(expected_path) > UNIX_PATH_MAX: + self.log.info( + "Too long (%s) path: %s", len(expected_path), expected_path + ) + raise AirflowException( + "The UNIX socket path length cannot exceed {} characters " + "on Linux system. Either use shorter instance/database " + "name or switch to TCP connection. " + "The socket path for Cloud SQL proxy is now:" + "{}".format(UNIX_PATH_MAX, expected_path) + ) + + @staticmethod + def _generate_unique_path() -> str: + """ + We are not using mkdtemp here as the path generated with mkdtemp + can be close to 60 characters and there is a limitation in + length of socket path to around 100 characters in total. + We append project/location/instance to it later and postgres + appends its own prefix, so we chose a shorter "/tmp/[8 random characters]" + """ + random.seed() + while True: + candidate = "/tmp/" + "".join( + random.choice(string.ascii_lowercase + string.digits) for _ in range(8) + ) + if not os.path.exists(candidate): + return candidate + + @staticmethod + def _quote(value) -> Optional[str]: + return quote_plus(value) if value else None + + def _generate_connection_uri(self) -> str: + if self.use_proxy: + if self.sql_proxy_use_tcp: + if not self.sql_proxy_tcp_port: + self.reserve_free_tcp_port() + if not self.sql_proxy_unique_path: + self.sql_proxy_unique_path = self._generate_unique_path() + if not self.database_type: + raise ValueError("The database_type should be set") + + database_uris = CONNECTION_URIS[ + self.database_type + ] # type: Dict[str, Dict[str, str]] + ssl_spec = None + socket_path = None + if self.use_proxy: + proxy_uris = database_uris["proxy"] # type: Dict[str, str] + if self.sql_proxy_use_tcp: + format_string = proxy_uris["tcp"] + else: + format_string = proxy_uris["socket"] + socket_path = "{sql_proxy_socket_path}/{instance_socket_name}".format( + sql_proxy_socket_path=self.sql_proxy_unique_path, + instance_socket_name=self._get_instance_socket_name(), + ) + else: + public_uris = database_uris["public"] # type: Dict[str, str] + if self.use_ssl: + format_string = public_uris["ssl"] + ssl_spec = { + "cert": self.sslcert, + "key": self.sslkey, + "ca": self.sslrootcert, + } + else: + format_string = public_uris["non-ssl"] + if not self.user: + raise AirflowException("The login parameter needs to be set in connection") + if not self.public_ip: + raise AirflowException( + "The location parameter needs to be set in connection" + ) + if not self.password: + raise AirflowException( + "The password parameter needs to be set in connection" + ) + if not self.database: + raise AirflowException( + "The database parameter needs to be set in connection" + ) + + connection_uri = format_string.format( + user=quote_plus(self.user) if self.user else "", + password=quote_plus(self.password) if self.password else "", + database=quote_plus(self.database) if self.database else "", + public_ip=self.public_ip, + public_port=self.public_port, + proxy_port=self.sql_proxy_tcp_port, + socket_path=self._quote(socket_path), + ssl_spec=self._quote(json.dumps(ssl_spec)) if ssl_spec else "", + client_cert_file=self._quote(self.sslcert) if self.sslcert else "", + client_key_file=self._quote(self.sslkey) if self.sslcert else "", + server_ca_file=self._quote(self.sslrootcert if self.sslcert else ""), + ) + self.log.info( + "DB connection URI %s", + connection_uri.replace( + quote_plus(self.password) if self.password else "PASSWORD", + "XXXXXXXXXXXX", + ), + ) + return connection_uri + + def _get_instance_socket_name(self) -> str: + return self.project_id + ":" + self.location + ":" + self.instance # type: ignore + + def _get_sqlproxy_instance_specification(self) -> str: + instance_specification = self._get_instance_socket_name() + if self.sql_proxy_use_tcp: + instance_specification += "=tcp:" + str(self.sql_proxy_tcp_port) + return instance_specification + + def create_connection(self) -> Connection: + """ + Create Connection object, according to whether it uses proxy, TCP, UNIX sockets, SSL. + Connection ID will be randomly generated. + """ + uri = self._generate_connection_uri() + connection = Connection(conn_id=self.db_conn_id, uri=uri) + self.log.info("Creating connection %s", self.db_conn_id) + return connection + + def get_sqlproxy_runner(self) -> CloudSqlProxyRunner: + """ + Retrieve Cloud SQL Proxy runner. It is used to manage the proxy + lifecycle per task. + + :return: The Cloud SQL Proxy runner. + :rtype: CloudSqlProxyRunner + """ + if not self.use_proxy: + raise ValueError( + "Proxy runner can only be retrieved in case of use_proxy = True" + ) + if not self.sql_proxy_unique_path: + raise ValueError("The sql_proxy_unique_path should be set") + return CloudSqlProxyRunner( + path_prefix=self.sql_proxy_unique_path, + instance_specification=self._get_sqlproxy_instance_specification(), + project_id=self.project_id, + sql_proxy_version=self.sql_proxy_version, + sql_proxy_binary_path=self.sql_proxy_binary_path, + gcp_conn_id=self.gcp_conn_id, + ) + + def get_database_hook( + self, connection: Connection + ) -> Union[PostgresHook, MySqlHook]: + """ + Retrieve database hook. This is the actual Postgres or MySQL database hook + that uses proxy or connects directly to the Google Cloud SQL database. + """ + if self.database_type == "postgres": + self.db_hook = PostgresHook(connection=connection, schema=self.database) + else: + self.db_hook = MySqlHook(connection=connection, schema=self.database) + return self.db_hook + + def cleanup_database_hook(self) -> None: + """Clean up database hook after it was used.""" + if self.database_type == "postgres": + if not self.db_hook: + raise ValueError("The db_hook should be set") + if not isinstance(self.db_hook, PostgresHook): + raise ValueError( + f"The db_hook should be PostgresHook and is {type(self.db_hook)}" + ) + conn = getattr(self.db_hook, "conn") + if conn and conn.notices: + for output in self.db_hook.conn.notices: + self.log.info(output) + + def reserve_free_tcp_port(self) -> None: + """Reserve free TCP port to be used by Cloud SQL Proxy""" + self.reserved_tcp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.reserved_tcp_socket.bind(("127.0.0.1", 0)) + self.sql_proxy_tcp_port = self.reserved_tcp_socket.getsockname()[1] + + def free_reserved_port(self) -> None: + """Free TCP port. Makes it immediately ready to be used by Cloud SQL Proxy.""" + if self.reserved_tcp_socket: + self.reserved_tcp_socket.close() + self.reserved_tcp_socket = None diff --git a/reference/providers/google/cloud/hooks/cloud_storage_transfer_service.py b/reference/providers/google/cloud/hooks/cloud_storage_transfer_service.py new file mode 100644 index 0000000..7b43936 --- /dev/null +++ b/reference/providers/google/cloud/hooks/cloud_storage_transfer_service.py @@ -0,0 +1,591 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Storage Transfer Service Hook.""" + +import json +import logging +import time +import warnings +from copy import deepcopy +from datetime import timedelta +from typing import List, Optional, Sequence, Set, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import Resource, build +from googleapiclient.errors import HttpError + +log = logging.getLogger(__name__) + +# Time to sleep between active checks of the operation results +TIME_TO_SLEEP_IN_SECONDS = 10 + + +class GcpTransferJobsStatus: + """Class with Google Cloud Transfer jobs statuses.""" + + ENABLED = "ENABLED" + DISABLED = "DISABLED" + DELETED = "DELETED" + + +class GcpTransferOperationStatus: + """Class with Google Cloud Transfer operations statuses.""" + + IN_PROGRESS = "IN_PROGRESS" + PAUSED = "PAUSED" + SUCCESS = "SUCCESS" + FAILED = "FAILED" + ABORTED = "ABORTED" + + +# A list of keywords used to build a request or response +ACCESS_KEY_ID = "accessKeyId" +ALREADY_EXISTING_IN_SINK = "overwriteObjectsAlreadyExistingInSink" +AWS_ACCESS_KEY = "awsAccessKey" +AWS_S3_DATA_SOURCE = "awsS3DataSource" +BODY = "body" +BUCKET_NAME = "bucketName" +COUNTERS = "counters" +DAY = "day" +DESCRIPTION = "description" +FILTER = "filter" +FILTER_JOB_NAMES = "job_names" +FILTER_PROJECT_ID = "project_id" +GCS_DATA_SINK = "gcsDataSink" +GCS_DATA_SOURCE = "gcsDataSource" +HOURS = "hours" +HTTP_DATA_SOURCE = "httpDataSource" +JOB_NAME = "name" +LIST_URL = "list_url" +METADATA = "metadata" +MINUTES = "minutes" +MONTH = "month" +NAME = "name" +OBJECT_CONDITIONS = "object_conditions" +OPERATIONS = "operations" +PROJECT_ID = "projectId" +SCHEDULE = "schedule" +SCHEDULE_END_DATE = "scheduleEndDate" +SCHEDULE_START_DATE = "scheduleStartDate" +SECONDS = "seconds" +SECRET_ACCESS_KEY = "secretAccessKey" +START_TIME_OF_DAY = "startTimeOfDay" +STATUS = "status" +STATUS1 = "status" +TRANSFER_JOB = "transfer_job" +TRANSFER_JOBS = "transferJobs" +TRANSFER_JOB_FIELD_MASK = "update_transfer_job_field_mask" +TRANSFER_OPERATIONS = "transferOperations" +TRANSFER_OPTIONS = "transfer_options" +TRANSFER_SPEC = "transferSpec" +YEAR = "year" +ALREADY_EXIST_CODE = 409 + +NEGATIVE_STATUSES = { + GcpTransferOperationStatus.FAILED, + GcpTransferOperationStatus.ABORTED, +} + + +def gen_job_name(job_name: str) -> str: + """ + Adds unique suffix to job name. If suffix already exists, updates it. + Suffix — current timestamp + + :param job_name: + :rtype job_name: str + :return: job_name with suffix + :rtype: str + """ + uniq = int(time.time()) + return f"{job_name}_{uniq}" + + +class CloudDataTransferServiceHook(GoogleBaseHook): + """ + Hook for Google Storage Transfer Service. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + def __init__( + self, + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.api_version = api_version + self._conn = None + + def get_conn(self) -> Re# + """ + Retrieves connection to Google Storage Transfer service. + + :return: Google Storage Transfer service object + :rtype: dict + """ + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "storagetransfer", + self.api_version, + http=http_authorized, + cache_discovery=False, + ) + return self._conn + + def create_transfer_job(self, body: dict) -> dict: + """ + Creates a transfer job that runs periodically. + + :param body: (Required) A request body, as described in + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/patch#request-body + :type body: dict + :return: transfer job. + See: + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs#TransferJob + :rtype: dict + """ + body = self._inject_project_id(body, BODY, PROJECT_ID) + try: + # pylint: disable=no-member + transfer_job = ( + self.get_conn() + .transferJobs() + .create(body=body) + .execute(num_retries=self.num_retries) + ) + except HttpError as e: + # If status code "Conflict" + # https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferOperations#Code.ENUM_VALUES.ALREADY_EXISTS + # we should try to find this job + job_name = body.get(JOB_NAME, "") + if int(e.resp.status) == ALREADY_EXIST_CODE and job_name: + transfer_job = self.get_transfer_job( + job_name=job_name, project_id=body.get(PROJECT_ID) + ) + # Generate new job_name, if jobs status is deleted + # and try to create this job again + if transfer_job.get(STATUS) == GcpTransferJobsStatus.DELETED: + body[JOB_NAME] = gen_job_name(job_name) + self.log.info( + "Job `%s` has been soft deleted. Creating job with " + "new name `%s`", + job_name, + {body[JOB_NAME]}, + ) + # pylint: disable=no-member + return ( + self.get_conn() + .transferJobs() + .create(body=body) + .execute(num_retries=self.num_retries) + ) + elif transfer_job.get(STATUS) == GcpTransferJobsStatus.DISABLED: + return self.enable_transfer_job( + job_name=job_name, project_id=body.get(PROJECT_ID) + ) + else: + raise e + self.log.info("Created job %s", transfer_job[NAME]) + return transfer_job + + @GoogleBaseHook.fallback_to_default_project_id + def get_transfer_job(self, job_name: str, project_id: str) -> dict: + """ + Gets the latest state of a long-running operation in Google Storage + Transfer Service. + + :param job_name: (Required) Name of the job to be fetched + :type job_name: str + :param project_id: (Optional) the ID of the project that owns the Transfer + Job. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :return: Transfer Job + :rtype: dict + """ + return ( + self.get_conn() # pylint: disable=no-member + .transferJobs() + .get(jobName=job_name, projectId=project_id) + .execute(num_retries=self.num_retries) + ) + + def list_transfer_job( + self, request_filter: Optional[dict] = None, **kwargs + ) -> List[dict]: + """ + Lists long-running operations in Google Storage Transfer + Service that match the specified filter. + + :param request_filter: (Required) A request filter, as described in + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/list#body.QUERY_PARAMETERS.filter + :type request_filter: dict + :return: List of Transfer Jobs + :rtype: list[dict] + """ + # To preserve backward compatibility + # TODO: remove one day + if request_filter is None: + if "filter" in kwargs: + request_filter = kwargs["filter"] + if not isinstance(request_filter, dict): + raise ValueError( + f"The request_filter should be dict and is {type(request_filter)}" + ) + warnings.warn( + "Use 'request_filter' instead of 'filter'", DeprecationWarning + ) + else: + raise TypeError( + "list_transfer_job missing 1 required positional argument: 'request_filter'" + ) + + conn = self.get_conn() + request_filter = self._inject_project_id( + request_filter, FILTER, FILTER_PROJECT_ID + ) + request = conn.transferJobs().list( + filter=json.dumps(request_filter) + ) # pylint: disable=no-member + jobs: List[dict] = [] + + while request is not None: + response = request.execute(num_retries=self.num_retries) + jobs.extend(response[TRANSFER_JOBS]) + + # pylint: disable=no-member + request = conn.transferJobs().list_next( + previous_request=request, previous_response=response + ) + + return jobs + + @GoogleBaseHook.fallback_to_default_project_id + def enable_transfer_job(self, job_name: str, project_id: str) -> dict: + """ + New transfers will be performed based on the schedule. + + :param job_name: (Required) Name of the job to be updated + :type job_name: str + :param project_id: (Optional) the ID of the project that owns the Transfer + Job. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :return: If successful, TransferJob. + :rtype: dict + """ + return ( + self.get_conn() # pylint: disable=no-member + .transferJobs() + .patch( + jobName=job_name, + body={ + PROJECT_ID: project_id, + TRANSFER_JOB: {STATUS1: GcpTransferJobsStatus.ENABLED}, + TRANSFER_JOB_FIELD_MASK: STATUS1, + }, + ) + .execute(num_retries=self.num_retries) + ) + + def update_transfer_job(self, job_name: str, body: dict) -> dict: + """ + Updates a transfer job that runs periodically. + + :param job_name: (Required) Name of the job to be updated + :type job_name: str + :param body: A request body, as described in + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/patch#request-body + :type body: dict + :return: If successful, TransferJob. + :rtype: dict + """ + body = self._inject_project_id(body, BODY, PROJECT_ID) + return ( + self.get_conn() # pylint: disable=no-member + .transferJobs() + .patch(jobName=job_name, body=body) + .execute(num_retries=self.num_retries) + ) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_transfer_job(self, job_name: str, project_id: str) -> None: + """ + Deletes a transfer job. This is a soft delete. After a transfer job is + deleted, the job and all the transfer executions are subject to garbage + collection. Transfer jobs become eligible for garbage collection + 30 days after soft delete. + + :param job_name: (Required) Name of the job to be deleted + :type job_name: str + :param project_id: (Optional) the ID of the project that owns the Transfer + Job. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :rtype: None + """ + ( + self.get_conn() # pylint: disable=no-member + .transferJobs() + .patch( + jobName=job_name, + body={ + PROJECT_ID: project_id, + TRANSFER_JOB: {STATUS1: GcpTransferJobsStatus.DELETED}, + TRANSFER_JOB_FIELD_MASK: STATUS1, + }, + ) + .execute(num_retries=self.num_retries) + ) + + def cancel_transfer_operation(self, operation_name: str) -> None: + """ + Cancels an transfer operation in Google Storage Transfer Service. + + :param operation_name: Name of the transfer operation. + :type operation_name: str + :rtype: None + """ + self.get_conn().transferOperations().cancel( + name=operation_name + ).execute( # pylint: disable=no-member + num_retries=self.num_retries + ) + + def get_transfer_operation(self, operation_name: str) -> dict: + """ + Gets an transfer operation in Google Storage Transfer Service. + + :param operation_name: (Required) Name of the transfer operation. + :type operation_name: str + :return: transfer operation + See: + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/Operation + :rtype: dict + """ + return ( + self.get_conn() # pylint: disable=no-member + .transferOperations() + .get(name=operation_name) + .execute(num_retries=self.num_retries) + ) + + def list_transfer_operations( + self, request_filter: Optional[dict] = None, **kwargs + ) -> List[dict]: + """ + Gets an transfer operation in Google Storage Transfer Service. + + :param request_filter: (Required) A request filter, as described in + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/list#body.QUERY_PARAMETERS.filter + With one additional improvement: + + * project_id is optional if you have a project id defined + in the connection + See: :doc:`/connections/gcp` + + :type request_filter: dict + :return: transfer operation + :rtype: list[dict] + """ + # To preserve backward compatibility + # TODO: remove one day + if request_filter is None: + if "filter" in kwargs: + request_filter = kwargs["filter"] + if not isinstance(request_filter, dict): + raise ValueError( + f"The request_filter should be dict and is {type(request_filter)}" + ) + warnings.warn( + "Use 'request_filter' instead of 'filter'", DeprecationWarning + ) + else: + raise TypeError( + "list_transfer_operations missing 1 required positional argument: 'request_filter'" + ) + + conn = self.get_conn() + + request_filter = self._inject_project_id( + request_filter, FILTER, FILTER_PROJECT_ID + ) + + operations: List[dict] = [] + + request = conn.transferOperations().list( # pylint: disable=no-member + name=TRANSFER_OPERATIONS, filter=json.dumps(request_filter) + ) + + while request is not None: + response = request.execute(num_retries=self.num_retries) + if OPERATIONS in response: + operations.extend(response[OPERATIONS]) + + request = conn.transferOperations().list_next( # pylint: disable=no-member + previous_request=request, previous_response=response + ) + + return operations + + def pause_transfer_operation(self, operation_name: str) -> None: + """ + Pauses an transfer operation in Google Storage Transfer Service. + + :param operation_name: (Required) Name of the transfer operation. + :type operation_name: str + :rtype: None + """ + self.get_conn().transferOperations().pause( + name=operation_name + ).execute( # pylint: disable=no-member + num_retries=self.num_retries + ) + + def resume_transfer_operation(self, operation_name: str) -> None: + """ + Resumes an transfer operation in Google Storage Transfer Service. + + :param operation_name: (Required) Name of the transfer operation. + :type operation_name: str + :rtype: None + """ + self.get_conn().transferOperations().resume( + name=operation_name + ).execute( # pylint: disable=no-member + num_retries=self.num_retries + ) + + def wait_for_transfer_job( + self, + job: dict, + expected_statuses: Optional[Set[str]] = None, + timeout: Optional[Union[float, timedelta]] = None, + ) -> None: + """ + Waits until the job reaches the expected state. + + :param job: Transfer job + See: + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs#TransferJob + :type job: dict + :param expected_statuses: State that is expected + See: + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferOperations#Status + :type expected_statuses: set[str] + :param timeout: Time in which the operation must end in seconds. If not specified, defaults to 60 + seconds. + :type timeout: Optional[Union[float, timedelta]] + :rtype: None + """ + expected_statuses = ( + {GcpTransferOperationStatus.SUCCESS} + if not expected_statuses + else expected_statuses + ) + if timeout is None: + timeout = 60 + elif isinstance(timeout, timedelta): + timeout = timeout.total_seconds() + + start_time = time.monotonic() + while time.monotonic() - start_time < timeout: + request_filter = { + FILTER_PROJECT_ID: job[PROJECT_ID], + FILTER_JOB_NAMES: [job[NAME]], + } + operations = self.list_transfer_operations(request_filter=request_filter) + + for operation in operations: + self.log.info( + "Progress for operation %s: %s", + operation[NAME], + operation[METADATA][COUNTERS], + ) + + if self.operations_contain_expected_statuses(operations, expected_statuses): + return + time.sleep(TIME_TO_SLEEP_IN_SECONDS) + raise AirflowException( + "Timeout. The operation could not be completed within the allotted time." + ) + + def _inject_project_id(self, body: dict, param_name: str, target_key: str) -> dict: + body = deepcopy(body) + body[target_key] = body.get(target_key, self.project_id) + if not body.get(target_key): + raise AirflowException( + "The project id must be passed either as `{}` key in `{}` parameter or as project_id " + "extra in Google Cloud connection definition. Both are not set!".format( + target_key, param_name + ) + ) + return body + + @staticmethod + def operations_contain_expected_statuses( + operations: List[dict], expected_statuses: Union[Set[str], str] + ) -> bool: + """ + Checks whether the operation list has an operation with the + expected status, then returns true + If it encounters operations in FAILED or ABORTED state + throw :class:`airflow.exceptions.AirflowException`. + + :param operations: (Required) List of transfer operations to check. + :type operations: list[dict] + :param expected_statuses: (Required) status that is expected + See: + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferOperations#Status + :type expected_statuses: set[str] + :return: If there is an operation with the expected state + in the operation list, returns true, + :raises: airflow.exceptions.AirflowException If it encounters operations + with a state in the list, + :rtype: bool + """ + expected_statuses_set = ( + {expected_statuses} + if isinstance(expected_statuses, str) + else set(expected_statuses) + ) + if not operations: + return False + + current_statuses = {operation[METADATA][STATUS] for operation in operations} + + if len(current_statuses - expected_statuses_set) != len(current_statuses): + return True + + if len(NEGATIVE_STATUSES - current_statuses) != len(NEGATIVE_STATUSES): + raise AirflowException( + "An unexpected operation status was encountered. Expected: {}".format( + ", ".join(expected_statuses_set) + ) + ) + return False diff --git a/reference/providers/google/cloud/hooks/compute.py b/reference/providers/google/cloud/hooks/compute.py new file mode 100644 index 0000000..c1a01ed --- /dev/null +++ b/reference/providers/google/cloud/hooks/compute.py @@ -0,0 +1,485 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Compute Engine Hook.""" + +import time +from typing import Any, Dict, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import build + +# Time to sleep between active checks of the operation results +TIME_TO_SLEEP_IN_SECONDS = 1 + + +class GceOperationStatus: + """Class with GCE operations statuses.""" + + PENDING = "PENDING" + RUNNING = "RUNNING" + DONE = "DONE" + + +class ComputeEngineHook(GoogleBaseHook): + """ + Hook for Google Compute Engine APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + _conn = None # type: Optional[Any] + + def __init__( + self, + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.api_version = api_version + + def get_conn(self): + """ + Retrieves connection to Google Compute Engine. + + :return: Google Compute Engine services object + :rtype: dict + """ + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "compute", self.api_version, http=http_authorized, cache_discovery=False + ) + return self._conn + + @GoogleBaseHook.fallback_to_default_project_id + def start_instance(self, zone: str, resource_id: str, project_id: str) -> None: + """ + Starts an existing instance defined by project_id, zone and resource_id. + Must be called with keyword arguments rather than positional. + + :param zone: Google Cloud zone where the instance exists + :type zone: str + :param resource_id: Name of the Compute Engine instance resource + :type resource_id: str + :param project_id: Optional, Google Cloud project ID where the + Compute Engine Instance exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + # noqa pylint: disable=no-member + response = ( + self.get_conn() + .instances() + .start(project=project_id, zone=zone, instance=resource_id) + .execute(num_retries=self.num_retries) + ) + try: + operation_name = response["name"] + except KeyError: + raise AirflowException( + f"Wrong response '{response}' returned - it should contain 'name' field" + ) + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name, zone=zone + ) + + @GoogleBaseHook.fallback_to_default_project_id + def stop_instance(self, zone: str, resource_id: str, project_id: str) -> None: + """ + Stops an instance defined by project_id, zone and resource_id + Must be called with keyword arguments rather than positional. + + :param zone: Google Cloud zone where the instance exists + :type zone: str + :param resource_id: Name of the Compute Engine instance resource + :type resource_id: str + :param project_id: Optional, Google Cloud project ID where the + Compute Engine Instance exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + # noqa pylint: disable=no-member + response = ( + self.get_conn() + .instances() + .stop(project=project_id, zone=zone, instance=resource_id) + .execute(num_retries=self.num_retries) + ) + try: + operation_name = response["name"] + except KeyError: + raise AirflowException( + f"Wrong response '{response}' returned - it should contain 'name' field" + ) + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name, zone=zone + ) + + @GoogleBaseHook.fallback_to_default_project_id + def set_machine_type( + self, zone: str, resource_id: str, body: dict, project_id: str + ) -> None: + """ + Sets machine type of an instance defined by project_id, zone and resource_id. + Must be called with keyword arguments rather than positional. + + :param zone: Google Cloud zone where the instance exists. + :type zone: str + :param resource_id: Name of the Compute Engine instance resource + :type resource_id: str + :param body: Body required by the Compute Engine setMachineType API, + as described in + https://cloud.google.com/compute/docs/reference/rest/v1/instances/setMachineType + :type body: dict + :param project_id: Optional, Google Cloud project ID where the + Compute Engine Instance exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + response = self._execute_set_machine_type(zone, resource_id, body, project_id) + try: + operation_name = response["name"] + except KeyError: + raise AirflowException( + f"Wrong response '{response}' returned - it should contain 'name' field" + ) + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name, zone=zone + ) + + def _execute_set_machine_type( + self, zone: str, resource_id: str, body: dict, project_id: str + ) -> dict: + # noqa pylint: disable=no-member + return ( + self.get_conn() + .instances() + .setMachineType( + project=project_id, zone=zone, instance=resource_id, body=body + ) + .execute(num_retries=self.num_retries) + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_instance_template(self, resource_id: str, project_id: str) -> dict: + """ + Retrieves instance template by project_id and resource_id. + Must be called with keyword arguments rather than positional. + + :param resource_id: Name of the instance template + :type resource_id: str + :param project_id: Optional, Google Cloud project ID where the + Compute Engine Instance exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: Instance template representation as object according to + https://cloud.google.com/compute/docs/reference/rest/v1/instanceTemplates + :rtype: dict + """ + # noqa pylint: disable=no-member + response = ( + self.get_conn() + .instanceTemplates() + .get(project=project_id, instanceTemplate=resource_id) + .execute(num_retries=self.num_retries) + ) + return response + + @GoogleBaseHook.fallback_to_default_project_id + def insert_instance_template( + self, + body: dict, + project_id: str, + request_id: Optional[str] = None, + ) -> None: + """ + Inserts instance template using body specified + Must be called with keyword arguments rather than positional. + + :param body: Instance template representation as object according to + https://cloud.google.com/compute/docs/reference/rest/v1/instanceTemplates + :type body: dict + :param request_id: Optional, unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again) + It should be in UUID format as defined in RFC 4122 + :type request_id: str + :param project_id: Optional, Google Cloud project ID where the + Compute Engine Instance exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + # noqa pylint: disable=no-member + response = ( + self.get_conn() + .instanceTemplates() + .insert(project=project_id, body=body, requestId=request_id) + .execute(num_retries=self.num_retries) + ) + try: + operation_name = response["name"] + except KeyError: + raise AirflowException( + f"Wrong response '{response}' returned - it should contain 'name' field" + ) + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_instance_group_manager( + self, + zone: str, + resource_id: str, + project_id: str, + ) -> dict: + """ + Retrieves Instance Group Manager by project_id, zone and resource_id. + Must be called with keyword arguments rather than positional. + + :param zone: Google Cloud zone where the Instance Group Manager exists + :type zone: str + :param resource_id: Name of the Instance Group Manager + :type resource_id: str + :param project_id: Optional, Google Cloud project ID where the + Compute Engine Instance exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: Instance group manager representation as object according to + https://cloud.google.com/compute/docs/reference/rest/beta/instanceGroupManagers + :rtype: dict + """ + # noqa pylint: disable=no-member + response = ( + self.get_conn() + .instanceGroupManagers() + .get(project=project_id, zone=zone, instanceGroupManager=resource_id) + .execute(num_retries=self.num_retries) + ) + return response + + @GoogleBaseHook.fallback_to_default_project_id + def patch_instance_group_manager( + self, + zone: str, + resource_id: str, + body: dict, + project_id: str, + request_id: Optional[str] = None, + ) -> None: + """ + Patches Instance Group Manager with the specified body. + Must be called with keyword arguments rather than positional. + + :param zone: Google Cloud zone where the Instance Group Manager exists + :type zone: str + :param resource_id: Name of the Instance Group Manager + :type resource_id: str + :param body: Instance Group Manager representation as json-merge-patch object + according to + https://cloud.google.com/compute/docs/reference/rest/beta/instanceTemplates/patch + :type body: dict + :param request_id: Optional, unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again). + It should be in UUID format as defined in RFC 4122 + :type request_id: str + :param project_id: Optional, Google Cloud project ID where the + Compute Engine Instance exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + # noqa pylint: disable=no-member + response = ( + self.get_conn() + .instanceGroupManagers() + .patch( + project=project_id, + zone=zone, + instanceGroupManager=resource_id, + body=body, + requestId=request_id, + ) + .execute(num_retries=self.num_retries) + ) + try: + operation_name = response["name"] + except KeyError: + raise AirflowException( + f"Wrong response '{response}' returned - it should contain 'name' field" + ) + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name, zone=zone + ) + + def _wait_for_operation_to_complete( + self, project_id: str, operation_name: str, zone: Optional[str] = None + ) -> None: + """ + Waits for the named operation to complete - checks status of the async call. + + :param operation_name: name of the operation + :type operation_name: str + :param zone: optional region of the request (might be None for global operations) + :type zone: str + :return: None + """ + service = self.get_conn() + while True: + if zone is None: + operation_response = self._check_global_operation_status( + service=service, + operation_name=operation_name, + project_id=project_id, + num_retries=self.num_retries, + ) + else: + operation_response = self._check_zone_operation_status( + service, operation_name, project_id, zone, self.num_retries + ) + if operation_response.get("status") == GceOperationStatus.DONE: + error = operation_response.get("error") + if error: + code = operation_response.get("httpErrorStatusCode") + msg = operation_response.get("httpErrorMessage") + # Extracting the errors list as string and trimming square braces + error_msg = str(error.get("errors"))[1:-1] + raise AirflowException(f"{code} {msg}: " + error_msg) + break + time.sleep(TIME_TO_SLEEP_IN_SECONDS) + + @staticmethod + def _check_zone_operation_status( + service: Any, operation_name: str, project_id: str, zone: str, num_retries: int + ) -> dict: + return ( + service.zoneOperations() + .get(project=project_id, zone=zone, operation=operation_name) + .execute(num_retries=num_retries) + ) + + @staticmethod + def _check_global_operation_status( + service: Any, operation_name: str, project_id: str, num_retries: int + ) -> dict: + return ( + service.globalOperations() + .get(project=project_id, operation=operation_name) + .execute(num_retries=num_retries) + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_instance_info( + self, zone: str, resource_id: str, project_id: str + ) -> Dict[str, Any]: + """ + Gets instance information. + + :param zone: Google Cloud zone where the Instance Group Manager exists + :type zone: str + :param resource_id: Name of the Instance Group Manager + :type resource_id: str + :param project_id: Optional, Google Cloud project ID where the + Compute Engine Instance exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + """ + instance_info = ( + self.get_conn() # pylint: disable=no-member + .instances() + .get(project=project_id, instance=resource_id, zone=zone) + .execute(num_retries=self.num_retries) + ) + return instance_info + + @GoogleBaseHook.fallback_to_default_project_id + def get_instance_address( + self, + zone: str, + resource_id: str, + project_id: str, + use_internal_ip: bool = False, + ) -> str: + """ + Return network address associated to instance. + + :param zone: Google Cloud zone where the Instance Group Manager exists + :type zone: str + :param resource_id: Name of the Instance Group Manager + :type resource_id: str + :param project_id: Optional, Google Cloud project ID where the + Compute Engine Instance exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :param use_internal_ip: If true, return private IP address. + :type use_internal_ip: bool + """ + instance_info = self.get_instance_info( + project_id=project_id, resource_id=resource_id, zone=zone + ) + if use_internal_ip: + return instance_info["networkInterfaces"][0].get("networkIP") + + access_config = instance_info.get("networkInterfaces")[0].get("accessConfigs") + if access_config: + return access_config[0].get("natIP") + raise AirflowException("The target instance does not have external IP") + + @GoogleBaseHook.fallback_to_default_project_id + def set_instance_metadata( + self, zone: str, resource_id: str, metadata: Dict[str, str], project_id: str + ) -> None: + """ + Set instance metadata. + + :param zone: Google Cloud zone where the Instance Group Manager exists + :type zone: str + :param resource_id: Name of the Instance Group Manager + :type resource_id: str + :param metadata: The new instance metadata. + :type metadata: Dict + :param project_id: Optional, Google Cloud project ID where the + Compute Engine Instance exists. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + """ + response = ( + self.get_conn() # pylint: disable=no-member + .instances() + .setMetadata( # pylint: disable=no-member + project=project_id, zone=zone, instance=resource_id, body=metadata + ) + .execute(num_retries=self.num_retries) + ) + operation_name = response["name"] + self._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name, zone=zone + ) diff --git a/reference/providers/google/cloud/hooks/compute_ssh.py b/reference/providers/google/cloud/hooks/compute_ssh.py new file mode 100644 index 0000000..bbcdd40 --- /dev/null +++ b/reference/providers/google/cloud/hooks/compute_ssh.py @@ -0,0 +1,354 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import shlex +import time +from io import StringIO +from typing import Any, Dict, Optional + +import paramiko + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow import AirflowException +from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook +from airflow.providers.google.cloud.hooks.os_login import OSLoginHook +from airflow.providers.ssh.hooks.ssh import SSHHook +from google.api_core.retry import exponential_sleep_generator + + +class _GCloudAuthorizedSSHClient(paramiko.SSHClient): + """SSH Client that maintains the context for gcloud authorization during the connection""" + + def __init__(self, google_hook, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ssh_client = paramiko.SSHClient() + self.google_hook = google_hook + self.decorator = None + + def connect(self, *args, **kwargs): # pylint: disable=signature-differs + self.decorator = self.google_hook.provide_authorized_gcloud() + self.decorator.__enter__() + return super().connect(*args, **kwargs) + + def close(self): + if self.decorator: + self.decorator.__exit__(None, None, None) + self.decorator = None + return super().close() + + def __exit__(self, type_, value, traceback): + if self.decorator: + self.decorator.__exit__(type_, value, traceback) + self.decorator = None + return super().__exit__(type_, value, traceback) + + +class ComputeEngineSSHHook(SSHHook): + """ + Hook to connect to a remote instance in compute engine + + :param instance_name: The name of the Compute Engine instance + :type instance_name: str + :param zone: The zone of the Compute Engine instance + :type zone: str + :param user: The name of the user on which the login attempt will be made + :type user: str + :param project_id: The project ID of the remote instance + :type project_id: str + :param gcp_conn_id: The connection id to use when fetching connection info + :type gcp_conn_id: str + :param hostname: The hostname of the target instance. If it is not passed, it will be detected + automatically. + :type hostname: str + :param use_iap_tunnel: Whether to connect through IAP tunnel + :type use_iap_tunnel: bool + :param use_internal_ip: Whether to connect using internal IP + :type use_internal_ip: bool + :param use_oslogin: Whether to manage keys using OsLogin API. If false, + keys are managed using instance metadata + :param expire_time: The maximum amount of time in seconds before the private key expires + :type expire_time: int + :param gcp_conn_id: The connection id to use when fetching connection information + :type gcp_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + """ + + conn_name_attr = "gcp_conn_id" + default_conn_name = "google_cloud_default" + conn_type = "gcpssh" + hook_name = "Google Cloud SSH" + + @staticmethod + def get_ui_field_behaviour() -> Dict: + return { + "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], + "relabeling": {}, + } + + def __init__( # pylint: disable=too-many-arguments + self, + gcp_conn_id: str = "google_cloud_default", + instance_name: Optional[str] = None, + zone: Optional[str] = None, + user: Optional[str] = "root", + project_id: Optional[str] = None, + hostname: Optional[str] = None, + use_internal_ip: bool = False, + use_iap_tunnel: bool = False, + use_oslogin: bool = True, + expire_time: int = 300, + delegate_to: Optional[str] = None, + ) -> None: + # Ignore original constructor + # super().__init__() # pylint: disable=super-init-not-called + self.instance_name = instance_name + self.zone = zone + self.user = user + self.project_id = project_id + self.hostname = hostname + self.use_internal_ip = use_internal_ip + self.use_iap_tunnel = use_iap_tunnel + self.use_oslogin = use_oslogin + self.expire_time = expire_time + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self._conn: Optional[Any] = None + + @cached_property + def _oslogin_hook(self) -> OSLoginHook: + return OSLoginHook(gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to) + + @cached_property + def _compute_hook(self) -> ComputeEngineHook: + return ComputeEngineHook( + gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to + ) + + def _load_connection_config(self): + def _boolify(value): + if isinstance(value, bool): + return value + if isinstance(value, str): + if value.lower() == "false": + return False + elif value.lower() == "true": + return True + return False + + def intify(key, value, default): + if value is None: + return default + if isinstance(value, str) and value.strip() == "": + return default + try: + return int(value) + except ValueError: + raise AirflowException( + f"The {key} field should be a integer. " + f'Current value: "{value}" (type: {type(value)}). ' + f"Please check the connection configuration." + ) + + conn = self.get_connection(self.gcp_conn_id) + if conn and conn.conn_type == "gcpssh": + self.instance_name = ( + self._compute_hook._get_field( # pylint: disable=protected-access + "instance_name", self.instance_name + ) + ) + self.zone = self._compute_hook._get_field( + "zone", self.zone + ) # pylint: disable=protected-access + self.user = conn.login if conn.login else self.user + # self.project_id is skipped intentionally + self.hostname = conn.host if conn.host else self.hostname + self.use_internal_ip = _boolify( + self._compute_hook._get_field( + "use_internal_ip" + ) # pylint: disable=protected-access + ) + self.use_iap_tunnel = _boolify( + self._compute_hook._get_field( + "use_iap_tunnel" + ) # pylint: disable=protected-access + ) + self.use_oslogin = _boolify( + self._compute_hook._get_field( + "use_oslogin" + ) # pylint: disable=protected-access + ) + self.expire_time = intify( + "expire_time", + self._compute_hook._get_field( + "expire_time" + ), # pylint: disable=protected-access + self.expire_time, + ) + + def get_conn(self) -> paramiko.SSHClient: + """Return SSH connection.""" + self._load_connection_config() + if not self.project_id: + self.project_id = self._compute_hook.project_id + + missing_fields = [ + k for k in ["instance_name", "zone", "project_id"] if not getattr(self, k) + ] + if not self.instance_name or not self.zone or not self.project_id: + raise AirflowException( + f"Required parameters are missing: {missing_fields}. These parameters be passed either as " + "keyword parameter or as extra field in Airfow connection definition. Both are not set!" + ) + + self.log.info( + "Connecting to instance: instance_name=%s, user=%s, zone=%s, " + "use_internal_ip=%s, use_iap_tunnel=%s, use_os_login=%s", + self.instance_name, + self.user, + self.zone, + self.use_internal_ip, + self.use_iap_tunnel, + self.use_oslogin, + ) + if not self.hostname: + hostname = self._compute_hook.get_instance_address( + zone=self.zone, + resource_id=self.instance_name, + project_id=self.project_id, + use_internal_ip=self.use_internal_ip or self.use_iap_tunnel, + ) + else: + hostname = self.hostname + + privkey, pubkey = self._generate_ssh_key(self.user) + if self.use_oslogin: + user = self._authorize_os_login(pubkey) + else: + user = self.user + self._authorize_compute_engine_instance_metadata(pubkey) + + proxy_command = None + if self.use_iap_tunnel: + proxy_command_args = [ + "gcloud", + "compute", + "start-iap-tunnel", + str(self.instance_name), + "22", + "--listen-on-stdin", + f"--project={self.project_id}", + f"--zone={self.zone}", + "--verbosity=warning", + ] + proxy_command = " ".join(shlex.quote(arg) for arg in proxy_command_args) + + sshclient = self._connect_to_instance(user, hostname, privkey, proxy_command) + return sshclient + + def _connect_to_instance( + self, user, hostname, pkey, proxy_command + ) -> paramiko.SSHClient: + self.log.info( + "Opening remote connection to host: username=%s, hostname=%s", + user, + hostname, + ) + max_time_to_wait = 10 + for time_to_wait in exponential_sleep_generator( + initial=1, maximum=max_time_to_wait + ): + try: + client = _GCloudAuthorizedSSHClient(self._compute_hook) + # Default is RejectPolicy + # No known host checking since we are not storing privatekey + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + client.connect( + hostname=hostname, + username=user, + pkey=pkey, + sock=paramiko.ProxyCommand(proxy_command) + if proxy_command + else None, + look_for_keys=False, + ) + return client + except paramiko.SSHException: + # exponential_sleep_generator is an infinite generator, so we need to + # check the end condition. + if time_to_wait == max_time_to_wait: + raise + self.log.info("Failed to connect. Waiting %ds to retry", time_to_wait) + time.sleep(time_to_wait) + raise AirflowException("Caa not connect to instance") + + def _authorize_compute_engine_instance_metadata(self, pubkey): + self.log.info("Appending SSH public key to instance metadata") + instance_info = self._compute_hook.get_instance_info( + zone=self.zone, resource_id=self.instance_name, project_id=self.project_id + ) + + keys = self.user + ":" + pubkey + "\n" + metadata = instance_info["metadata"] + items = metadata.get("items", []) + for item in items: + if item.get("key") == "ssh-keys": + keys += item["value"] + item["value"] = keys + break + else: + new_dict = dict(key="ssh-keys", value=keys) + metadata["items"] = [new_dict] + + self._compute_hook.set_instance_metadata( + zone=self.zone, + resource_id=self.instance_name, + metadata=metadata, + project_id=self.project_id, + ) + + def _authorize_os_login(self, pubkey): + username = ( + self._oslogin_hook._get_credentials_email() + ) # pylint: disable=protected-access + self.log.info("Importing SSH public key using OSLogin: user=%s", username) + expiration = int((time.time() + self.expire_time) * 1000000) + ssh_public_key = {"key": pubkey, "expiration_time_usec": expiration} + response = self._oslogin_hook.import_ssh_public_key( + user=username, ssh_public_key=ssh_public_key, project_id=self.project_id + ) + profile = response.login_profile + account = profile.posix_accounts[0] + user = account.username + return user + + def _generate_ssh_key(self, user): + try: + self.log.info("Generating ssh keys...") + pkey_file = StringIO() + pkey_obj = paramiko.RSAKey.generate(2048) + pkey_obj.write_private_key(pkey_file) + pubkey = f"{pkey_obj.get_name()} {pkey_obj.get_base64()} {user}" + return pkey_obj, pubkey + except (OSError, paramiko.SSHException) as err: + raise AirflowException(f"Error encountered creating ssh keys, {err}") diff --git a/reference/providers/google/cloud/hooks/datacatalog.py b/reference/providers/google/cloud/hooks/datacatalog.py new file mode 100644 index 0000000..3dcf081 --- /dev/null +++ b/reference/providers/google/cloud/hooks/datacatalog.py @@ -0,0 +1,1379 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict, Optional, Sequence, Tuple, Union + +from airflow import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.retry import Retry +from google.cloud import datacatalog +from google.cloud.datacatalog_v1beta1 import ( + CreateTagRequest, + DataCatalogClient, + Entry, + EntryGroup, + SearchCatalogRequest, + Tag, + TagTemplate, + TagTemplateField, +) +from google.protobuf.field_mask_pb2 import FieldMask + + +class CloudDataCatalogHook(GoogleBaseHook): + """ + Hook for Google Cloud Data Catalog Service. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client: Optional[DataCatalogClient] = None + + def get_conn(self) -> DataCatalogClient: + """Retrieves client library object that allow access to Cloud Data Catalog service.""" + if not self._client: + self._client = DataCatalogClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._client + + @GoogleBaseHook.fallback_to_default_project_id + def create_entry( + self, + location: str, + entry_group: str, + entry_id: str, + entry: Union[dict, Entry], + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Entry: + """ + Creates an entry. + + Currently only entries of 'FILESET' type can be created. + + :param location: Required. The location of the entry to create. + :type location: str + :param entry_group: Required. Entry group ID under which the entry is created. + :type entry_group: str + :param entry_id: Required. The id of the entry to create. + :type entry_id: str + :param entry: Required. The entry to create. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.Entry` + :type entry: Union[Dict, google.cloud.datacatalog_v1beta1.types.Entry] + :param project_id: The ID of the Google Cloud project that owns the entry. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If set to ``None`` or missing, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}" + self.log.info("Creating a new entry: parent=%s", parent) + result = client.create_entry( + request={"parent": parent, "entry_id": entry_id, "entry": entry}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Created a entry: name=%s", result.name) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_entry_group( + self, + location: str, + entry_group_id: str, + entry_group: Union[Dict, EntryGroup], + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> EntryGroup: + """ + Creates an EntryGroup. + + :param location: Required. The location of the entry group to create. + :type location: str + :param entry_group_id: Required. The id of the entry group to create. The id must begin with a letter + or underscore, contain only English letters, numbers and underscores, and be at most 64 + characters. + :type entry_group_id: str + :param entry_group: The entry group to create. Defaults to an empty entry group. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.EntryGroup` + :type entry_group: Union[Dict, google.cloud.datacatalog_v1beta1.types.EntryGroup] + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + parent = f"projects/{project_id}/locations/{location}" + self.log.info("Creating a new entry group: parent=%s", parent) + + result = client.create_entry_group( + request={ + "parent": parent, + "entry_group_id": entry_group_id, + "entry_group": entry_group, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Created a entry group: name=%s", result.name) + + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_tag( + self, + location: str, + entry_group: str, + entry: str, + tag: Union[dict, Tag], + project_id: str, + template_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Tag: + """ + Creates a tag on an entry. + + :param location: Required. The location of the tag to create. + :type location: str + :param entry_group: Required. Entry group ID under which the tag is created. + :type entry_group: str + :param entry: Required. Entry group ID under which the tag is created. + :type entry: str + :param tag: Required. The tag to create. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.Tag` + :type tag: Union[Dict, google.cloud.datacatalog_v1beta1.types.Tag] + :param template_id: Required. Template ID used to create tag + :type template_id: Optional[str] + :param project_id: The ID of the Google Cloud project that owns the tag. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + if template_id: + template_path = ( + f"projects/{project_id}/locations/{location}/tagTemplates/{template_id}" + ) + if isinstance(tag, Tag): + tag.template = template_path + else: + tag["template"] = template_path + parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" + + self.log.info("Creating a new tag: parent=%s", parent) + # HACK: google-cloud-datacatalog has problems with mapping messages where the value is not a + # primitive type, so we need to convert it manually. + # See: https://github.com/googleapis/python-datacatalog/issues/84 + if isinstance(tag, dict): + tag = Tag( + name=tag.get("name"), + template=tag.get("template"), + template_display_name=tag.get("template_display_name"), + column=tag.get("column"), + fields={ + k: datacatalog.TagField(**v) if isinstance(v, dict) else v + for k, v in tag.get("fields", {}).items() + }, + ) + request = CreateTagRequest( + parent=parent, + tag=tag, + ) + + result = client.create_tag( + request=request, retry=retry, timeout=timeout, metadata=metadata or () + ) + self.log.info("Created a tag: name=%s", result.name) + + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_tag_template( + self, + location, + tag_template_id: str, + tag_template: Union[dict, TagTemplate], + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> TagTemplate: + """ + Creates a tag template. + + :param location: Required. The location of the tag template to create. + :type location: str + :param tag_template_id: Required. The id of the tag template to create. + :type tag_template_id: str + :param tag_template: Required. The tag template to create. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.TagTemplate` + :type tag_template: Union[Dict, google.cloud.datacatalog_v1beta1.types.TagTemplate] + :param project_id: The ID of the Google Cloud project that owns the tag template. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + parent = f"projects/{project_id}/locations/{location}" + + self.log.info("Creating a new tag template: parent=%s", parent) + # HACK: google-cloud-datacatalog has problems with mapping messages where the value is not a + # primitive type, so we need to convert it manually. + # See: https://github.com/googleapis/python-datacatalog/issues/84 + if isinstance(tag_template, dict): + tag_template = datacatalog.TagTemplate( + name=tag_template.get("name"), + display_name=tag_template.get("display_name"), + fields={ + k: datacatalog.TagTemplateField(**v) if isinstance(v, dict) else v + for k, v in tag_template.get("fields", {}).items() + }, + ) + + request = datacatalog.CreateTagTemplateRequest( + parent=parent, tag_template_id=tag_template_id, tag_template=tag_template + ) + result = client.create_tag_template( + request=request, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Created a tag template: name=%s", result.name) + + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_tag_template_field( + self, + location: str, + tag_template: str, + tag_template_field_id: str, + tag_template_field: Union[dict, TagTemplateField], + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> TagTemplateField: + r""" + Creates a field in a tag template. + + :param location: Required. The location of the tag template field to create. + :type location: str + :param tag_template: Required. The id of the tag template to create. + :type tag_template: str + :param tag_template_field_id: Required. The ID of the tag template field to create. Field ids can + contain letters (both uppercase and lowercase), numbers (0-9), underscores (\_) and dashes (-). + Field IDs must be at least 1 character long and at most 128 characters long. Field IDs must also + be unique within their template. + :type tag_template_field_id: str + :param tag_template_field: Required. The tag template field to create. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.TagTemplateField` + :type tag_template_field: Union[Dict, google.cloud.datacatalog_v1beta1.types.TagTemplateField] + :param project_id: The ID of the Google Cloud project that owns the tag template field. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + parent = ( + f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}" + ) + + self.log.info("Creating a new tag template field: parent=%s", parent) + + result = client.create_tag_template_field( + request={ + "parent": parent, + "tag_template_field_id": tag_template_field_id, + "tag_template_field": tag_template_field, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + self.log.info("Created a tag template field: name=%s", result.name) + + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_entry( + self, + location: str, + entry_group: str, + entry: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes an existing entry. + + :param location: Required. The location of the entry to delete. + :type location: str + :param entry_group: Required. Entry group ID for entries that is deleted. + :type entry_group: str + :param entry: Entry ID that is deleted. + :type entry: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" + self.log.info("Deleting a entry: name=%s", name) + client.delete_entry( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Deleted a entry: name=%s", name) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_entry_group( + self, + location, + entry_group, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes an EntryGroup. + + Only entry groups that do not contain entries can be deleted. + + :param location: Required. The location of the entry group to delete. + :type location: str + :param entry_group: Entry group ID that is deleted. + :type entry_group: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}" + + self.log.info("Deleting a entry group: name=%s", name) + client.delete_entry_group( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Deleted a entry group: name=%s", name) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_tag( + self, + location: str, + entry_group: str, + entry: str, + tag: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes a tag. + + :param location: Required. The location of the tag to delete. + :type location: str + :param entry_group: Entry group ID for tag that is deleted. + :type entry_group: str + :param entry: Entry ID for tag that is deleted. + :type entry: str + :param tag: Identifier for TAG that is deleted. + :type tag: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}/tags/{tag}" + + self.log.info("Deleting a tag: name=%s", name) + client.delete_tag( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Deleted a tag: name=%s", name) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_tag_template( + self, + location, + tag_template, + force: bool, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes a tag template and all tags using the template. + + :param location: Required. The location of the tag template to delete. + :type location: str + :param tag_template: ID for tag template that is deleted. + :type tag_template: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param force: Required. Currently, this field must always be set to ``true``. This confirms the + deletion of any possible tags using this template. ``force = false`` will be supported in the + future. + :type force: bool + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}" + + self.log.info("Deleting a tag template: name=%s", name) + client.delete_tag_template( + request={"name": name, "force": force}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Deleted a tag template: name=%s", name) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_tag_template_field( + self, + location: str, + tag_template: str, + field: str, + force: bool, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes a field in a tag template and all uses of that field. + + :param location: Required. The location of the tag template to delete. + :type location: str + :param tag_template: Tag Template ID for tag template field that is deleted. + :type tag_template: str + :param field: Name of field that is deleted. + :type field: str + :param force: Required. This confirms the deletion of this field from any tags using this field. + :type force: bool + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}/fields/{field}" + + self.log.info("Deleting a tag template field: name=%s", name) + client.delete_tag_template_field( + request={"name": name, "force": force}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Deleted a tag template field: name=%s", name) + + @GoogleBaseHook.fallback_to_default_project_id + def get_entry( + self, + location: str, + entry_group: str, + entry: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Entry: + """ + Gets an entry. + + :param location: Required. The location of the entry to get. + :type location: str + :param entry_group: Required. The entry group of the entry to get. + :type entry_group: str + :param entry: The ID of the entry to get. + :type entry: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" + + self.log.info("Getting a entry: name=%s", name) + result = client.get_entry( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Received a entry: name=%s", result.name) + + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_entry_group( + self, + location: str, + entry_group: str, + project_id: str, + read_mask: Union[Dict, FieldMask] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> EntryGroup: + """ + Gets an entry group. + + :param location: Required. The location of the entry group to get. + :type location: str + :param entry_group: The ID of the entry group to get. + :type entry_group: str + :param read_mask: The fields to return. If not set or empty, all fields are returned. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type read_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}" + + self.log.info("Getting a entry group: name=%s", name) + + result = client.get_entry_group( + request={"name": name, "read_mask": read_mask}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + self.log.info("Received a entry group: name=%s", result.name) + + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_tag_template( + self, + location: str, + tag_template: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> TagTemplate: + """ + Gets a tag template. + + :param location: Required. The location of the tag template to get. + :type location: str + :param tag_template: Required. The ID of the tag template to get. + :type tag_template: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}" + + self.log.info("Getting a tag template: name=%s", name) + + result = client.get_tag_template( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + self.log.info("Received a tag template: name=%s", result.name) + + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_tags( + self, + location: str, + entry_group: str, + entry: str, + project_id: str, + page_size: int = 100, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Lists the tags on an Entry. + + :param location: Required. The location of the tags to get. + :type location: str + :param entry_group: Required. The entry group of the tags to get. + :type entry_group: str + :param entry_group: Required. The entry of the tags to get. + :type entry: str + :param page_size: The maximum number of resources contained in the underlying API response. If page + streaming is performed per- resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number of resources in a page. + :type page_size: int + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" + + self.log.info("Listing tag on entry: entry_name=%s", parent) + + result = client.list_tags( + request={"parent": parent, "page_size": page_size}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + self.log.info("Received tags.") + + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_tag_for_template_name( + self, + location: str, + entry_group: str, + entry: str, + template_name: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Tag: + """ + Gets for a tag with a specific template for a specific entry. + + :param location: Required. The location which contains the entry to search for. + :type location: str + :param entry_group: The entry group ID which contains the entry to search for. + :type entry_group: str + :param entry: The name of the entry to search for. + :type entry: str + :param template_name: The name of the template that will be the search criterion. + :type template_name: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + tags_list = self.list_tags( + location=location, + entry_group=entry_group, + entry=entry, + project_id=project_id, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + tag = next(t for t in tags_list if t.template == template_name) + return tag + + def lookup_entry( + self, + linked_re# Optional[str] = None, + sql_re# Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Entry: + r""" + Get an entry by target resource name. + + This method allows clients to use the resource name from the source Google Cloud service + to get the Data Catalog Entry. + + :param linked_re# The full name of the Google Cloud resource the Data Catalog entry + represents. See: https://cloud.google.com/apis/design/resource\_names#full\_resource\_name. Full + names are case-sensitive. + + :type linked_re# str + :param sql_re# The SQL name of the entry. SQL names are case-sensitive. + :type sql_re# str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + if linked_resource and sql_re# + raise AirflowException( + "Only one of linked_resource, sql_resource should be set." + ) + + if not linked_resource and not sql_re# + raise AirflowException( + "At least one of linked_resource, sql_resource should be set." + ) + + if linked_re# + self.log.info("Getting entry: linked_resource=%s", linked_resource) + result = client.lookup_entry( + request={"linked_resource": linked_resource}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + else: + self.log.info("Getting entry: sql_resource=%s", sql_resource) + result = client.lookup_entry( + request={"sql_resource": sql_resource}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Received entry. name=%s", result.name) + + return result + + @GoogleBaseHook.fallback_to_default_project_id + def rename_tag_template_field( + self, + location: str, + tag_template: str, + field: str, + new_tag_template_field_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> TagTemplateField: + """ + Renames a field in a tag template. + + :param location: Required. The location of the tag template field to rename. + :type location: str + :param tag_template: The tag template ID for field that is renamed. + :type tag_template: str + :param field: Required. The old ID of this tag template field. For example, + ``my_old_field``. + :type field: str + :param new_tag_template_field_id: Required. The new ID of this tag template field. For example, + ``my_new_field``. + :type new_tag_template_field_id: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}/fields/{field}" + + self.log.info( + "Renaming field: old_name=%s, new_tag_template_field_id=%s", + name, + new_tag_template_field_id, + ) + + result = client.rename_tag_template_field( + request={ + "name": name, + "new_tag_template_field_id": new_tag_template_field_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + self.log.info("Renamed tag template field.") + + return result + + def search_catalog( + self, + scope: Union[Dict, SearchCatalogRequest.Scope], + query: str, + page_size: int = 100, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + r""" + Searches Data Catalog for multiple resources like entries, tags that match a query. + + This does not return the complete resource, only the resource identifier and high level fields. + Clients can subsequently call ``Get`` methods. + + Note that searches do not have full recall. There may be results that match your query but are not + returned, even in subsequent pages of results. These missing results may vary across repeated calls to + search. Do not rely on this method if you need to guarantee full recall. + + :param scope: Required. The scope of this search request. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.Scope` + :type scope: Union[Dict, google.cloud.datacatalog_v1beta1.types.SearchCatalogRequest.Scope] + :param query: Required. The query string in search query syntax. The query must be non-empty. + + Query strings can be simple as "x" or more qualified as: + + - name:x + - column:x + - description:y + + Note: Query tokens need to have a minimum of 3 characters for substring matching to work + correctly. See `Data Catalog Search Syntax `__ for more information. + :type query: str + :param page_size: The maximum number of resources contained in the underlying API response. If page + streaming is performed per-resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number of resources in a page. + :type page_size: int + :param order_by: Specifies the ordering of results, currently supported case-sensitive choices are: + + - ``relevance``, only supports descending + - ``last_access_timestamp [asc|desc]``, defaults to descending if not specified + - ``last_modified_timestamp [asc|desc]``, defaults to descending if not specified + + If not specified, defaults to ``relevance`` descending. + :type order_by: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + + self.log.info( + "Searching catalog: scope=%s, query=%s, page_size=%s, order_by=%s", + scope, + query, + page_size, + order_by, + ) + result = client.search_catalog( + request={ + "scope": scope, + "query": query, + "page_size": page_size, + "order_by": order_by, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + self.log.info("Received items.") + + return result + + @GoogleBaseHook.fallback_to_default_project_id + def update_entry( + self, + entry: Union[Dict, Entry], + update_mask: Union[dict, FieldMask], + project_id: str, + location: Optional[str] = None, + entry_group: Optional[str] = None, + entry_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Entry: + """ + Updates an existing entry. + + :param entry: Required. The updated entry. The "name" field must be set. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.Entry` + :type entry: Union[Dict, google.cloud.datacatalog_v1beta1.types.Entry] + :param update_mask: The fields to update on the entry. If absent or empty, all modifiable fields are + updated. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] + :param location: Required. The location of the entry to update. + :type location: str + :param entry_group: The entry group ID for the entry that is being updated. + :type entry_group: str + :param entry_id: The entry ID that is being updated. + :type entry_id: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + if project_id and location and entry_group and entry_id: + full_entry_name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry_id}" + if isinstance(entry, Entry): + entry.name = full_entry_name + elif isinstance(entry, dict): + entry["name"] = full_entry_name + else: + raise AirflowException("Unable to set entry's name.") + elif location and entry_group and entry_id: + raise AirflowException( + "You must provide all the parameters (project_id, location, entry_group, entry_id) " + "contained in the name, or do not specify any parameters and pass the name on the object " + ) + name = entry.name if isinstance(entry, Entry) else entry["name"] + self.log.info("Updating entry: name=%s", name) + + # HACK: google-cloud-datacatalog has a problem with dictionaries for update methods. + if isinstance(entry, dict): + entry = Entry(**entry) + result = client.update_entry( + request={"entry": entry, "update_mask": update_mask}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + self.log.info("Updated entry.") + + return result + + @GoogleBaseHook.fallback_to_default_project_id + def update_tag( # pylint: disable=too-many-arguments + self, + tag: Union[Dict, Tag], + update_mask: Union[Dict, FieldMask], + project_id: str, + location: Optional[str] = None, + entry_group: Optional[str] = None, + entry: Optional[str] = None, + tag_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Tag: + """ + Updates an existing tag. + + :param tag: Required. The updated tag. The "name" field must be set. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.Tag` + :type tag: Union[Dict, google.cloud.datacatalog_v1beta1.types.Tag] + :param update_mask: The fields to update on the Tag. If absent or empty, all modifiable fields are + updated. Currently the only modifiable field is the field ``fields``. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] + :param location: Required. The location of the tag to rename. + :type location: str + :param entry_group: The entry group ID for the tag that is being updated. + :type entry_group: str + :param entry: The entry ID for the tag that is being updated. + :type entry: str + :param tag_id: The tag ID that is being updated. + :type tag_id: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + if project_id and location and entry_group and entry and tag_id: + full_tag_name = ( + f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" + f"/tags/{tag_id}" + ) + if isinstance(tag, Tag): + tag.name = full_tag_name + elif isinstance(tag, dict): + tag["name"] = full_tag_name + else: + raise AirflowException("Unable to set tag's name.") + elif location and entry_group and entry and tag_id: + raise AirflowException( + "You must provide all the parameters (project_id, location, entry_group, entry, tag_id) " + "contained in the name, or do not specify any parameters and pass the name on the object " + ) + + name = tag.name if isinstance(tag, Tag) else tag["name"] + self.log.info("Updating tag: name=%s", name) + + # HACK: google-cloud-datacatalog has a problem with dictionaries for update methods. + if isinstance(tag, dict): + tag = Tag(**tag) + result = client.update_tag( + request={"tag": tag, "update_mask": update_mask}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Updated tag.") + + return result + + @GoogleBaseHook.fallback_to_default_project_id + def update_tag_template( + self, + tag_template: Union[dict, TagTemplate], + update_mask: Union[dict, FieldMask], + project_id: str, + location: Optional[str] = None, + tag_template_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> TagTemplate: + """ + Updates a tag template. + + This method cannot be used to update the fields of a template. The tag + template fields are represented as separate resources and should be updated using their own + create/update/delete methods. + + :param tag_template: Required. The template to update. The "name" field must be set. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.TagTemplate` + :type tag_template: Union[Dict, google.cloud.datacatalog_v1beta1.types.TagTemplate] + :param update_mask: The field mask specifies the parts of the template to overwrite. + + If absent or empty, all of the allowed fields above will be updated. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] + :param location: Required. The location of the tag template to rename. + :type location: str + :param tag_template_id: Optional. The tag template ID for the entry that is being updated. + :type tag_template_id: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + if project_id and location and tag_template: + full_tag_template_name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template_id}" + if isinstance(tag_template, TagTemplate): + tag_template.name = full_tag_template_name + elif isinstance(tag_template, dict): + tag_template["name"] = full_tag_template_name + else: + raise AirflowException("Unable to set name of tag template.") + elif location and tag_template: + raise AirflowException( + "You must provide all the parameters (project_id, location, tag_template_id) " + "contained in the name, or do not specify any parameters and pass the name on the object " + ) + + name = ( + tag_template.name + if isinstance(tag_template, TagTemplate) + else tag_template["name"] + ) + self.log.info("Updating tag template: name=%s", name) + + # HACK: google-cloud-datacatalog has a problem with dictionaries for update methods. + if isinstance(tag_template, dict): + tag_template = TagTemplate(**tag_template) + result = client.update_tag_template( + request={"tag_template": tag_template, "update_mask": update_mask}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Updated tag template.") + + return result + + @GoogleBaseHook.fallback_to_default_project_id + def update_tag_template_field( # pylint: disable=too-many-arguments + self, + tag_template_field: Union[dict, TagTemplateField], + update_mask: Union[dict, FieldMask], + project_id: str, + tag_template_field_name: Optional[str] = None, + location: Optional[str] = None, + tag_template: Optional[str] = None, + tag_template_field_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Updates a field in a tag template. This method cannot be used to update the field type. + + :param tag_template_field: Required. The template to update. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.TagTemplateField` + :type tag_template_field: Union[Dict, google.cloud.datacatalog_v1beta1.types.TagTemplateField] + :param update_mask: The field mask specifies the parts of the template to be updated. Allowed fields: + + - ``display_name`` + - ``type.enum_type`` + + If ``update_mask`` is not set or empty, all of the allowed fields above will be updated. + + When updating an enum type, the provided values will be merged with the existing values. + Therefore, enum values can only be added, existing enum values cannot be deleted nor renamed. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] + :param tag_template_field_name: Optional. The name of the tag template field to rename. + :type tag_template_field_name: str + :param location: Optional. The location of the tag to rename. + :type location: str + :param tag_template: Optional. The tag template ID for tag template field to rename. + :type tag_template: str + :param tag_template_field_id: Optional. The ID of tag template field to rename. + :type tag_template_field_id: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + if project_id and location and tag_template and tag_template_field_id: + tag_template_field_name = ( + f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}" + f"/fields/{tag_template_field_id}" + ) + + self.log.info("Updating tag template field: name=%s", tag_template_field_name) + + result = client.update_tag_template_field( + request={ + "name": tag_template_field_name, + "tag_template_field": tag_template_field, + "update_mask": update_mask, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + self.log.info("Updated tag template field.") + + return result diff --git a/reference/providers/google/cloud/hooks/dataflow.py b/reference/providers/google/cloud/hooks/dataflow.py new file mode 100644 index 0000000..8aa0635 --- /dev/null +++ b/reference/providers/google/cloud/hooks/dataflow.py @@ -0,0 +1,1247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Dataflow Hook.""" +import functools +import json +import re +import shlex +import subprocess +import time +import uuid +import warnings +from copy import deepcopy +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + Optional, + Sequence, + Set, + TypeVar, + Union, + cast, +) + +from airflow.exceptions import AirflowException +from airflow.providers.apache.beam.hooks.beam import ( + BeamHook, + BeamRunnerType, + beam_options_to_args, +) +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.timeout import timeout +from googleapiclient.discovery import build + +# This is the default location +# https://cloud.google.com/dataflow/pipelines/specifying-exec-params +DEFAULT_DATAFLOW_LOCATION = "us-central1" + + +JOB_ID_PATTERN = re.compile( + r"Submitted job: (?P.*)|Created job with id: \[(?P.*)\]" +) + +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + + +def process_line_and_extract_dataflow_job_id_callback( + on_new_job_id_callback: Optional[Callable[[str], None]] +) -> Callable[[str], None]: + """ + Returns callback which triggers function passed as `on_new_job_id_callback` when Dataflow job_id is found. + To be used for `process_line_callback` in + :py:class:`~airflow.providers.apache.beam.hooks.beam.BeamCommandRunner` + + :param on_new_job_id_callback: Callback called when the job ID is known + :type on_new_job_id_callback: callback + """ + + def _process_line_and_extract_job_id( + line: str, + # on_new_job_id_callback: Optional[Callable[[str], None]] + ) -> None: + # Job id info: https://goo.gl/SE29y9. + matched_job = JOB_ID_PATTERN.search(line) + if matched_job: + job_id = matched_job.group("job_id_java") or matched_job.group( + "job_id_python" + ) + if on_new_job_id_callback: + on_new_job_id_callback(job_id) + + def wrap(line: str): + return _process_line_and_extract_job_id(line) + + return wrap + + +def _fallback_variable_parameter( + parameter_name: str, variable_key_name: str +) -> Callable[[T], T]: + def _wrapper(func: T) -> T: + """ + Decorator that provides fallback for location from `region` key in `variables` parameters. + + :param func: function to wrap + :return: result of the function call + """ + + @functools.wraps(func) + def inner_wrapper(self: "DataflowHook", *args, **kwargs): + if args: + raise AirflowException( + "You must use keyword arguments in this methods rather than positional" + ) + + parameter_location = kwargs.get(parameter_name) + variables_location = kwargs.get("variables", {}).get(variable_key_name) + + if parameter_location and variables_location: + raise AirflowException( + f"The mutually exclusive parameter `{parameter_name}` and `{variable_key_name}` key " + f"in `variables` parameter are both present. Please remove one." + ) + if parameter_location or variables_location: + kwargs[parameter_name] = parameter_location or variables_location + if variables_location: + copy_variables = deepcopy(kwargs["variables"]) + del copy_variables[variable_key_name] + kwargs["variables"] = copy_variables + + return func(self, *args, **kwargs) + + return cast(T, inner_wrapper) + + return _wrapper + + +_fallback_to_location_from_variables = _fallback_variable_parameter( + "location", "region" +) +_fallback_to_project_id_from_variables = _fallback_variable_parameter( + "project_id", "project" +) + + +class DataflowJobStatus: + """ + Helper class with Dataflow job statuses. + Reference: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.jobs#Job.JobState + """ + + JOB_STATE_DONE = "JOB_STATE_DONE" + JOB_STATE_UNKNOWN = "JOB_STATE_UNKNOWN" + JOB_STATE_STOPPED = "JOB_STATE_STOPPED" + JOB_STATE_RUNNING = "JOB_STATE_RUNNING" + JOB_STATE_FAILED = "JOB_STATE_FAILED" + JOB_STATE_CANCELLED = "JOB_STATE_CANCELLED" + JOB_STATE_UPDATED = "JOB_STATE_UPDATED" + JOB_STATE_DRAINING = "JOB_STATE_DRAINING" + JOB_STATE_DRAINED = "JOB_STATE_DRAINED" + JOB_STATE_PENDING = "JOB_STATE_PENDING" + JOB_STATE_CANCELLING = "JOB_STATE_CANCELLING" + JOB_STATE_QUEUED = "JOB_STATE_QUEUED" + FAILED_END_STATES = {JOB_STATE_FAILED, JOB_STATE_CANCELLED} + SUCCEEDED_END_STATES = {JOB_STATE_DONE, JOB_STATE_UPDATED, JOB_STATE_DRAINED} + TERMINAL_STATES = SUCCEEDED_END_STATES | FAILED_END_STATES + AWAITING_STATES = { + JOB_STATE_RUNNING, + JOB_STATE_PENDING, + JOB_STATE_QUEUED, + JOB_STATE_CANCELLING, + JOB_STATE_DRAINING, + JOB_STATE_STOPPED, + } + + +class DataflowJobType: + """Helper class with Dataflow job types.""" + + JOB_TYPE_UNKNOWN = "JOB_TYPE_UNKNOWN" + JOB_TYPE_BATCH = "JOB_TYPE_BATCH" + JOB_TYPE_STREAMING = "JOB_TYPE_STREAMING" + + +class _DataflowJobsController(LoggingMixin): + """ + Interface for communication with Google API. + + It's not use Apache Beam, but only Google Dataflow API. + + :param dataflow: Discovery resource + :param project_number: The Google Cloud Project ID. + :param location: Job location. + :param poll_sleep: The status refresh rate for pending operations. + :param name: The Job ID prefix used when the multiple_jobs option is passed is set to True. + :param job_id: ID of a single job. + :param num_retries: Maximum number of retries in case of connection problems. + :param multiple_jobs: If set to true this task will be searched by name prefix (``name`` parameter), + not by specific job ID, then actions will be performed on all matching jobs. + :param drain_pipeline: Optional, set to True if want to stop streaming job by draining it + instead of canceling. + :param cancel_timeout: wait time in seconds for successful job canceling + :param wait_until_finished: If True, wait for the end of pipeline execution before exiting. If False, + it only submits job and check once is job not in terminal state. + + The default behavior depends on the type of pipeline: + + * for the streaming pipeline, wait for jobs to start, + * for the batch pipeline, wait for the jobs to complete. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + dataflow: Any, + project_number: str, + location: str, + poll_sleep: int = 10, + name: Optional[str] = None, + job_id: Optional[str] = None, + num_retries: int = 0, + multiple_jobs: bool = False, + drain_pipeline: bool = False, + cancel_timeout: Optional[int] = 5 * 60, + wait_until_finished: Optional[bool] = None, + ) -> None: + + super().__init__() + self._dataflow = dataflow + self._project_number = project_number + self._job_name = name + self._job_location = location + self._multiple_jobs = multiple_jobs + self._job_id = job_id + self._num_retries = num_retries + self._poll_sleep = poll_sleep + self._cancel_timeout = cancel_timeout + self._jobs: Optional[List[dict]] = None + self.drain_pipeline = drain_pipeline + self._wait_until_finished = wait_until_finished + self._jobs: Optional[List[dict]] = None + + def is_job_running(self) -> bool: + """ + Helper method to check if jos is still running in dataflow + + :return: True if job is running. + :rtype: bool + """ + self._refresh_jobs() + if not self._jobs: + return False + + for job in self._jobs: + if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES: + return True + return False + + # pylint: disable=too-many-nested-blocks + def _get_current_jobs(self) -> List[dict]: + """ + Helper method to get list of jobs that start with job name or id + + :return: list of jobs including id's + :rtype: list + """ + if not self._multiple_jobs and self._job_id: + return [self.fetch_job_by_id(self._job_id)] + elif self._job_name: + jobs = self._fetch_jobs_by_prefix_name(self._job_name.lower()) + if len(jobs) == 1: + self._job_id = jobs[0]["id"] + return jobs + else: + raise Exception("Missing both dataflow job ID and name.") + + def fetch_job_by_id(self, job_id: str) -> dict: + """ + Helper method to fetch the job with the specified Job ID. + + :param job_id: Job ID to get. + :type job_id: str + :return: the Job + :rtype: dict + """ + return ( + self._dataflow.projects() + .locations() + .jobs() + .get( + projectId=self._project_number, + location=self._job_location, + jobId=job_id, + ) + .execute(num_retries=self._num_retries) + ) + + def fetch_job_metrics_by_id(self, job_id: str) -> dict: + """ + Helper method to fetch the job metrics with the specified Job ID. + + :param job_id: Job ID to get. + :type job_id: str + :return: the JobMetrics. See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/JobMetrics + :rtype: dict + """ + result = ( + self._dataflow.projects() + .locations() + .jobs() + .getMetrics( + projectId=self._project_number, + location=self._job_location, + jobId=job_id, + ) + .execute(num_retries=self._num_retries) + ) + + self.log.debug("fetch_job_metrics_by_id %s:\n%s", job_id, result) + return result + + def _fetch_list_job_messages_responses( + self, job_id: str + ) -> Generator[dict, None, None]: + """ + Helper method to fetch ListJobMessagesResponse with the specified Job ID. + + :param job_id: Job ID to get. + :type job_id: str + :return: yields the ListJobMessagesResponse. See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/ListJobMessagesResponse + :rtype: Generator[dict, None, None] + """ + request = ( + self._dataflow.projects() + .locations() + .jobs() + .messages() + .list( + projectId=self._project_number, + location=self._job_location, + jobId=job_id, + ) + ) + + while request is not None: + response = request.execute(num_retries=self._num_retries) + yield response + + request = ( + self._dataflow.projects() + .locations() + .jobs() + .messages() + .list_next(previous_request=request, previous_response=response) + ) + + def fetch_job_messages_by_id(self, job_id: str) -> List[dict]: + """ + Helper method to fetch the job messages with the specified Job ID. + + :param job_id: Job ID to get. + :type job_id: str + :return: the list of JobMessages. See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/ListJobMessagesResponse#JobMessage + :rtype: List[dict] + """ + messages: List[dict] = [] + for response in self._fetch_list_job_messages_responses(job_id=job_id): + messages.extend(response.get("jobMessages", [])) + return messages + + def fetch_job_autoscaling_events_by_id(self, job_id: str) -> List[dict]: + """ + Helper method to fetch the job autoscaling events with the specified Job ID. + + :param job_id: Job ID to get. + :type job_id: str + :return: the list of AutoscalingEvents. See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/ListJobMessagesResponse#autoscalingevent + :rtype: List[dict] + """ + autoscaling_events: List[dict] = [] + for response in self._fetch_list_job_messages_responses(job_id=job_id): + autoscaling_events.extend(response.get("autoscalingEvents", [])) + return autoscaling_events + + def _fetch_all_jobs(self) -> List[dict]: + request = ( + self._dataflow.projects() + .locations() + .jobs() + .list(projectId=self._project_number, location=self._job_location) + ) + jobs: List[dict] = [] + while request is not None: + response = request.execute(num_retries=self._num_retries) + jobs.extend(response["jobs"]) + + request = ( + self._dataflow.projects() + .locations() + .jobs() + .list_next(previous_request=request, previous_response=response) + ) + return jobs + + def _fetch_jobs_by_prefix_name(self, prefix_name: str) -> List[dict]: + jobs = self._fetch_all_jobs() + jobs = [job for job in jobs if job["name"].startswith(prefix_name)] + return jobs + + def _refresh_jobs(self) -> None: + """ + Helper method to get all jobs by name + + :return: jobs + :rtype: list + """ + self._jobs = self._get_current_jobs() + + if self._jobs: + for job in self._jobs: + self.log.info( + "Google Cloud DataFlow job %s is state: %s", + job["name"], + job["currentState"], + ) + else: + self.log.info("Google Cloud DataFlow job not available yet..") + + def _check_dataflow_job_state(self, job) -> bool: + """ + Helper method to check the state of one job in dataflow for this task + if job failed raise exception + + :return: True if job is done. + :rtype: bool + :raise: Exception + """ + if self._wait_until_finished is None: + wait_for_running = job.get("type") == DataflowJobType.JOB_TYPE_STREAMING + else: + wait_for_running = not self._wait_until_finished + + if job["currentState"] == DataflowJobStatus.JOB_STATE_DONE: + return True + elif job["currentState"] == DataflowJobStatus.JOB_STATE_FAILED: + raise Exception(f"Google Cloud Dataflow job {job['name']} has failed.") + elif job["currentState"] == DataflowJobStatus.JOB_STATE_CANCELLED: + raise Exception(f"Google Cloud Dataflow job {job['name']} was cancelled.") + elif job["currentState"] == DataflowJobStatus.JOB_STATE_DRAINED: + raise Exception(f"Google Cloud Dataflow job {job['name']} was drained.") + elif job["currentState"] == DataflowJobStatus.JOB_STATE_UPDATED: + raise Exception(f"Google Cloud Dataflow job {job['name']} was updated.") + elif ( + job["currentState"] == DataflowJobStatus.JOB_STATE_RUNNING + and wait_for_running + ): + return True + elif job["currentState"] in DataflowJobStatus.AWAITING_STATES: + return self._wait_until_finished is False + self.log.debug("Current job: %s", str(job)) + raise Exception( + f"Google Cloud Dataflow job {job['name']} was unknown state: {job['currentState']}" + ) + + def wait_for_done(self) -> None: + """Helper method to wait for result of submitted job.""" + self.log.info("Start waiting for done.") + self._refresh_jobs() + while self._jobs and not all( + self._check_dataflow_job_state(job) for job in self._jobs + ): + self.log.info("Waiting for done. Sleep %s s", self._poll_sleep) + time.sleep(self._poll_sleep) + self._refresh_jobs() + + def get_jobs(self, refresh: bool = False) -> List[dict]: + """ + Returns Dataflow jobs. + + :param refresh: Forces the latest data to be fetched. + :type refresh: bool + :return: list of jobs + :rtype: list + """ + if not self._jobs or refresh: + self._refresh_jobs() + if not self._jobs: + raise ValueError("Could not read _jobs") + + return self._jobs + + def _wait_for_states(self, expected_states: Set[str]): + """Waiting for the jobs to reach a certain state.""" + if not self._jobs: + raise ValueError("The _jobs should be set") + while True: + self._refresh_jobs() + job_states = {job["currentState"] for job in self._jobs} + if not job_states.difference(expected_states): + return + unexpected_failed_end_states = ( + expected_states - DataflowJobStatus.FAILED_END_STATES + ) + if unexpected_failed_end_states.intersection(job_states): + unexpected_failed_jobs = { + job + for job in self._jobs + if job["currentState"] in unexpected_failed_end_states + } + raise AirflowException( + "Jobs failed: " + + ", ".join( + f"ID: {job['id']} name: {job['name']} state: {job['currentState']}" + for job in unexpected_failed_jobs + ) + ) + time.sleep(self._poll_sleep) + + def cancel(self) -> None: + """Cancels or drains current job""" + jobs = self.get_jobs() + job_ids = [ + job["id"] + for job in jobs + if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES + ] + if job_ids: + batch = self._dataflow.new_batch_http_request() + self.log.info("Canceling jobs: %s", ", ".join(job_ids)) + for job in jobs: + requested_state = ( + DataflowJobStatus.JOB_STATE_DRAINED + if self.drain_pipeline + and job["type"] == DataflowJobType.JOB_TYPE_STREAMING + else DataflowJobStatus.JOB_STATE_CANCELLED + ) + batch.add( + self._dataflow.projects() + .locations() + .jobs() + .update( + projectId=self._project_number, + location=self._job_location, + jobId=job["id"], + body={"requestedState": requested_state}, + ) + ) + batch.execute() + if self._cancel_timeout and isinstance(self._cancel_timeout, int): + timeout_error_message = ( + "Canceling jobs failed due to timeout ({}s): {}".format( + self._cancel_timeout, ", ".join(job_ids) + ) + ) + with timeout( + seconds=self._cancel_timeout, error_message=timeout_error_message + ): + self._wait_for_states({DataflowJobStatus.JOB_STATE_CANCELLED}) + else: + self.log.info("No jobs to cancel") + + +class DataflowHook(GoogleBaseHook): + """ + Hook for Google Dataflow. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + poll_sleep: int = 10, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + drain_pipeline: bool = False, + cancel_timeout: Optional[int] = 5 * 60, + wait_until_finished: Optional[bool] = None, + ) -> None: + self.poll_sleep = poll_sleep + self.drain_pipeline = drain_pipeline + self.cancel_timeout = cancel_timeout + self.wait_until_finished = wait_until_finished + self.job_id: Optional[str] = None + self.beam_hook = BeamHook(BeamRunnerType.DataflowRunner) + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + + def get_conn(self) -> build: + """Returns a Google Cloud Dataflow service object.""" + http_authorized = self._authorize() + return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False) + + @_fallback_to_location_from_variables + @_fallback_to_project_id_from_variables + @GoogleBaseHook.fallback_to_default_project_id + def start_java_dataflow( + self, + job_name: str, + variables: dict, + jar: str, + project_id: str, + job_class: Optional[str] = None, + append_job_name: bool = True, + multiple_jobs: bool = False, + on_new_job_id_callback: Optional[Callable[[str], None]] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + ) -> None: + """ + Starts Dataflow java job. + + :param job_name: The name of the job. + :type job_name: str + :param variables: Variables passed to the job. + :type variables: dict + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param jar: Name of the jar for the job + :type job_class: str + :param job_class: Name of the java class for the job. + :type job_class: str + :param append_job_name: True if unique suffix has to be appended to job name. + :type append_job_name: bool + :param multiple_jobs: True if to check for multiple job in dataflow + :type multiple_jobs: bool + :param on_new_job_id_callback: Callback called when the job ID is known. + :type on_new_job_id_callback: callable + :param location: Job location. + :type location: str + """ + warnings.warn( + """"This method is deprecated. + Please use `airflow.providers.apache.beam.hooks.beam.start.start_java_pipeline` + to start pipeline and `providers.google.cloud.hooks.dataflow.DataflowHook.wait_for_done` + to wait for the required pipeline state. + """, + DeprecationWarning, + stacklevel=3, + ) + + name = self.build_dataflow_job_name(job_name, append_job_name) + + variables["jobName"] = name + variables["region"] = location + variables["project"] = project_id + + if "labels" in variables: + variables["labels"] = json.dumps(variables["labels"], separators=(",", ":")) + + self.beam_hook.start_java_pipeline( + variables=variables, + jar=jar, + job_class=job_class, + process_line_callback=process_line_and_extract_dataflow_job_id_callback( + on_new_job_id_callback + ), + ) + self.wait_for_done( # pylint: disable=no-value-for-parameter + job_name=name, + location=location, + job_id=self.job_id, + multiple_jobs=multiple_jobs, + ) + + @_fallback_to_location_from_variables + @_fallback_to_project_id_from_variables + @GoogleBaseHook.fallback_to_default_project_id + def start_template_dataflow( + self, + job_name: str, + variables: dict, + parameters: dict, + dataflow_template: str, + project_id: str, + append_job_name: bool = True, + on_new_job_id_callback: Optional[Callable[[str], None]] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + environment: Optional[dict] = None, + ) -> dict: + """ + Starts Dataflow template job. + + :param job_name: The name of the job. + :type job_name: str + :param variables: Map of job runtime environment options. + It will update environment argument if passed. + + .. seealso:: + For more information on possible configurations, look at the API documentation + `https://cloud.google.com/dataflow/pipelines/specifying-exec-params + `__ + + :type variables: dict + :param parameters: Parameters fot the template + :type parameters: dict + :param dataflow_template: GCS path to the template. + :type dataflow_template: str + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param append_job_name: True if unique suffix has to be appended to job name. + :type append_job_name: bool + :param on_new_job_id_callback: Callback called when the job ID is known. + :type on_new_job_id_callback: callable + :param location: Job location. + :type location: str + :type environment: Optional, Map of job runtime environment options. + + .. seealso:: + For more information on possible configurations, look at the API documentation + `https://cloud.google.com/dataflow/pipelines/specifying-exec-params + `__ + + :type environment: Optional[dict] + """ + name = self.build_dataflow_job_name(job_name, append_job_name) + + environment = environment or {} + # available keys for runtime environment are listed here: + # https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment + environment_keys = [ + "numWorkers", + "maxWorkers", + "zone", + "serviceAccountEmail", + "tempLocation", + "bypassTempDirValidation", + "machineType", + "additionalExperiments", + "network", + "subnetwork", + "additionalUserLabels", + "kmsKeyName", + "ipConfiguration", + "workerRegion", + "workerZone", + ] + + for key in variables: + if key in environment_keys: + if key in environment: + self.log.warning( + "'%s' parameter in 'variables' will override of " + "the same one passed in 'environment'!", + key, + ) + environment.update({key: variables[key]}) + + service = self.get_conn() + # pylint: disable=no-member + request = ( + service.projects() + .locations() + .templates() + .launch( + projectId=project_id, + location=location, + gcsPath=dataflow_template, + body={ + "jobName": name, + "parameters": parameters, + "environment": environment, + }, + ) + ) + response = request.execute(num_retries=self.num_retries) + + job_id = response["job"]["id"] + if on_new_job_id_callback: + on_new_job_id_callback(job_id) + + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + name=name, + job_id=job_id, + location=location, + poll_sleep=self.poll_sleep, + num_retries=self.num_retries, + drain_pipeline=self.drain_pipeline, + cancel_timeout=self.cancel_timeout, + ) + jobs_controller.wait_for_done() + return response["job"] + + @GoogleBaseHook.fallback_to_default_project_id + def start_flex_template( + self, + body: dict, + location: str, + project_id: str, + on_new_job_id_callback: Optional[Callable[[str], None]] = None, + ): + """ + Starts flex templates with the Dataflow pipeline. + + :param body: The request body. See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body + :param location: The location of the Dataflow job (for example europe-west1) + :type location: str + :param project_id: The ID of the GCP project that owns the job. + If set to ``None`` or missing, the default project_id from the GCP connection is used. + :type project_id: Optional[str] + :param on_new_job_id_callback: A callback that is called when a Job ID is detected. + :return: the Job + """ + service = self.get_conn() + request = ( + service.projects() # pylint: disable=no-member + .locations() + .flexTemplates() + .launch(projectId=project_id, body=body, location=location) + ) + response = request.execute(num_retries=self.num_retries) + job_id = response["job"]["id"] + + if on_new_job_id_callback: + on_new_job_id_callback(job_id) + + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + job_id=job_id, + location=location, + poll_sleep=self.poll_sleep, + num_retries=self.num_retries, + cancel_timeout=self.cancel_timeout, + ) + jobs_controller.wait_for_done() + + return jobs_controller.get_jobs(refresh=True)[0] + + @_fallback_to_location_from_variables + @_fallback_to_project_id_from_variables + @GoogleBaseHook.fallback_to_default_project_id + def start_python_dataflow( # pylint: disable=too-many-arguments + self, + job_name: str, + variables: dict, + dataflow: str, + py_options: List[str], + project_id: str, + py_interpreter: str = "python3", + py_requirements: Optional[List[str]] = None, + py_system_site_packages: bool = False, + append_job_name: bool = True, + on_new_job_id_callback: Optional[Callable[[str], None]] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + ): + """ + Starts Dataflow job. + + :param job_name: The name of the job. + :type job_name: str + :param variables: Variables passed to the job. + :type variables: Dict + :param dataflow: Name of the Dataflow process. + :type dataflow: str + :param py_options: Additional options. + :type py_options: List[str] + :param project_id: The ID of the GCP project that owns the job. + If set to ``None`` or missing, the default project_id from the GCP connection is used. + :type project_id: Optional[str] + :param py_interpreter: Python version of the beam pipeline. + If None, this defaults to the python3. + To track python versions supported by beam and related + issues check: https://issues.apache.org/jira/browse/BEAM-1251 + :param py_requirements: Additional python package(s) to install. + If a value is passed to this parameter, a new virtual environment has been created with + additional packages installed. + + You could also install the apache-beam package if it is not installed on your system or you want + to use a different version. + :type py_requirements: List[str] + :param py_system_site_packages: Whether to include system_site_packages in your virtualenv. + See virtualenv documentation for more information. + + This option is only relevant if the ``py_requirements`` parameter is not None. + :type py_interpreter: str + :param append_job_name: True if unique suffix has to be appended to job name. + :type append_job_name: bool + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param on_new_job_id_callback: Callback called when the job ID is known. + :type on_new_job_id_callback: callable + :param location: Job location. + :type location: str + """ + warnings.warn( + """This method is deprecated. + Please use `airflow.providers.apache.beam.hooks.beam.start.start_python_pipeline` + to start pipeline and `providers.google.cloud.hooks.dataflow.DataflowHook.wait_for_done` + to wait for the required pipeline state. + """, + DeprecationWarning, + stacklevel=3, + ) + + name = self.build_dataflow_job_name(job_name, append_job_name) + variables["job_name"] = name + variables["region"] = location + variables["project"] = project_id + + self.beam_hook.start_python_pipeline( + variables=variables, + py_file=dataflow, + py_options=py_options, + py_interpreter=py_interpreter, + py_requirements=py_requirements, + py_system_site_packages=py_system_site_packages, + process_line_callback=process_line_and_extract_dataflow_job_id_callback( + on_new_job_id_callback + ), + ) + + self.wait_for_done( # pylint: disable=no-value-for-parameter + job_name=name, + location=location, + job_id=self.job_id, + ) + + @staticmethod + def build_dataflow_job_name(job_name: str, append_job_name: bool = True) -> str: + """Builds Dataflow job name.""" + base_job_name = str(job_name).replace("_", "-") + + if not re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", base_job_name): + raise ValueError( + "Invalid job_name ({}); the name must consist of" + "only the characters [-a-z0-9], starting with a " + "letter and ending with a letter or number ".format(base_job_name) + ) + + if append_job_name: + safe_job_name = base_job_name + "-" + str(uuid.uuid4())[:8] + else: + safe_job_name = base_job_name + + return safe_job_name + + @_fallback_to_location_from_variables + @_fallback_to_project_id_from_variables + @GoogleBaseHook.fallback_to_default_project_id + def is_job_dataflow_running( + self, + name: str, + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + variables: Optional[dict] = None, + ) -> bool: + """ + Helper method to check if jos is still running in dataflow + + :param name: The name of the job. + :type name: str + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param location: Job location. + :type location: str + :return: True if job is running. + :rtype: bool + """ + if variables: + warnings.warn( + "The variables parameter has been deprecated. You should pass location using " + "the location parameter.", + DeprecationWarning, + stacklevel=4, + ) + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + name=name, + location=location, + poll_sleep=self.poll_sleep, + drain_pipeline=self.drain_pipeline, + num_retries=self.num_retries, + cancel_timeout=self.cancel_timeout, + ) + return jobs_controller.is_job_running() + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_job( + self, + project_id: str, + job_name: Optional[str] = None, + job_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + ) -> None: + """ + Cancels the job with the specified name prefix or Job ID. + + Parameter ``name`` and ``job_id`` are mutually exclusive. + + :param job_name: Name prefix specifying which jobs are to be canceled. + :type job_name: str + :param job_id: Job ID specifying which jobs are to be canceled. + :type job_id: str + :param location: Job location. + :type location: str + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: + """ + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + name=job_name, + job_id=job_id, + location=location, + poll_sleep=self.poll_sleep, + drain_pipeline=self.drain_pipeline, + num_retries=self.num_retries, + cancel_timeout=self.cancel_timeout, + ) + jobs_controller.cancel() + + @GoogleBaseHook.fallback_to_default_project_id + def start_sql_job( + self, + job_name: str, + query: str, + options: Dict[str, Any], + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + on_new_job_id_callback: Optional[Callable[[str], None]] = None, + ): + """ + Starts Dataflow SQL query. + + :param job_name: The unique name to assign to the Cloud Dataflow job. + :type job_name: str + :param query: The SQL query to execute. + :type query: str + :param options: Job parameters to be executed. + For more information, look at: + `https://cloud.google.com/sdk/gcloud/reference/beta/dataflow/sql/query + `__ + command reference + :param location: The location of the Dataflow job (for example europe-west1) + :type location: str + :param project_id: The ID of the GCP project that owns the job. + If set to ``None`` or missing, the default project_id from the GCP connection is used. + :type project_id: Optional[str] + :param on_new_job_id_callback: Callback called when the job ID is known. + :type on_new_job_id_callback: callable + :return: the new job object + """ + cmd = [ + "gcloud", + "dataflow", + "sql", + "query", + query, + f"--project={project_id}", + "--format=value(job.id)", + f"--job-name={job_name}", + f"--region={location}", + *(beam_options_to_args(options)), + ] + self.log.info("Executing command: %s", " ".join([shlex.quote(c) for c in cmd])) + with self.provide_authorized_gcloud(): + proc = subprocess.run( # pylint: disable=subprocess-run-check + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + self.log.info("Output: %s", proc.stdout.decode()) + self.log.warning("Stderr: %s", proc.stderr.decode()) + self.log.info("Exit code %d", proc.returncode) + if proc.returncode != 0: + raise AirflowException( + f"Process exit with non-zero exit code. Exit code: {proc.returncode}" + ) + job_id = proc.stdout.decode().strip() + + self.log.info("Created job ID: %s", job_id) + if on_new_job_id_callback: + on_new_job_id_callback(job_id) + + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + job_id=job_id, + location=location, + poll_sleep=self.poll_sleep, + num_retries=self.num_retries, + drain_pipeline=self.drain_pipeline, + ) + jobs_controller.wait_for_done() + + return jobs_controller.get_jobs(refresh=True)[0] + + @GoogleBaseHook.fallback_to_default_project_id + def get_job( + self, + job_id: str, + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + ) -> dict: + """ + Gets the job with the specified Job ID. + + :param job_id: Job ID to get. + :type job_id: str + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: + :param location: The location of the Dataflow job (for example europe-west1). See: + https://cloud.google.com/dataflow/docs/concepts/regional-endpoints + :return: the Job + :rtype: dict + """ + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + location=location, + ) + return jobs_controller.fetch_job_by_id(job_id) + + @GoogleBaseHook.fallback_to_default_project_id + def fetch_job_metrics_by_id( + self, + job_id: str, + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + ) -> dict: + """ + Gets the job metrics with the specified Job ID. + + :param job_id: Job ID to get. + :type job_id: str + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: + :param location: The location of the Dataflow job (for example europe-west1). See: + https://cloud.google.com/dataflow/docs/concepts/regional-endpoints + :return: the JobMetrics. See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/JobMetrics + :rtype: dict + """ + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + location=location, + ) + return jobs_controller.fetch_job_metrics_by_id(job_id) + + @GoogleBaseHook.fallback_to_default_project_id + def fetch_job_messages_by_id( + self, + job_id: str, + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + ) -> List[dict]: + """ + Gets the job messages with the specified Job ID. + + :param job_id: Job ID to get. + :type job_id: str + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: + :param location: Job location. + :type location: str + :return: the list of JobMessages. See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/ListJobMessagesResponse#JobMessage + :rtype: List[dict] + """ + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + location=location, + ) + return jobs_controller.fetch_job_messages_by_id(job_id) + + @GoogleBaseHook.fallback_to_default_project_id + def fetch_job_autoscaling_events_by_id( + self, + job_id: str, + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + ) -> List[dict]: + """ + Gets the job autoscaling events with the specified Job ID. + + :param job_id: Job ID to get. + :type job_id: str + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: + :param location: Job location. + :type location: str + :return: the list of AutoscalingEvents. See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/ListJobMessagesResponse#autoscalingevent + :rtype: List[dict] + """ + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + location=location, + ) + return jobs_controller.fetch_job_autoscaling_events_by_id(job_id) + + @GoogleBaseHook.fallback_to_default_project_id + def wait_for_done( + self, + job_name: str, + location: str, + project_id: str, + job_id: Optional[str] = None, + multiple_jobs: bool = False, + ) -> None: + """ + Wait for Dataflow job. + + :param job_name: The 'jobName' to use when executing the DataFlow job + (templated). This ends up being set in the pipeline options, so any entry + with key ``'jobName'`` in ``options`` will be overwritten. + :type job_name: str + :param location: location the job is running + :type location: str + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: + :param job_id: a Dataflow job ID + :type job_id: str + :param multiple_jobs: If pipeline creates multiple jobs then monitor all jobs + :type multiple_jobs: boolean + """ + job_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + name=job_name, + location=location, + poll_sleep=self.poll_sleep, + job_id=job_id or self.job_id, + num_retries=self.num_retries, + multiple_jobs=multiple_jobs, + drain_pipeline=self.drain_pipeline, + cancel_timeout=self.cancel_timeout, + wait_until_finished=self.wait_until_finished, + ) + job_controller.wait_for_done() diff --git a/reference/providers/google/cloud/hooks/datafusion.py b/reference/providers/google/cloud/hooks/datafusion.py new file mode 100644 index 0000000..c3050cc --- /dev/null +++ b/reference/providers/google/cloud/hooks/datafusion.py @@ -0,0 +1,530 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains Google DataFusion hook.""" +import json +import os +from time import monotonic, sleep +from typing import Any, Dict, List, Optional, Sequence, Union +from urllib.parse import quote, urlencode + +import google.auth +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.retry import exponential_sleep_generator +from googleapiclient.discovery import Resource, build + +Operation = Dict[str, Any] + + +class PipelineStates: + """Data Fusion pipeline states""" + + PENDING = "PENDING" + STARTING = "STARTING" + RUNNING = "RUNNING" + SUSPENDED = "SUSPENDED" + RESUMING = "RESUMING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + KILLED = "KILLED" + REJECTED = "REJECTED" + + +FAILURE_STATES = [PipelineStates.FAILED, PipelineStates.KILLED, PipelineStates.REJECTED] +SUCCESS_STATES = [PipelineStates.COMPLETED] + + +class DataFusionHook(GoogleBaseHook): + """Hook for Google DataFusion.""" + + _conn = None # type: Optional[Resource] + + def __init__( + self, + api_version: str = "v1beta1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.api_version = api_version + + def wait_for_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]: + """Waits for long-lasting operation to complete.""" + for time_to_wait in exponential_sleep_generator(initial=10, maximum=120): + sleep(time_to_wait) + operation = ( + self.get_conn() # pylint: disable=no-member + .projects() + .locations() + .operations() + .get(name=operation.get("name")) + .execute() + ) + if operation.get("done"): + break + if "error" in operation: + raise AirflowException(operation["error"]) + return operation["response"] + + def wait_for_pipeline_state( + self, + pipeline_name: str, + pipeline_id: str, + instance_url: str, + namespace: str = "default", + success_states: Optional[List[str]] = None, + failure_states: Optional[List[str]] = None, + timeout: int = 5 * 60, + ) -> None: + """ + Polls pipeline state and raises an exception if the state is one of + `failure_states` or the operation timed_out. + """ + failure_states = failure_states or FAILURE_STATES + success_states = success_states or SUCCESS_STATES + start_time = monotonic() + current_state = None + while monotonic() - start_time < timeout: + try: + current_state = self._get_workflow_state( + pipeline_name=pipeline_name, + pipeline_id=pipeline_id, + instance_url=instance_url, + namespace=namespace, + ) + except AirflowException: + pass # Because the pipeline may not be visible in system yet + if current_state in success_states: + return + if current_state in failure_states: + raise AirflowException( + f"Pipeline {pipeline_name} state {current_state} is not " + f"one of {success_states}" + ) + sleep(30) + + # Time is up! + raise AirflowException( + f"Pipeline {pipeline_name} state {current_state} is not " + f"one of {success_states} after {timeout}s" + ) + + @staticmethod + def _name(project_id: str, location: str, instance_name: str) -> str: + return f"projects/{project_id}/locations/{location}/instances/{instance_name}" + + @staticmethod + def _parent(project_id: str, location: str) -> str: + return f"projects/{project_id}/locations/{location}" + + @staticmethod + def _base_url(instance_url: str, namespace: str) -> str: + return os.path.join(instance_url, "v3", "namespaces", quote(namespace), "apps") + + def _cdap_request( + self, url: str, method: str, body: Optional[Union[List, Dict]] = None + ) -> google.auth.transport.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + request = google.auth.transport.requests.Request() + + credentials = self._get_credentials() + credentials.before_request( + request=request, method=method, url=url, headers=headers + ) + + payload = json.dumps(body) if body else None + + response = request(method=method, url=url, headers=headers, body=payload) + return response + + def get_conn(self) -> Re# + """Retrieves connection to DataFusion.""" + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "datafusion", + self.api_version, + http=http_authorized, + cache_discovery=False, + ) + return self._conn + + @GoogleBaseHook.fallback_to_default_project_id + def restart_instance( + self, instance_name: str, location: str, project_id: str + ) -> Operation: + """ + Restart a single Data Fusion instance. + At the end of an operation instance is fully restarted. + + :param instance_name: The name of the instance to restart. + :type instance_name: str + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param project_id: The ID of the Google Cloud project that the instance belongs to. + :type project_id: str + """ + operation = ( + self.get_conn() # pylint: disable=no-member + .projects() + .locations() + .instances() + .restart(name=self._name(project_id, location, instance_name)) + .execute(num_retries=self.num_retries) + ) + return operation + + @GoogleBaseHook.fallback_to_default_project_id + def delete_instance( + self, instance_name: str, location: str, project_id: str + ) -> Operation: + """ + Deletes a single Date Fusion instance. + + :param instance_name: The name of the instance to delete. + :type instance_name: str + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param project_id: The ID of the Google Cloud project that the instance belongs to. + :type project_id: str + """ + operation = ( + self.get_conn() # pylint: disable=no-member + .projects() + .locations() + .instances() + .delete(name=self._name(project_id, location, instance_name)) + .execute(num_retries=self.num_retries) + ) + return operation + + @GoogleBaseHook.fallback_to_default_project_id + def create_instance( + self, + instance_name: str, + instance: Dict[str, Any], + location: str, + project_id: str, + ) -> Operation: + """ + Creates a new Data Fusion instance in the specified project and location. + + :param instance_name: The name of the instance to create. + :type instance_name: str + :param instance: An instance of Instance. + https://cloud.google.com/data-fusion/docs/reference/rest/v1beta1/projects.locations.instances#Instance + :type instance: Dict[str, Any] + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param project_id: The ID of the Google Cloud project that the instance belongs to. + :type project_id: str + """ + operation = ( + self.get_conn() # pylint: disable=no-member + .projects() + .locations() + .instances() + .create( + parent=self._parent(project_id, location), + body=instance, + instanceId=instance_name, + ) + .execute(num_retries=self.num_retries) + ) + return operation + + @GoogleBaseHook.fallback_to_default_project_id + def get_instance( + self, instance_name: str, location: str, project_id: str + ) -> Dict[str, Any]: + """ + Gets details of a single Data Fusion instance. + + :param instance_name: The name of the instance. + :type instance_name: str + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param project_id: The ID of the Google Cloud project that the instance belongs to. + :type project_id: str + """ + instance = ( + self.get_conn() # pylint: disable=no-member + .projects() + .locations() + .instances() + .get(name=self._name(project_id, location, instance_name)) + .execute(num_retries=self.num_retries) + ) + return instance + + @GoogleBaseHook.fallback_to_default_project_id + def patch_instance( + self, + instance_name: str, + instance: Dict[str, Any], + update_mask: str, + location: str, + project_id: str, + ) -> Operation: + """ + Updates a single Data Fusion instance. + + :param instance_name: The name of the instance to create. + :type instance_name: str + :param instance: An instance of Instance. + https://cloud.google.com/data-fusion/docs/reference/rest/v1beta1/projects.locations.instances#Instance + :type instance: Dict[str, Any] + :param update_mask: Field mask is used to specify the fields that the update will overwrite + in an instance resource. The fields specified in the updateMask are relative to the resource, + not the full request. A field will be overwritten if it is in the mask. If the user does not + provide a mask, all the supported fields (labels and options currently) will be overwritten. + A comma-separated list of fully qualified names of fields. Example: "user.displayName,photo". + https://developers.google.com/protocol-buffers/docs/reference/google.protobuf?_ga=2.205612571.-968688242.1573564810#google.protobuf.FieldMask + :type update_mask: str + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param project_id: The ID of the Google Cloud project that the instance belongs to. + :type project_id: str + """ + operation = ( + self.get_conn() # pylint: disable=no-member + .projects() + .locations() + .instances() + .patch( + name=self._name(project_id, location, instance_name), + updateMask=update_mask, + body=instance, + ) + .execute(num_retries=self.num_retries) + ) + return operation + + def create_pipeline( + self, + pipeline_name: str, + pipeline: Dict[str, Any], + instance_url: str, + namespace: str = "default", + ) -> None: + """ + Creates a Cloud Data Fusion pipeline. + + :param pipeline_name: Your pipeline name. + :type pipeline_name: str + :param pipeline: The pipeline definition. For more information check: + https://docs.cdap.io/cdap/current/en/developer-manual/pipelines/developing-pipelines.html#pipeline-configuration-file-format + :type pipeline: Dict[str, Any] + :param instance_url: Endpoint on which the REST APIs is accessible for the instance. + :type instance_url: str + :param namespace: if your pipeline belongs to a Basic edition instance, the namespace ID + is always default. If your pipeline belongs to an Enterprise edition instance, you + can create a namespace. + :type namespace: str + """ + url = os.path.join( + self._base_url(instance_url, namespace), quote(pipeline_name) + ) + response = self._cdap_request(url=url, method="PUT", body=pipeline) + if response.status != 200: + raise AirflowException( + f"Creating a pipeline failed with code {response.status}" + ) + + def delete_pipeline( + self, + pipeline_name: str, + instance_url: str, + version_id: Optional[str] = None, + namespace: str = "default", + ) -> None: + """ + Deletes a Cloud Data Fusion pipeline. + + :param pipeline_name: Your pipeline name. + :type pipeline_name: str + :param version_id: Version of pipeline to delete + :type version_id: Optional[str] + :param instance_url: Endpoint on which the REST APIs is accessible for the instance. + :type instance_url: str + :param namespace: f your pipeline belongs to a Basic edition instance, the namespace ID + is always default. If your pipeline belongs to an Enterprise edition instance, you + can create a namespace. + :type namespace: str + """ + url = os.path.join( + self._base_url(instance_url, namespace), quote(pipeline_name) + ) + if version_id: + url = os.path.join(url, "versions", version_id) + + response = self._cdap_request(url=url, method="DELETE", body=None) + if response.status != 200: + raise AirflowException( + f"Deleting a pipeline failed with code {response.status}" + ) + + def list_pipelines( + self, + instance_url: str, + artifact_name: Optional[str] = None, + artifact_version: Optional[str] = None, + namespace: str = "default", + ) -> dict: + """ + Lists Cloud Data Fusion pipelines. + + :param artifact_version: Artifact version to filter instances + :type artifact_version: Optional[str] + :param artifact_name: Artifact name to filter instances + :type artifact_name: Optional[str] + :param instance_url: Endpoint on which the REST APIs is accessible for the instance. + :type instance_url: str + :param namespace: f your pipeline belongs to a Basic edition instance, the namespace ID + is always default. If your pipeline belongs to an Enterprise edition instance, you + can create a namespace. + :type namespace: str + """ + url = self._base_url(instance_url, namespace) + query: Dict[str, str] = {} + if artifact_name: + query = {"artifactName": artifact_name} + if artifact_version: + query = {"artifactVersion": artifact_version} + if query: + url = os.path.join(url, urlencode(query)) + + response = self._cdap_request(url=url, method="GET", body=None) + if response.status != 200: + raise AirflowException( + f"Listing pipelines failed with code {response.status}" + ) + return json.loads(response.data) + + def _get_workflow_state( + self, + pipeline_name: str, + instance_url: str, + pipeline_id: str, + namespace: str = "default", + ) -> str: + url = os.path.join( + self._base_url(instance_url, namespace), + quote(pipeline_name), + "workflows", + "DataPipelineWorkflow", + "runs", + quote(pipeline_id), + ) + response = self._cdap_request(url=url, method="GET") + if response.status != 200: + raise AirflowException( + f"Retrieving a pipeline state failed with code {response.status}" + ) + workflow = json.loads(response.data) + return workflow["status"] + + def start_pipeline( + self, + pipeline_name: str, + instance_url: str, + namespace: str = "default", + runtime_args: Optional[Dict[str, Any]] = None, + ) -> str: + """ + Starts a Cloud Data Fusion pipeline. Works for both batch and stream pipelines. + + :param pipeline_name: Your pipeline name. + :type pipeline_name: str + :param instance_url: Endpoint on which the REST APIs is accessible for the instance. + :type instance_url: str + :param runtime_args: Optional runtime JSON args to be passed to the pipeline + :type runtime_args: Optional[Dict[str, Any]] + :param namespace: f your pipeline belongs to a Basic edition instance, the namespace ID + is always default. If your pipeline belongs to an Enterprise edition instance, you + can create a namespace. + :type namespace: str + """ + # TODO: This API endpoint starts multiple pipelines. There will eventually be a fix + # return the run Id as part of the API request to run a single pipeline. + # https://github.com/apache/airflow/pull/8954#discussion_r438223116 + url = os.path.join( + instance_url, + "v3", + "namespaces", + quote(namespace), + "start", + ) + runtime_args = runtime_args or {} + body = [ + { + "appId": pipeline_name, + "programType": "workflow", + "programId": "DataPipelineWorkflow", + "runtimeargs": runtime_args, + } + ] + response = self._cdap_request(url=url, method="POST", body=body) + if response.status != 200: + raise AirflowException( + f"Starting a pipeline failed with code {response.status}" + ) + + response_json = json.loads(response.data) + pipeline_id = response_json[0]["runId"] + self.wait_for_pipeline_state( + success_states=SUCCESS_STATES + [PipelineStates.RUNNING], + pipeline_name=pipeline_name, + pipeline_id=pipeline_id, + namespace=namespace, + instance_url=instance_url, + ) + return pipeline_id + + def stop_pipeline( + self, pipeline_name: str, instance_url: str, namespace: str = "default" + ) -> None: + """ + Stops a Cloud Data Fusion pipeline. Works for both batch and stream pipelines. + + :param pipeline_name: Your pipeline name. + :type pipeline_name: str + :param instance_url: Endpoint on which the REST APIs is accessible for the instance. + :type instance_url: str + :param namespace: f your pipeline belongs to a Basic edition instance, the namespace ID + is always default. If your pipeline belongs to an Enterprise edition instance, you + can create a namespace. + :type namespace: str + """ + url = os.path.join( + self._base_url(instance_url, namespace), + quote(pipeline_name), + "workflows", + "DataPipelineWorkflow", + "stop", + ) + response = self._cdap_request(url=url, method="POST") + if response.status != 200: + raise AirflowException( + f"Stopping a pipeline failed with code {response.status}" + ) diff --git a/reference/providers/google/cloud/hooks/dataprep.py b/reference/providers/google/cloud/hooks/dataprep.py new file mode 100644 index 0000000..7e2648b --- /dev/null +++ b/reference/providers/google/cloud/hooks/dataprep.py @@ -0,0 +1,122 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Dataprep hook.""" +import json +import os +from typing import Any, Dict + +import requests +from airflow.hooks.base import BaseHook +from requests import HTTPError +from tenacity import retry, stop_after_attempt, wait_exponential + + +class GoogleDataprepHook(BaseHook): + """ + Hook for connection with Dataprep API. + To get connection Dataprep with Airflow you need Dataprep token. + https://clouddataprep.com/documentation/api#section/Authentication + + It should be added to the Connection in Airflow in JSON format. + + """ + + conn_name_attr = "dataprep_conn_id" + default_conn_name = "dataprep_default" + conn_type = "dataprep" + hook_name = "Google Dataprep" + + def __init__(self, dataprep_conn_id: str = default_conn_name) -> None: + super().__init__() + self.dataprep_conn_id = dataprep_conn_id + conn = self.get_connection(self.dataprep_conn_id) + extra_dejson = conn.extra_dejson + self._token = extra_dejson.get("extra__dataprep__token") + self._base_url = extra_dejson.get( + "extra__dataprep__base_url", "https://api.clouddataprep.com" + ) + + @property + def _headers(self) -> Dict[str, str]: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self._token}", + } + return headers + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def get_jobs_for_job_group(self, job_id: int) -> Dict[str, Any]: + """ + Get information about the batch jobs within a Cloud Dataprep job. + + :param job_id: The ID of the job that will be fetched + :type job_id: int + """ + endpoint_path = f"v4/jobGroups/{job_id}/jobs" + url: str = os.path.join(self._base_url, endpoint_path) + response = requests.get(url, headers=self._headers) + self._raise_for_status(response) + return response.json() + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def get_job_group( + self, job_group_id: int, embed: str, include_deleted: bool + ) -> Dict[str, Any]: + """ + Get the specified job group. + A job group is a job that is executed from a specific node in a flow. + + :param job_group_id: The ID of the job that will be fetched + :type job_group_id: int + :param embed: Comma-separated list of objects to pull in as part of the response + :type embed: str + :param include_deleted: if set to "true", will include deleted objects + :type include_deleted: bool + """ + params: Dict[str, Any] = {"embed": embed, "includeDeleted": include_deleted} + endpoint_path = f"v4/jobGroups/{job_group_id}" + url: str = os.path.join(self._base_url, endpoint_path) + response = requests.get(url, headers=self._headers, params=params) + self._raise_for_status(response) + return response.json() + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def run_job_group(self, body_request: dict) -> Dict[str, Any]: + """ + Creates a ``jobGroup``, which launches the specified job as the authenticated user. + This performs the same action as clicking on the Run Job button in the application. + To get recipe_id please follow the Dataprep API documentation + https://clouddataprep.com/documentation/api#operation/runJobGroup + + :param body_request: The identifier for the recipe you would like to run. + :type body_request: dict + """ + endpoint_path = "v4/jobGroups" + url: str = os.path.join(self._base_url, endpoint_path) + response = requests.post( + url, headers=self._headers, data=json.dumps(body_request) + ) + self._raise_for_status(response) + return response.json() + + def _raise_for_status(self, response: requests.models.Response) -> None: + try: + response.raise_for_status() + except HTTPError: + self.log.error(response.json().get("exception")) + raise diff --git a/reference/providers/google/cloud/hooks/dataproc.py b/reference/providers/google/cloud/hooks/dataproc.py new file mode 100644 index 0000000..93de950 --- /dev/null +++ b/reference/providers/google/cloud/hooks/dataproc.py @@ -0,0 +1,951 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains a Google Cloud Dataproc hook.""" + +import time +import uuid +import warnings +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.version import version as airflow_version +from google.api_core.exceptions import ServerError +from google.api_core.retry import Retry +from google.cloud.dataproc_v1beta2 import ( # pylint: disable=no-name-in-module + Cluster, + ClusterControllerClient, + Job, + JobControllerClient, + JobStatus, + WorkflowTemplate, + WorkflowTemplateServiceClient, +) +from google.protobuf.duration_pb2 import Duration +from google.protobuf.field_mask_pb2 import FieldMask + + +class DataProcJobBuilder: + """A helper class for building Dataproc job.""" + + def __init__( + self, + project_id: str, + task_id: str, + cluster_name: str, + job_type: str, + properties: Optional[Dict[str, str]] = None, + ) -> None: + name = task_id + "_" + str(uuid.uuid4())[:8] + self.job_type = job_type + self.job = { + "job": { + "reference": {"project_id": project_id, "job_id": name}, + "placement": {"cluster_name": cluster_name}, + "labels": { + "airflow-version": "v" + + airflow_version.replace(".", "-").replace("+", "-") + }, + job_type: {}, + } + } # type: Dict[str, Any] + if properties is not None: + self.job["job"][job_type]["properties"] = properties + + def add_labels(self, labels: dict) -> None: + """ + Set labels for Dataproc job. + + :param labels: Labels for the job query. + :type labels: dict + """ + if labels: + self.job["job"]["labels"].update(labels) + + def add_variables(self, variables: List[str]) -> None: + """ + Set variables for Dataproc job. + + :param variables: Variables for the job query. + :type variables: List[str] + """ + if variables is not None: + self.job["job"][self.job_type]["script_variables"] = variables + + def add_args(self, args: List[str]) -> None: + """ + Set args for Dataproc job. + + :param args: Args for the job query. + :type args: List[str] + """ + if args is not None: + self.job["job"][self.job_type]["args"] = args + + def add_query(self, query: List[str]) -> None: + """ + Set query uris for Dataproc job. + + :param query: URIs for the job queries. + :type query: List[str] + """ + self.job["job"][self.job_type]["query_list"] = {"queries": [query]} + + def add_query_uri(self, query_uri: str) -> None: + """ + Set query uri for Dataproc job. + + :param query_uri: URI for the job query. + :type query_uri: str + """ + self.job["job"][self.job_type]["query_file_uri"] = query_uri + + def add_jar_file_uris(self, jars: List[str]) -> None: + """ + Set jars uris for Dataproc job. + + :param jars: List of jars URIs + :type jars: List[str] + """ + if jars is not None: + self.job["job"][self.job_type]["jar_file_uris"] = jars + + def add_archive_uris(self, archives: List[str]) -> None: + """ + Set archives uris for Dataproc job. + + :param archives: List of archives URIs + :type archives: List[str] + """ + if archives is not None: + self.job["job"][self.job_type]["archive_uris"] = archives + + def add_file_uris(self, files: List[str]) -> None: + """ + Set file uris for Dataproc job. + + :param files: List of files URIs + :type files: List[str] + """ + if files is not None: + self.job["job"][self.job_type]["file_uris"] = files + + def add_python_file_uris(self, pyfiles: List[str]) -> None: + """ + Set python file uris for Dataproc job. + + :param pyfiles: List of python files URIs + :type pyfiles: List[str] + """ + if pyfiles is not None: + self.job["job"][self.job_type]["python_file_uris"] = pyfiles + + def set_main(self, main_jar: Optional[str], main_class: Optional[str]) -> None: + """ + Set Dataproc main class. + + :param main_jar: URI for the main file. + :type main_jar: str + :param main_class: Name of the main class. + :type main_class: str + :raises: Exception + """ + if main_class is not None and main_jar is not None: + raise Exception("Set either main_jar or main_class") + if main_jar: + self.job["job"][self.job_type]["main_jar_file_uri"] = main_jar + else: + self.job["job"][self.job_type]["main_class"] = main_class + + def set_python_main(self, main: str) -> None: + """ + Set Dataproc main python file uri. + + :param main: URI for the python main file. + :type main: str + """ + self.job["job"][self.job_type]["main_python_file_uri"] = main + + def set_job_name(self, name: str) -> None: + """ + Set Dataproc job name. + + :param name: Job name. + :type name: str + """ + self.job["job"]["reference"]["job_id"] = name + "_" + str(uuid.uuid4())[:8] + + def build(self) -> Dict: + """ + Returns Dataproc job. + + :return: Dataproc job + :rtype: dict + """ + return self.job + + +class DataprocHook(GoogleBaseHook): + """ + Hook for Google Cloud Dataproc APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + def get_cluster_client( + self, location: Optional[str] = None + ) -> ClusterControllerClient: + """Returns ClusterControllerClient.""" + client_options = None + if location and location != "global": + client_options = {"api_endpoint": f"{location}-dataproc.googleapis.com:443"} + + return ClusterControllerClient( + credentials=self._get_credentials(), + client_info=self.client_info, + client_options=client_options, + ) + + def get_template_client( + self, location: Optional[str] = None + ) -> WorkflowTemplateServiceClient: + """Returns WorkflowTemplateServiceClient.""" + client_options = None + if location and location != "global": + client_options = {"api_endpoint": f"{location}-dataproc.googleapis.com:443"} + + return WorkflowTemplateServiceClient( + credentials=self._get_credentials(), + client_info=self.client_info, + client_options=client_options, + ) + + def get_job_client(self, location: Optional[str] = None) -> JobControllerClient: + """Returns JobControllerClient.""" + client_options = None + if location and location != "global": + client_options = {"api_endpoint": f"{location}-dataproc.googleapis.com:443"} + + return JobControllerClient( + credentials=self._get_credentials(), + client_info=self.client_info, + client_options=client_options, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_cluster( + self, + region: str, + project_id: str, + cluster_name: str, + cluster_config: Union[Dict, Cluster], + labels: Optional[Dict[str, str]] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Creates a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param cluster_name: Name of the cluster to create + :type cluster_name: str + :param labels: Labels that will be assigned to created cluster + :type labels: Dict[str, str] + :param cluster_config: Required. The cluster config to create. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1.types.ClusterConfig` + :type cluster_config: Union[Dict, google.cloud.dataproc_v1.types.ClusterConfig] + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``CreateClusterRequest`` requests with the same id, then the second request will be ignored and + the first ``google.longrunning.Operation`` created and stored in the backend is returned. + :type request_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + # Dataproc labels must conform to the following regex: + # [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows + # semantic versioning spec: x.y.z). + labels = labels or {} + labels.update( + { + "airflow-version": "v" + + airflow_version.replace(".", "-").replace("+", "-") + } + ) + + cluster = { + "project_id": project_id, + "cluster_name": cluster_name, + "config": cluster_config, + "labels": labels, + } + + client = self.get_cluster_client(location=region) + result = client.create_cluster( + request={ + "project_id": project_id, + "region": region, + "cluster": cluster, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_cluster( + self, + region: str, + cluster_name: str, + project_id: str, + cluster_uuid: Optional[str] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Deletes a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param cluster_name: Required. The cluster name. + :type cluster_name: str + :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail + if cluster with specified UUID does not exist. + :type cluster_uuid: str + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and + the first ``google.longrunning.Operation`` created and stored in the backend is returned. + :type request_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_cluster_client(location=region) + result = client.delete_cluster( + request={ + "project_id": project_id, + "region": region, + "cluster_name": cluster_name, + "cluster_uuid": cluster_uuid, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def diagnose_cluster( + self, + region: str, + cluster_name: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Gets cluster diagnostic information. After the operation completes GCS uri to + diagnose is returned + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param cluster_name: Required. The cluster name. + :type cluster_name: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_cluster_client(location=region) + operation = client.diagnose_cluster( + request={ + "project_id": project_id, + "region": region, + "cluster_name": cluster_name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + operation.result() + gcs_uri = str(operation.operation.response.value) + return gcs_uri + + @GoogleBaseHook.fallback_to_default_project_id + def get_cluster( + self, + region: str, + cluster_name: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Gets the resource representation for a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param cluster_name: Required. The cluster name. + :type cluster_name: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_cluster_client(location=region) + result = client.get_cluster( + request={ + "project_id": project_id, + "region": region, + "cluster_name": cluster_name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_clusters( + self, + region: str, + filter_: str, + project_id: str, + page_size: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Lists all regions/{region}/clusters in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request. + :type region: str + :param filter_: Optional. A filter constraining the clusters to list. Filters are case-sensitive. + :type filter_: str + :param page_size: The maximum number of resources contained in the underlying API response. If page + streaming is performed per- resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number of resources in a page. + :type page_size: int + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_cluster_client(location=region) + result = client.list_clusters( + request={ + "project_id": project_id, + "region": region, + "filter": filter_, + "page_size": page_size, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def update_cluster( # pylint: disable=too-many-arguments + self, + location: str, + cluster_name: str, + cluster: Union[Dict, Cluster], + update_mask: Union[Dict, FieldMask], + project_id: str, + graceful_decommission_timeout: Optional[Union[Dict, Duration]] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Updates a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The Cloud Dataproc region in which to handle the request. + :type location: str + :param cluster_name: Required. The cluster name. + :type cluster_name: str + :param cluster: Required. The changes to the cluster. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1.types.Cluster` + :type cluster: Union[Dict, google.cloud.dataproc_v1.types.Cluster] + :param update_mask: Required. Specifies the path, relative to ``Cluster``, of the field to update. For + example, to change the number of workers in a cluster to 5, the ``update_mask`` parameter would be + specified as ``config.worker_config.num_instances``, and the ``PATCH`` request body would specify + the new value, as follows: + + :: + + { "config":{ "workerConfig":{ "numInstances":"5" } } } + + Similarly, to change the number of preemptible workers in a cluster to 5, the ``update_mask`` + parameter would be ``config.secondary_worker_config.num_instances``, and the ``PATCH`` request + body would be set as follows: + + :: + + { "config":{ "secondaryWorkerConfig":{ "numInstances":"5" } } } + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1.types.FieldMask` + :type update_mask: Union[Dict, google.cloud.dataproc_v1.types.FieldMask] + :param graceful_decommission_timeout: Optional. Timeout for graceful YARN decommissioning. Graceful + decommissioning allows removing nodes from the cluster without interrupting jobs in progress. + Timeout specifies how long to wait for jobs in progress to finish before forcefully removing nodes + (and potentially interrupting jobs). Default timeout is 0 (for forceful decommission), and the + maximum allowed timeout is 1 day. + + Only supported on Dataproc image versions 1.2 and higher. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1.types.Duration` + :type graceful_decommission_timeout: Union[Dict, google.cloud.dataproc_v1.types.Duration] + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``UpdateClusterRequest`` requests with the same id, then the second request will be ignored and + the first ``google.longrunning.Operation`` created and stored in the backend is returned. + :type request_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_cluster_client(location=location) + operation = client.update_cluster( + request={ + "project_id": project_id, + "region": location, + "cluster_name": cluster_name, + "cluster": cluster, + "update_mask": update_mask, + "graceful_decommission_timeout": graceful_decommission_timeout, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return operation + + @GoogleBaseHook.fallback_to_default_project_id + def create_workflow_template( + self, + location: str, + template: Union[Dict, WorkflowTemplate], + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> WorkflowTemplate: + """ + Creates new workflow template. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The Cloud Dataproc region in which to handle the request. + :type location: str + :param template: The Dataproc workflow template to create. If a dict is provided, + it must be of the same form as the protobuf message WorkflowTemplate. + :type template: Union[dict, WorkflowTemplate] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + metadata = metadata or () + client = self.get_template_client(location) + parent = f"projects/{project_id}/regions/{location}" + return client.create_workflow_template( + request={"parent": parent, "template": template}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def instantiate_workflow_template( + self, + location: str, + template_name: str, + project_id: str, + version: Optional[int] = None, + request_id: Optional[str] = None, + parameters: Optional[Dict[str, str]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Instantiates a template and begins execution. + + :param template_name: Name of template to instantiate. + :type template_name: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The Cloud Dataproc region in which to handle the request. + :type location: str + :param version: Optional. The version of workflow template to instantiate. If specified, + the workflow will be instantiated only if the current version of + the workflow template has the supplied version. + This option cannot be used to instantiate a previous version of + workflow template. + :type version: int + :param request_id: Optional. A tag that prevents multiple concurrent workflow instances + with the same tag from running. This mitigates risk of concurrent + instances started due to retries. + :type request_id: str + :param parameters: Optional. Map from parameter names to values that should be used for those + parameters. Values may not exceed 100 characters. + :type parameters: Dict[str, str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + metadata = metadata or () + client = self.get_template_client(location) + name = f"projects/{project_id}/regions/{location}/workflowTemplates/{template_name}" + operation = client.instantiate_workflow_template( + request={ + "name": name, + "version": version, + "request_id": request_id, + "parameters": parameters, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return operation + + @GoogleBaseHook.fallback_to_default_project_id + def instantiate_inline_workflow_template( + self, + location: str, + template: Union[Dict, WorkflowTemplate], + project_id: str, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + Instantiates a template and begins execution. + + :param template: The workflow template to instantiate. If a dict is provided, + it must be of the same form as the protobuf message WorkflowTemplate + :type template: Union[Dict, WorkflowTemplate] + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The Cloud Dataproc region in which to handle the request. + :type location: str + :param request_id: Optional. A tag that prevents multiple concurrent workflow instances + with the same tag from running. This mitigates risk of concurrent + instances started due to retries. + :type request_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + metadata = metadata or () + client = self.get_template_client(location) + parent = f"projects/{project_id}/regions/{location}" + operation = client.instantiate_inline_workflow_template( + request={"parent": parent, "template": template, "request_id": request_id}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return operation + + @GoogleBaseHook.fallback_to_default_project_id + def wait_for_job( + self, + job_id: str, + location: str, + project_id: str, + wait_time: int = 10, + timeout: Optional[int] = None, + ) -> None: + """ + Helper method which polls a job to check if it finishes. + + :param job_id: Id of the Dataproc job + :type job_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The Cloud Dataproc region in which to handle the request. + :type location: str + :param wait_time: Number of seconds between checks + :type wait_time: int + :param timeout: How many seconds wait for job to be ready. Used only if ``asynchronous`` is False + :type timeout: int + """ + state = None + start = time.monotonic() + while state not in ( + JobStatus.State.ERROR, + JobStatus.State.DONE, + JobStatus.State.CANCELLED, + ): + if timeout and start + timeout < time.monotonic(): + raise AirflowException( + f"Timeout: dataproc job {job_id} is not ready after {timeout}s" + ) + time.sleep(wait_time) + try: + job = self.get_job( + project_id=project_id, location=location, job_id=job_id + ) + state = job.status.state + except ServerError as err: + self.log.info( + "Retrying. Dataproc API returned server error when waiting for job: %s", + err, + ) + + if state == JobStatus.State.ERROR: + raise AirflowException(f"Job failed:\n{job}") + if state == JobStatus.State.CANCELLED: + raise AirflowException(f"Job was cancelled:\n{job}") + + @GoogleBaseHook.fallback_to_default_project_id + def get_job( + self, + location: str, + job_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Job: + """ + Gets the resource representation for a job in a project. + + :param job_id: Id of the Dataproc job + :type job_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The Cloud Dataproc region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_job_client(location=location) + job = client.get_job( + request={"project_id": project_id, "region": location, "job_id": job_id}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return job + + @GoogleBaseHook.fallback_to_default_project_id + def submit_job( + self, + location: str, + job: Union[dict, Job], + project_id: str, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Job: + """ + Submits a job to a cluster. + + :param job: The job resource. If a dict is provided, + it must be of the same form as the protobuf message Job + :type job: Union[Dict, Job] + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The Cloud Dataproc region in which to handle the request. + :type location: str + :param request_id: Optional. A tag that prevents multiple concurrent workflow instances + with the same tag from running. This mitigates risk of concurrent + instances started due to retries. + :type request_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_job_client(location=location) + return client.submit_job( + request={ + "project_id": project_id, + "region": location, + "job": job, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def submit( + self, + project_id: str, + job: dict, + region: str = "global", + job_error_states: Optional[ + Iterable[str] + ] = None, # pylint: disable=unused-argument + ) -> None: + """ + Submits Google Cloud Dataproc job. + + :param project_id: The id of Google Cloud Dataproc project. + :type project_id: str + :param job: The job to be submitted + :type job: dict + :param region: The region of Google Dataproc cluster. + :type region: str + :param job_error_states: Job states that should be considered error states. + :type job_error_states: List[str] + """ + # TODO: Remover one day + warnings.warn( + "This method is deprecated. Please use `submit_job`", + DeprecationWarning, + stacklevel=2, + ) + job_object = self.submit_job(location=region, project_id=project_id, job=job) + job_id = job_object.reference.job_id + self.wait_for_job(job_id=job_id, location=region, project_id=project_id) + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_job( + self, + job_id: str, + project_id: str, + location: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Job: + """ + Starts a job cancellation request. + + :param project_id: Required. The ID of the Google Cloud project that the job belongs to. + :type project_id: str + :param location: Required. The Cloud Dataproc region in which to handle the request. + :type location: str + :param job_id: Required. The job ID. + :type job_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + if location is None: + warnings.warn( + "Default location value `global` will be deprecated. Please, provide location value.", + DeprecationWarning, + stacklevel=2, + ) + location = "global" + client = self.get_job_client(location=location) + + job = client.cancel_job( + request={"project_id": project_id, "region": location, "job_id": job_id}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return job diff --git a/reference/providers/google/cloud/hooks/datastore.py b/reference/providers/google/cloud/hooks/datastore.py new file mode 100644 index 0000000..c42680d --- /dev/null +++ b/reference/providers/google/cloud/hooks/datastore.py @@ -0,0 +1,426 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains Google Datastore hook.""" + + +import time +import warnings +from typing import Any, Dict, Optional, Sequence, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import Resource, build + + +class DatastoreHook(GoogleBaseHook): + """ + Interact with Google Cloud Datastore. This hook uses the Google Cloud connection. + + This object is not threads safe. If you want to make multiple requests + simultaneously, you will need to create a hook per thread. + + :param api_version: The version of the API it is going to connect to. + :type api_version: str + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + api_version: str = "v1", + datastore_conn_id: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + if datastore_conn_id: + warnings.warn( + "The datastore_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=2, + ) + gcp_conn_id = datastore_conn_id + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.connection = None + self.api_version = api_version + + def get_conn(self) -> Re# + """ + Establishes a connection to the Google API. + + :return: a Google Cloud Datastore service object. + :rtype: Resource + """ + if not self.connection: + http_authorized = self._authorize() + self.connection = build( + "datastore", + self.api_version, + http=http_authorized, + cache_discovery=False, + ) + + return self.connection + + @GoogleBaseHook.fallback_to_default_project_id + def allocate_ids(self, partial_keys: list, project_id: str) -> list: + """ + Allocate IDs for incomplete keys. + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/rest/v1/projects/allocateIds + + :param partial_keys: a list of partial keys. + :type partial_keys: list + :param project_id: Google Cloud project ID against which to make the request. + :type project_id: str + :return: a list of full keys. + :rtype: list + """ + conn = self.get_conn() # type: Any + + resp = ( + conn.projects() # pylint: disable=no-member + .allocateIds(projectId=project_id, body={"keys": partial_keys}) + .execute(num_retries=self.num_retries) + ) + + return resp["keys"] + + @GoogleBaseHook.fallback_to_default_project_id + def begin_transaction( + self, project_id: str, transaction_options: Dict[str, Any] + ) -> str: + """ + Begins a new transaction. + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/rest/v1/projects/beginTransaction + + :param project_id: Google Cloud project ID against which to make the request. + :type project_id: str + :param transaction_options: Options for a new transaction. + :type transaction_options: Dict[str, Any] + :return: a transaction handle. + :rtype: str + """ + conn = self.get_conn() # type: Any + + resp = ( + conn.projects() # pylint: disable=no-member + .beginTransaction( + projectId=project_id, body={"transactionOptions": transaction_options} + ) + .execute(num_retries=self.num_retries) + ) + + return resp["transaction"] + + @GoogleBaseHook.fallback_to_default_project_id + def commit(self, body: dict, project_id: str) -> dict: + """ + Commit a transaction, optionally creating, deleting or modifying some entities. + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/rest/v1/projects/commit + + :param body: the body of the commit request. + :type body: dict + :param project_id: Google Cloud project ID against which to make the request. + :type project_id: str + :return: the response body of the commit request. + :rtype: dict + """ + conn = self.get_conn() # type: Any + + resp = ( + conn.projects() # pylint: disable=no-member + .commit(projectId=project_id, body=body) + .execute(num_retries=self.num_retries) + ) + + return resp + + @GoogleBaseHook.fallback_to_default_project_id + def lookup( + self, + keys: list, + project_id: str, + read_consistency: Optional[str] = None, + transaction: Optional[str] = None, + ) -> dict: + """ + Lookup some entities by key. + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/rest/v1/projects/lookup + + :param keys: the keys to lookup. + :type keys: list + :param read_consistency: the read consistency to use. default, strong or eventual. + Cannot be used with a transaction. + :type read_consistency: str + :param transaction: the transaction to use, if any. + :type transaction: str + :param project_id: Google Cloud project ID against which to make the request. + :type project_id: str + :return: the response body of the lookup request. + :rtype: dict + """ + conn = self.get_conn() # type: Any + + body = {"keys": keys} # type: Dict[str, Any] + if read_consistency: + body["readConsistency"] = read_consistency + if transaction: + body["transaction"] = transaction + resp = ( + conn.projects() # pylint: disable=no-member + .lookup(projectId=project_id, body=body) + .execute(num_retries=self.num_retries) + ) + + return resp + + @GoogleBaseHook.fallback_to_default_project_id + def rollback(self, transaction: str, project_id: str) -> None: + """ + Roll back a transaction. + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/rest/v1/projects/rollback + + :param transaction: the transaction to roll back. + :type transaction: str + :param project_id: Google Cloud project ID against which to make the request. + :type project_id: str + """ + conn: Any = self.get_conn() + + conn.projects().rollback( # pylint: disable=no-member + projectId=project_id, body={"transaction": transaction} + ).execute(num_retries=self.num_retries) + + @GoogleBaseHook.fallback_to_default_project_id + def run_query(self, body: dict, project_id: str) -> dict: + """ + Run a query for entities. + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/rest/v1/projects/runQuery + + :param body: the body of the query request. + :type body: dict + :param project_id: Google Cloud project ID against which to make the request. + :type project_id: str + :return: the batch of query results. + :rtype: dict + """ + conn = self.get_conn() # type: Any + + resp = ( + conn.projects() # pylint: disable=no-member + .runQuery(projectId=project_id, body=body) + .execute(num_retries=self.num_retries) + ) + + return resp["batch"] + + def get_operation(self, name: str) -> dict: + """ + Gets the latest state of a long-running operation. + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects.operations/get + + :param name: the name of the operation resource. + :type name: str + :return: a resource operation instance. + :rtype: dict + """ + conn: Any = self.get_conn() + + resp = ( + conn.projects() # pylint: disable=no-member + .operations() + .get(name=name) + .execute(num_retries=self.num_retries) + ) + + return resp + + def delete_operation(self, name: str) -> dict: + """ + Deletes the long-running operation. + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects.operations/delete + + :param name: the name of the operation resource. + :type name: str + :return: none if successful. + :rtype: dict + """ + conn = self.get_conn() # type: Any + + resp = ( + conn.projects() # pylint: disable=no-member + .operations() + .delete(name=name) + .execute(num_retries=self.num_retries) + ) + + return resp + + def poll_operation_until_done( + self, name: str, polling_interval_in_seconds: int + ) -> Dict: + """ + Poll backup operation state until it's completed. + + :param name: the name of the operation resource + :type name: str + :param polling_interval_in_seconds: The number of seconds to wait before calling another request. + :type polling_interval_in_seconds: int + :return: a resource operation instance. + :rtype: dict + """ + while True: + result = self.get_operation(name) # type: Dict + + state = result["metadata"]["common"]["state"] # type: str + if state == "PROCESSING": + self.log.info( + "Operation is processing. Re-polling state in %s seconds", + polling_interval_in_seconds, + ) + time.sleep(polling_interval_in_seconds) + else: + return result + + @GoogleBaseHook.fallback_to_default_project_id + def export_to_storage_bucket( + self, + bucket: str, + project_id: str, + namespace: Optional[str] = None, + entity_filter: Optional[dict] = None, + labels: Optional[Dict[str, str]] = None, + ) -> dict: + """ + Export entities from Cloud Datastore to Cloud Storage for backup. + + .. note:: + Keep in mind that this requests the Admin API not the Data API. + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/admin/rest/v1/projects/export + + :param bucket: The name of the Cloud Storage bucket. + :type bucket: str + :param namespace: The Cloud Storage namespace path. + :type namespace: str + :param entity_filter: Description of what data from the project is included in the export. + :type entity_filter: dict + :param labels: Client-assigned labels. + :type labels: dict of str + :param project_id: Google Cloud project ID against which to make the request. + :type project_id: str + :return: a resource operation instance. + :rtype: dict + """ + admin_conn = self.get_conn() # type: Any + + output_uri_prefix = "gs://" + "/".join( + filter(None, [bucket, namespace]) + ) # type: str + if not entity_filter: + entity_filter = {} + if not labels: + labels = {} + body = { + "outputUrlPrefix": output_uri_prefix, + "entityFilter": entity_filter, + "labels": labels, + } # type: Dict + resp = ( + admin_conn.projects() # pylint: disable=no-member + .export(projectId=project_id, body=body) + .execute(num_retries=self.num_retries) + ) + + return resp + + @GoogleBaseHook.fallback_to_default_project_id + def import_from_storage_bucket( + self, + bucket: str, + file: str, + project_id: str, + namespace: Optional[str] = None, + entity_filter: Optional[dict] = None, + labels: Optional[Union[dict, str]] = None, + ) -> dict: + """ + Import a backup from Cloud Storage to Cloud Datastore. + + .. note:: + Keep in mind that this requests the Admin API not the Data API. + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/admin/rest/v1/projects/import + + :param bucket: The name of the Cloud Storage bucket. + :type bucket: str + :param file: the metadata file written by the projects.export operation. + :type file: str + :param namespace: The Cloud Storage namespace path. + :type namespace: str + :param entity_filter: specify which kinds/namespaces are to be imported. + :type entity_filter: dict + :param labels: Client-assigned labels. + :type labels: dict of str + :param project_id: Google Cloud project ID against which to make the request. + :type project_id: str + :return: a resource operation instance. + :rtype: dict + """ + admin_conn = self.get_conn() # type: Any + + input_url = "gs://" + "/".join( + filter(None, [bucket, namespace, file]) + ) # type: str + if not entity_filter: + entity_filter = {} + if not labels: + labels = {} + body = { + "inputUrl": input_url, + "entityFilter": entity_filter, + "labels": labels, + } # type: Dict + resp = ( + admin_conn.projects() # pylint: disable=no-member + .import_(projectId=project_id, body=body) + .execute(num_retries=self.num_retries) + ) + + return resp diff --git a/reference/providers/google/cloud/hooks/dlp.py b/reference/providers/google/cloud/hooks/dlp.py new file mode 100644 index 0000000..68e05c7 --- /dev/null +++ b/reference/providers/google/cloud/hooks/dlp.py @@ -0,0 +1,1794 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This module contains a CloudDLPHook +which allows you to connect to Google Cloud DLP service. +""" + +import re +import time +from typing import List, Optional, Sequence, Tuple, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.retry import Retry +from google.cloud.dlp_v2 import DlpServiceClient +from google.cloud.dlp_v2.types import ( + ByteContentItem, + ContentItem, + DeidentifyConfig, + DeidentifyContentResponse, + DeidentifyTemplate, + DlpJob, + FieldMask, + InspectConfig, + InspectContentResponse, + InspectJobConfig, + InspectTemplate, + JobTrigger, + ListInfoTypesResponse, + RedactImageRequest, + RedactImageResponse, + ReidentifyContentResponse, + RiskAnalysisJobConfig, + StoredInfoType, + StoredInfoTypeConfig, +) + +DLP_JOB_PATH_PATTERN = "^projects/[^/]+/dlpJobs/(?P.*?)$" + + +# pylint: disable=R0904, C0302 +class CloudDLPHook(GoogleBaseHook): + """ + Hook for Google Cloud Data Loss Prevention (DLP) APIs. + Cloud DLP allows clients to detect the presence of Personally Identifiable + Information (PII) and other privacy-sensitive data in user-supplied, + unstructured data streams, like text blocks or images. The service also + includes methods for sensitive data redaction and scheduling of data scans + on Google Cloud based data sets. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client = None + + def get_conn(self) -> DlpServiceClient: + """ + Provides a client for interacting with the Cloud DLP API. + + :return: Google Cloud DLP API Client + :rtype: google.cloud.dlp_v2.DlpServiceClient + """ + if not self._client: + self._client = DlpServiceClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._client + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_dlp_job( + self, + dlp_job_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Starts asynchronous cancellation on a long-running DLP job. + + :param dlp_job_id: ID of the DLP job resource to be cancelled. + :type dlp_job_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default project_id + from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + + if not dlp_job_id: + raise AirflowException( + "Please provide the ID of the DLP job resource to be cancelled." + ) + + name = DlpServiceClient.dlp_job_path(project_id, dlp_job_id) + client.cancel_dlp_job( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + + def create_deidentify_template( + self, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + deidentify_template: Optional[Union[dict, DeidentifyTemplate]] = None, + template_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> DeidentifyTemplate: + """ + Creates a deidentify template for re-using frequently used configuration for + de-identifying content, images, and storage. + + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param deidentify_template: (Optional) The de-identify template to create. + :type deidentify_template: dict or google.cloud.dlp_v2.types.DeidentifyTemplate + :param template_id: (Optional) The template ID. + :type template_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.DeidentifyTemplate + """ + client = self.get_conn() + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + parent = DlpServiceClient.organization_path(organization_id) + elif project_id: + parent = DlpServiceClient.project_path(project_id) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + return client.create_deidentify_template( + parent=parent, + deidentify_template=deidentify_template, + template_id=template_id, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_dlp_job( + self, + project_id: str, + inspect_job: Optional[Union[dict, InspectJobConfig]] = None, + risk_job: Optional[Union[dict, RiskAnalysisJobConfig]] = None, + job_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + wait_until_finished: bool = True, + time_to_sleep_in_seconds: int = 60, + ) -> DlpJob: + """ + Creates a new job to inspect storage or calculate risk metrics. + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param inspect_job: (Optional) The configuration for the inspect job. + :type inspect_job: dict or google.cloud.dlp_v2.types.InspectJobConfig + :param risk_job: (Optional) The configuration for the risk job. + :type risk_job: dict or google.cloud.dlp_v2.types.RiskAnalysisJobConfig + :param job_id: (Optional) The job ID. + :type job_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param wait_until_finished: (Optional) If true, it will keep polling the job state + until it is set to DONE. + :type wait_until_finished: bool + :rtype: google.cloud.dlp_v2.types.DlpJob + :param time_to_sleep_in_seconds: (Optional) Time to sleep, in seconds, between active checks + of the operation results. Defaults to 60. + :type time_to_sleep_in_seconds: int + """ + client = self.get_conn() + + parent = DlpServiceClient.project_path(project_id) + job = client.create_dlp_job( + parent=parent, + inspect_job=inspect_job, + risk_job=risk_job, + job_id=job_id, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + if wait_until_finished: + pattern = re.compile(DLP_JOB_PATH_PATTERN, re.IGNORECASE) + match = pattern.match(job.name) + if match is not None: + job_name = match.groupdict()["job"] + else: + raise AirflowException( + f"Unable to retrieve DLP job's ID from {job.name}." + ) + + while wait_until_finished: + job = self.get_dlp_job(dlp_job_id=job_name, project_id=project_id) + + self.log.info( + "DLP job %s state: %s.", job.name, DlpJob.JobState.Name(job.state) + ) + + if job.state == DlpJob.JobState.DONE: + return job + elif job.state in [ + DlpJob.JobState.PENDING, + DlpJob.JobState.RUNNING, + DlpJob.JobState.JOB_STATE_UNSPECIFIED, + ]: + time.sleep(time_to_sleep_in_seconds) + else: + raise AirflowException( + "Stopped polling DLP job state. DLP job {} state: {}.".format( + job.name, DlpJob.JobState.Name(job.state) + ) + ) + return job + + def create_inspect_template( + self, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + inspect_template: Optional[Union[dict, InspectTemplate]] = None, + template_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> InspectTemplate: + """ + Creates an inspect template for re-using frequently used configuration for + inspecting content, images, and storage. + + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param inspect_template: (Optional) The inspect template to create. + :type inspect_template: dict or google.cloud.dlp_v2.types.InspectTemplate + :param template_id: (Optional) The template ID. + :type template_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.InspectTemplate + """ + client = self.get_conn() + + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + parent = DlpServiceClient.organization_path(organization_id) + elif project_id: + parent = DlpServiceClient.project_path(project_id) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + return client.create_inspect_template( + parent=parent, + inspect_template=inspect_template, + template_id=template_id, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_job_trigger( + self, + project_id: str, + job_trigger: Optional[Union[dict, JobTrigger]] = None, + trigger_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> JobTrigger: + """ + Creates a job trigger to run DLP actions such as scanning storage for sensitive + information on a set schedule. + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param job_trigger: (Optional) The job trigger to create. + :type job_trigger: dict or google.cloud.dlp_v2.types.JobTrigger + :param trigger_id: (Optional) The job trigger ID. + :type trigger_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.JobTrigger + """ + client = self.get_conn() + + parent = DlpServiceClient.project_path(project_id) + return client.create_job_trigger( + parent=parent, + job_trigger=job_trigger, + trigger_id=trigger_id, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def create_stored_info_type( + self, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + config: Optional[Union[dict, StoredInfoTypeConfig]] = None, + stored_info_type_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> StoredInfoType: + """ + Creates a pre-built stored info type to be used for inspection. + + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param config: (Optional) The config for the stored info type. + :type config: dict or google.cloud.dlp_v2.types.StoredInfoTypeConfig + :param stored_info_type_id: (Optional) The stored info type ID. + :type stored_info_type_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.StoredInfoType + """ + client = self.get_conn() + + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + parent = DlpServiceClient.organization_path(organization_id) + elif project_id: + parent = DlpServiceClient.project_path(project_id) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + return client.create_stored_info_type( + parent=parent, + config=config, + stored_info_type_id=stored_info_type_id, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def deidentify_content( + self, + project_id: str, + deidentify_config: Optional[Union[dict, DeidentifyConfig]] = None, + inspect_config: Optional[Union[dict, InspectConfig]] = None, + item: Optional[Union[dict, ContentItem]] = None, + inspect_template_name: Optional[str] = None, + deidentify_template_name: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> DeidentifyContentResponse: + """ + De-identifies potentially sensitive info from a content item. This method has limits + on input size and output size. + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param deidentify_config: (Optional) Configuration for the de-identification of the + content item. Items specified here will override the template referenced by the + deidentify_template_name argument. + :type deidentify_config: dict or google.cloud.dlp_v2.types.DeidentifyConfig + :param inspect_config: (Optional) Configuration for the inspector. Items specified + here will override the template referenced by the inspect_template_name argument. + :type inspect_config: dict or google.cloud.dlp_v2.types.InspectConfig + :param item: (Optional) The item to de-identify. Will be treated as text. + :type item: dict or google.cloud.dlp_v2.types.ContentItem + :param inspect_template_name: (Optional) Optional template to use. Any configuration + directly specified in inspect_config will override those set in the template. + :type inspect_template_name: str + :param deidentify_template_name: (Optional) Optional template to use. Any + configuration directly specified in deidentify_config will override those set + in the template. + :type deidentify_template_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.DeidentifyContentResponse + """ + client = self.get_conn() + + parent = DlpServiceClient.project_path(project_id) + return client.deidentify_content( + parent=parent, + deidentify_config=deidentify_config, + inspect_config=inspect_config, + item=item, + inspect_template_name=inspect_template_name, + deidentify_template_name=deidentify_template_name, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def delete_deidentify_template( + self, + template_id, + organization_id=None, + project_id=None, + retry=None, + timeout=None, + metadata=None, + ) -> None: + """ + Deletes a deidentify template. + + :param template_id: The ID of deidentify template to be deleted. + :type template_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + + if not template_id: + raise AirflowException( + "Please provide the ID of deidentify template to be deleted." + ) + + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + name = DlpServiceClient.organization_deidentify_template_path( + organization_id, template_id + ) + elif project_id: + name = DlpServiceClient.project_deidentify_template_path( + project_id, template_id + ) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + client.delete_deidentify_template( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_dlp_job( + self, + dlp_job_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes a long-running DLP job. This method indicates that the client is no longer + interested in the DLP job result. The job will be cancelled if possible. + + :param dlp_job_id: The ID of the DLP job resource to be cancelled. + :type dlp_job_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + + if not dlp_job_id: + raise AirflowException( + "Please provide the ID of the DLP job resource to be cancelled." + ) + + name = DlpServiceClient.dlp_job_path(project_id, dlp_job_id) + client.delete_dlp_job( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + + def delete_inspect_template( + self, + template_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes an inspect template. + + :param template_id: The ID of the inspect template to be deleted. + :type template_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + + if not template_id: + raise AirflowException( + "Please provide the ID of the inspect template to be deleted." + ) + + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + name = DlpServiceClient.organization_inspect_template_path( + organization_id, template_id + ) + elif project_id: + name = DlpServiceClient.project_inspect_template_path( + project_id, template_id + ) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + client.delete_inspect_template( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_job_trigger( + self, + job_trigger_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes a job trigger. + + :param job_trigger_id: The ID of the DLP job trigger to be deleted. + :type job_trigger_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + + if not job_trigger_id: + raise AirflowException( + "Please provide the ID of the DLP job trigger to be deleted." + ) + + name = DlpServiceClient.project_job_trigger_path(project_id, job_trigger_id) + client.delete_job_trigger( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + + def delete_stored_info_type( + self, + stored_info_type_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes a stored info type. + + :param stored_info_type_id: The ID of the stored info type to be deleted. + :type stored_info_type_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + + if not stored_info_type_id: + raise AirflowException( + "Please provide the ID of the stored info type to be deleted." + ) + + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + name = DlpServiceClient.organization_stored_info_type_path( + organization_id, stored_info_type_id + ) + elif project_id: + name = DlpServiceClient.project_stored_info_type_path( + project_id, stored_info_type_id + ) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + client.delete_stored_info_type( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + + def get_deidentify_template( + self, + template_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> DeidentifyTemplate: + """ + Gets a deidentify template. + + :param template_id: The ID of deidentify template to be read. + :type template_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.DeidentifyTemplate + """ + client = self.get_conn() + + if not template_id: + raise AirflowException( + "Please provide the ID of the deidentify template to be read." + ) + + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + name = DlpServiceClient.organization_deidentify_template_path( + organization_id, template_id + ) + elif project_id: + name = DlpServiceClient.project_deidentify_template_path( + project_id, template_id + ) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + return client.get_deidentify_template( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_dlp_job( + self, + dlp_job_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> DlpJob: + """ + Gets the latest state of a long-running Dlp Job. + + :param dlp_job_id: The ID of the DLP job resource to be read. + :type dlp_job_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.DlpJob + """ + client = self.get_conn() + + if not dlp_job_id: + raise AirflowException( + "Please provide the ID of the DLP job resource to be read." + ) + + name = DlpServiceClient.dlp_job_path(project_id, dlp_job_id) + return client.get_dlp_job( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + + def get_inspect_template( + self, + template_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> InspectTemplate: + """ + Gets an inspect template. + + :param template_id: The ID of inspect template to be read. + :type template_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.InspectTemplate + """ + client = self.get_conn() + + if not template_id: + raise AirflowException( + "Please provide the ID of the inspect template to be read." + ) + + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + name = DlpServiceClient.organization_inspect_template_path( + organization_id, template_id + ) + elif project_id: + name = DlpServiceClient.project_inspect_template_path( + project_id, template_id + ) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + return client.get_inspect_template( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_job_trigger( + self, + job_trigger_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> JobTrigger: + """ + Gets a DLP job trigger. + + :param job_trigger_id: The ID of the DLP job trigger to be read. + :type job_trigger_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.JobTrigger + """ + client = self.get_conn() + + if not job_trigger_id: + raise AirflowException( + "Please provide the ID of the DLP job trigger to be read." + ) + + name = DlpServiceClient.project_job_trigger_path(project_id, job_trigger_id) + return client.get_job_trigger( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + + def get_stored_info_type( + self, + stored_info_type_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> StoredInfoType: + """ + Gets a stored info type. + + :param stored_info_type_id: The ID of the stored info type to be read. + :type stored_info_type_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.StoredInfoType + """ + client = self.get_conn() + + if not stored_info_type_id: + raise AirflowException( + "Please provide the ID of the stored info type to be read." + ) + + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + name = DlpServiceClient.organization_stored_info_type_path( + organization_id, stored_info_type_id + ) + elif project_id: + name = DlpServiceClient.project_stored_info_type_path( + project_id, stored_info_type_id + ) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + return client.get_stored_info_type( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleBaseHook.fallback_to_default_project_id + def inspect_content( + self, + project_id: str, + inspect_config: Optional[Union[dict, InspectConfig]] = None, + item: Optional[Union[dict, ContentItem]] = None, + inspect_template_name: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> InspectContentResponse: + """ + Finds potentially sensitive info in content. This method has limits on input size, + processing time, and output size. + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param inspect_config: (Optional) Configuration for the inspector. Items specified + here will override the template referenced by the inspect_template_name argument. + :type inspect_config: dict or google.cloud.dlp_v2.types.InspectConfig + :param item: (Optional) The item to de-identify. Will be treated as text. + :type item: dict or google.cloud.dlp_v2.types.ContentItem + :param inspect_template_name: (Optional) Optional template to use. Any configuration + directly specified in inspect_config will override those set in the template. + :type inspect_template_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.InspectContentResponse + """ + client = self.get_conn() + + parent = DlpServiceClient.project_path(project_id) + return client.inspect_content( + parent=parent, + inspect_config=inspect_config, + item=item, + inspect_template_name=inspect_template_name, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def list_deidentify_templates( + self, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + page_size: Optional[int] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> List[DeidentifyTemplate]: + """ + Lists deidentify templates. + + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param order_by: (Optional) Optional comma separated list of fields to order by, + followed by asc or desc postfix. + :type order_by: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: List[google.cloud.dlp_v2.types.DeidentifyTemplate] + """ + client = self.get_conn() + + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + parent = DlpServiceClient.organization_path(organization_id) + elif project_id: + parent = DlpServiceClient.project_path(project_id) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + results = client.list_deidentify_templates( + parent=parent, + page_size=page_size, + order_by=order_by, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + return list(results) + + @GoogleBaseHook.fallback_to_default_project_id + def list_dlp_jobs( + self, + project_id: str, + results_filter: Optional[str] = None, + page_size: Optional[int] = None, + job_type: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> List[DlpJob]: + """ + Lists DLP jobs that match the specified filter in the request. + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param results_filter: (Optional) Filter used to specify a subset of results. + :type results_filter: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param job_type: (Optional) The type of job. + :type job_type: str + :param order_by: (Optional) Optional comma separated list of fields to order by, + followed by asc or desc postfix. + :type order_by: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: List[google.cloud.dlp_v2.types.DlpJob] + """ + client = self.get_conn() + + parent = DlpServiceClient.project_path(project_id) + results = client.list_dlp_jobs( + parent=parent, + filter_=results_filter, + page_size=page_size, + type_=job_type, + order_by=order_by, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return list(results) + + def list_info_types( + self, + language_code: Optional[str] = None, + results_filter: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> ListInfoTypesResponse: + """ + Returns a list of the sensitive information types that the DLP API supports. + + :param language_code: (Optional) Optional BCP-47 language code for localized info + type friendly names. If omitted, or if localized strings are not available, + en-US strings will be returned. + :type language_code: str + :param results_filter: (Optional) Filter used to specify a subset of results. + :type results_filter: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.ListInfoTypesResponse + """ + client = self.get_conn() + + return client.list_info_types( + language_code=language_code, + filter_=results_filter, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def list_inspect_templates( + self, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + page_size: Optional[int] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> List[InspectTemplate]: + """ + Lists inspect templates. + + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param order_by: (Optional) Optional comma separated list of fields to order by, + followed by asc or desc postfix. + :type order_by: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: List[google.cloud.dlp_v2.types.InspectTemplate] + """ + client = self.get_conn() + + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + parent = DlpServiceClient.organization_path(organization_id) + elif project_id: + parent = DlpServiceClient.project_path(project_id) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + results = client.list_inspect_templates( + parent=parent, + page_size=page_size, + order_by=order_by, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return list(results) + + @GoogleBaseHook.fallback_to_default_project_id + def list_job_triggers( + self, + project_id: str, + page_size: Optional[int] = None, + order_by: Optional[str] = None, + results_filter: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> List[JobTrigger]: + """ + Lists job triggers. + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param order_by: (Optional) Optional comma separated list of fields to order by, + followed by asc or desc postfix. + :type order_by: str + :param results_filter: (Optional) Filter used to specify a subset of results. + :type results_filter: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: List[google.cloud.dlp_v2.types.JobTrigger] + """ + client = self.get_conn() + + parent = DlpServiceClient.project_path(project_id) + results = client.list_job_triggers( + parent=parent, + page_size=page_size, + order_by=order_by, + filter_=results_filter, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return list(results) + + def list_stored_info_types( + self, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + page_size: Optional[int] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> List[StoredInfoType]: + """ + Lists stored info types. + + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param order_by: (Optional) Optional comma separated list of fields to order by, + followed by asc or desc postfix. + :type order_by: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: List[google.cloud.dlp_v2.types.StoredInfoType] + """ + client = self.get_conn() + + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + parent = DlpServiceClient.organization_path(organization_id) + elif project_id: + parent = DlpServiceClient.project_path(project_id) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + results = client.list_stored_info_types( + parent=parent, + page_size=page_size, + order_by=order_by, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return list(results) + + @GoogleBaseHook.fallback_to_default_project_id + def redact_image( + self, + project_id: str, + inspect_config: Optional[Union[dict, InspectConfig]] = None, + image_redaction_configs: Optional[ + Union[List[dict], List[RedactImageRequest.ImageRedactionConfig]] + ] = None, + include_findings: Optional[bool] = None, + byte_item: Optional[Union[dict, ByteContentItem]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> RedactImageResponse: + """ + Redacts potentially sensitive info from an image. This method has limits on + input size, processing time, and output size. + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param inspect_config: (Optional) Configuration for the inspector. Items specified + here will override the template referenced by the inspect_template_name argument. + :type inspect_config: dict or google.cloud.dlp_v2.types.InspectConfig + :param image_redaction_configs: (Optional) The configuration for specifying what + content to redact from images. + :type image_redaction_configs: List[dict] or + List[google.cloud.dlp_v2.types.RedactImageRequest.ImageRedactionConfig] + :param include_findings: (Optional) Whether the response should include findings + along with the redacted image. + :type include_findings: bool + :param byte_item: (Optional) The content must be PNG, JPEG, SVG or BMP. + :type byte_item: dict or google.cloud.dlp_v2.types.ByteContentItem + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.RedactImageResponse + """ + client = self.get_conn() + + parent = DlpServiceClient.project_path(project_id) + return client.redact_image( + parent=parent, + inspect_config=inspect_config, + image_redaction_configs=image_redaction_configs, + include_findings=include_findings, + byte_item=byte_item, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def reidentify_content( + self, + project_id: str, + reidentify_config: Optional[Union[dict, DeidentifyConfig]] = None, + inspect_config: Optional[Union[dict, InspectConfig]] = None, + item: Optional[Union[dict, ContentItem]] = None, + inspect_template_name: Optional[str] = None, + reidentify_template_name: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> ReidentifyContentResponse: + """ + Re-identifies content that has been de-identified. + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param reidentify_config: (Optional) Configuration for the re-identification of + the content item. + :type reidentify_config: dict or google.cloud.dlp_v2.types.DeidentifyConfig + :param inspect_config: (Optional) Configuration for the inspector. + :type inspect_config: dict or google.cloud.dlp_v2.types.InspectConfig + :param item: (Optional) The item to re-identify. Will be treated as text. + :type item: dict or google.cloud.dlp_v2.types.ContentItem + :param inspect_template_name: (Optional) Optional template to use. Any configuration + directly specified in inspect_config will override those set in the template. + :type inspect_template_name: str + :param reidentify_template_name: (Optional) Optional template to use. References an + instance of deidentify template. Any configuration directly specified in + reidentify_config or inspect_config will override those set in the template. + :type reidentify_template_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.ReidentifyContentResponse + """ + client = self.get_conn() + + parent = DlpServiceClient.project_path(project_id) + return client.reidentify_content( + parent=parent, + reidentify_config=reidentify_config, + inspect_config=inspect_config, + item=item, + inspect_template_name=inspect_template_name, + reidentify_template_name=reidentify_template_name, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def update_deidentify_template( + self, + template_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + deidentify_template: Optional[Union[dict, DeidentifyTemplate]] = None, + update_mask: Optional[Union[dict, FieldMask]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> DeidentifyTemplate: + """ + Updates the deidentify template. + + :param template_id: The ID of deidentify template to be updated. + :type template_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param deidentify_template: New deidentify template value. + :type deidentify_template: dict or google.cloud.dlp_v2.types.DeidentifyTemplate + :param update_mask: Mask to control which fields get updated. + :type update_mask: dict or google.cloud.dlp_v2.types.FieldMask + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.DeidentifyTemplate + """ + client = self.get_conn() + + if not template_id: + raise AirflowException( + "Please provide the ID of deidentify template to be updated." + ) + + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + name = DlpServiceClient.organization_deidentify_template_path( + organization_id, template_id + ) + elif project_id: + name = DlpServiceClient.project_deidentify_template_path( + project_id, template_id + ) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + return client.update_deidentify_template( + name=name, + deidentify_template=deidentify_template, + update_mask=update_mask, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def update_inspect_template( + self, + template_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + inspect_template: Optional[Union[dict, InspectTemplate]] = None, + update_mask: Optional[Union[dict, FieldMask]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> InspectTemplate: + """ + Updates the inspect template. + + :param template_id: The ID of the inspect template to be updated. + :type template_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param inspect_template: New inspect template value. + :type inspect_template: dict or google.cloud.dlp_v2.types.InspectTemplate + :param update_mask: Mask to control which fields get updated. + :type update_mask: dict or google.cloud.dlp_v2.types.FieldMask + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.InspectTemplate + """ + client = self.get_conn() + + if not template_id: + raise AirflowException( + "Please provide the ID of the inspect template to be updated." + ) + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + name = DlpServiceClient.organization_inspect_template_path( + organization_id, template_id + ) + elif project_id: + name = DlpServiceClient.project_inspect_template_path( + project_id, template_id + ) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + return client.update_inspect_template( + name=name, + inspect_template=inspect_template, + update_mask=update_mask, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def update_job_trigger( + self, + job_trigger_id: str, + project_id: str, + job_trigger: Optional[Union[dict, JobTrigger]] = None, + update_mask: Optional[Union[dict, FieldMask]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> JobTrigger: + """ + Updates a job trigger. + + :param job_trigger_id: The ID of the DLP job trigger to be updated. + :type job_trigger_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param job_trigger: New job trigger value. + :type job_trigger: dict or google.cloud.dlp_v2.types.JobTrigger + :param update_mask: Mask to control which fields get updated. + :type update_mask: dict or google.cloud.dlp_v2.types.FieldMask + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.JobTrigger + """ + client = self.get_conn() + + if not job_trigger_id: + raise AirflowException( + "Please provide the ID of the DLP job trigger to be updated." + ) + + name = DlpServiceClient.project_job_trigger_path(project_id, job_trigger_id) + return client.update_job_trigger( + name=name, + job_trigger=job_trigger, + update_mask=update_mask, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def update_stored_info_type( + self, + stored_info_type_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + config: Optional[Union[dict, StoredInfoTypeConfig]] = None, + update_mask: Optional[Union[dict, FieldMask]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> StoredInfoType: + """ + Updates the stored info type by creating a new version. + + :param stored_info_type_id: The ID of the stored info type to be updated. + :type stored_info_type_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param config: Updated configuration for the stored info type. If not provided, a new + version of the stored info type will be created with the existing configuration. + :type config: dict or google.cloud.dlp_v2.types.StoredInfoTypeConfig + :param update_mask: Mask to control which fields get updated. + :type update_mask: dict or google.cloud.dlp_v2.types.FieldMask + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :rtype: google.cloud.dlp_v2.types.StoredInfoType + """ + client = self.get_conn() + + if not stored_info_type_id: + raise AirflowException( + "Please provide the ID of the stored info type to be updated." + ) + + # Handle project_id from connection configuration + project_id = project_id or self.project_id + + if organization_id: + name = DlpServiceClient.organization_stored_info_type_path( + organization_id, stored_info_type_id + ) + elif project_id: + name = DlpServiceClient.project_stored_info_type_path( + project_id, stored_info_type_id + ) + else: + raise AirflowException( + "Please provide either organization_id or project_id." + ) + + return client.update_stored_info_type( + name=name, + config=config, + update_mask=update_mask, + retry=retry, + timeout=timeout, + metadata=metadata, + ) diff --git a/reference/providers/google/cloud/hooks/functions.py b/reference/providers/google/cloud/hooks/functions.py new file mode 100644 index 0000000..1715163 --- /dev/null +++ b/reference/providers/google/cloud/hooks/functions.py @@ -0,0 +1,258 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Functions Hook.""" +import time +from typing import Any, Dict, List, Optional, Sequence, Union + +import requests +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import build + +# Time to sleep between active checks of the operation results +TIME_TO_SLEEP_IN_SECONDS = 1 + + +class CloudFunctionsHook(GoogleBaseHook): + """ + Hook for the Google Cloud Functions APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + _conn = None # type: Optional[Any] + + def __init__( + self, + api_version: str, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.api_version = api_version + + @staticmethod + def _full_location(project_id: str, location: str) -> str: + """ + Retrieve full location of the function in the form of + ``projects//locations/`` + + :param project_id: The Google Cloud Project project_id where the function belongs. + :type project_id: str + :param location: The location where the function is created. + :type location: str + :return: + """ + return f"projects/{project_id}/locations/{location}" + + def get_conn(self) -> build: + """ + Retrieves the connection to Cloud Functions. + + :return: Google Cloud Function services object. + :rtype: dict + """ + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "cloudfunctions", + self.api_version, + http=http_authorized, + cache_discovery=False, + ) + return self._conn + + def get_function(self, name: str) -> dict: + """ + Returns the Cloud Function with the given name. + + :param name: Name of the function. + :type name: str + :return: A Cloud Functions object representing the function. + :rtype: dict + """ + # fmt: off + return self.get_conn().projects().locations().functions().get( # pylint: disable=no-member + name=name).execute(num_retries=self.num_retries) + # fmt: on + + @GoogleBaseHook.fallback_to_default_project_id + def create_new_function(self, location: str, body: dict, project_id: str) -> None: + """ + Creates a new function in Cloud Function in the location specified in the body. + + :param location: The location of the function. + :type location: str + :param body: The body required by the Cloud Functions insert API. + :type body: dict + :param project_id: Optional, Google Cloud Project project_id where the function belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + # fmt: off + response = self.get_conn().projects().locations().functions().create( # pylint: disable=no-member + location=self._full_location(project_id, location), + body=body + ).execute(num_retries=self.num_retries) + # fmt: on + operation_name = response["name"] + self._wait_for_operation_to_complete(operation_name=operation_name) + + def update_function(self, name: str, body: dict, update_mask: List[str]) -> None: + """ + Updates Cloud Functions according to the specified update mask. + + :param name: The name of the function. + :type name: str + :param body: The body required by the cloud function patch API. + :type body: dict + :param update_mask: The update mask - array of fields that should be patched. + :type update_mask: [str] + :return: None + """ + # fmt: off + response = self.get_conn().projects().locations().functions().patch( # pylint: disable=no-member + updateMask=",".join(update_mask), + name=name, + body=body + ).execute(num_retries=self.num_retries) + # fmt: on + operation_name = response["name"] + self._wait_for_operation_to_complete(operation_name=operation_name) + + @GoogleBaseHook.fallback_to_default_project_id + def upload_function_zip(self, location: str, zip_path: str, project_id: str) -> str: + """ + Uploads zip file with sources. + + :param location: The location where the function is created. + :type location: str + :param zip_path: The path of the valid .zip file to upload. + :type zip_path: str + :param project_id: Optional, Google Cloud Project project_id where the function belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: The upload URL that was returned by generateUploadUrl method. + :rtype: str + """ + # fmt: off + # pylint: disable=no-member # noqa + response = \ + self.get_conn().projects().locations().functions().generateUploadUrl( + parent=self._full_location(project_id, location) + ).execute(num_retries=self.num_retries) + # fmt: on + + upload_url = response.get("uploadUrl") + with open(zip_path, "rb") as file: + requests.put( + url=upload_url, + data=file, + # Those two headers needs to be specified according to: + # https://cloud.google.com/functions/docs/reference/rest/v1/projects.locations.functions/generateUploadUrl + # nopep8 + headers={ + "Content-type": "application/zip", + "x-goog-content-length-range": "0,104857600", + }, + ) + return upload_url + + def delete_function(self, name: str) -> None: + """ + Deletes the specified Cloud Function. + + :param name: The name of the function. + :type name: str + :return: None + """ + # fmt: off + response = self.get_conn().projects().locations().functions().delete( # pylint: disable=no-member + name=name).execute(num_retries=self.num_retries) + # fmt: on + operation_name = response["name"] + self._wait_for_operation_to_complete(operation_name=operation_name) + + @GoogleBaseHook.fallback_to_default_project_id + def call_function( + self, + function_id: str, + input_data: Dict, + location: str, + project_id: str, + ) -> dict: + """ + Synchronously invokes a deployed Cloud Function. To be used for testing + purposes as very limited traffic is allowed. + + :param function_id: ID of the function to be called + :type function_id: str + :param input_data: Input to be passed to the function + :type input_data: Dict + :param location: The location where the function is located. + :type location: str + :param project_id: Optional, Google Cloud Project project_id where the function belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :return: None + """ + name = f"projects/{project_id}/locations/{location}/functions/{function_id}" + # fmt: off + response = self.get_conn().projects().locations().functions().call( # pylint: disable=no-member + name=name, + body=input_data + ).execute(num_retries=self.num_retries) + # fmt: on + if "error" in response: + raise AirflowException(response["error"]) + return response + + def _wait_for_operation_to_complete(self, operation_name: str) -> dict: + """ + Waits for the named operation to complete - checks status of the + asynchronous call. + + :param operation_name: The name of the operation. + :type operation_name: str + :return: The response returned by the operation. + :rtype: dict + :exception: AirflowException in case error is returned. + """ + service = self.get_conn() + while True: + # fmt: off + operation_response = service.operations().get( # pylint: disable=no-member + name=operation_name, + ).execute(num_retries=self.num_retries) + # fmt: on + if operation_response.get("done"): + response = operation_response.get("response") + error = operation_response.get("error") + # Note, according to documentation always either response or error is + # set when "done" == True + if error: + raise AirflowException(str(error)) + return response + time.sleep(TIME_TO_SLEEP_IN_SECONDS) diff --git a/reference/providers/google/cloud/hooks/gcs.py b/reference/providers/google/cloud/hooks/gcs.py new file mode 100644 index 0000000..93373e0 --- /dev/null +++ b/reference/providers/google/cloud/hooks/gcs.py @@ -0,0 +1,1323 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains a Google Cloud Storage hook.""" +import functools +import gzip as gz +import os +import shutil +import time +import warnings +from contextlib import contextmanager +from datetime import datetime +from functools import partial +from io import BytesIO +from os import path +from tempfile import NamedTemporaryFile +from typing import Callable, List, Optional, Sequence, Set, Tuple, TypeVar, Union, cast +from urllib.parse import urlparse + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.utils import timezone +from airflow.version import version +from google.api_core.exceptions import NotFound +from google.cloud import storage +from google.cloud.exceptions import GoogleCloudError + +RT = TypeVar("RT") # pylint: disable=invalid-name +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + +# Use default timeout from google-cloud-storage +DEFAULT_TIMEOUT = 60 + + +def _fallback_object_url_to_object_name_and_bucket_name( + object_url_keyword_arg_name="object_url", + bucket_name_keyword_arg_name="bucket_name", + object_name_keyword_arg_name="object_name", +) -> Callable[[T], T]: + """ + Decorator factory that convert object URL parameter to object name and bucket name parameter. + + :param object_url_keyword_arg_name: Name of the object URL parameter + :type object_url_keyword_arg_name: str + :param bucket_name_keyword_arg_name: Name of the bucket name parameter + :type bucket_name_keyword_arg_name: str + :param object_name_keyword_arg_name: Name of the object name parameter + :type object_name_keyword_arg_name: str + :return: Decorator + """ + + def _wrapper(func: T): + @functools.wraps(func) + def _inner_wrapper(self: "GCSHook", *args, **kwargs) -> RT: + if args: + raise AirflowException( + "You must use keyword arguments in this methods rather than positional" + ) + + object_url = kwargs.get(object_url_keyword_arg_name) + bucket_name = kwargs.get(bucket_name_keyword_arg_name) + object_name = kwargs.get(object_name_keyword_arg_name) + + if object_url and bucket_name and object_name: + raise AirflowException( + "The mutually exclusive parameters. `object_url`, `bucket_name` together " + "with `object_name` parameters are present. " + "Please provide `object_url` or `bucket_name` and `object_name`." + ) + if object_url: + bucket_name, object_name = _parse_gcs_url(object_url) + kwargs[bucket_name_keyword_arg_name] = bucket_name + kwargs[object_name_keyword_arg_name] = object_name + del kwargs[object_url_keyword_arg_name] + + if not object_name or not bucket_name: + raise TypeError( + f"{func.__name__}() missing 2 required positional arguments: " + f"'{bucket_name_keyword_arg_name}' and '{object_name_keyword_arg_name}' " + f"or {object_url_keyword_arg_name}" + ) + if not object_name: + raise TypeError( + f"{func.__name__}() missing 1 required positional argument: " + f"'{object_name_keyword_arg_name}'" + ) + if not bucket_name: + raise TypeError( + f"{func.__name__}() missing 1 required positional argument: " + f"'{bucket_name_keyword_arg_name}'" + ) + + return func(self, *args, **kwargs) + + return cast(T, _inner_wrapper) + + return _wrapper + + +class GCSHook(GoogleBaseHook): + """ + Interact with Google Cloud Storage. This hook uses the Google Cloud + connection. + """ + + _conn = None # type: Optional[storage.Client] + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + google_cloud_storage_conn_id: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + # To preserve backward compatibility + # TODO: remove one day + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=2, + ) + gcp_conn_id = google_cloud_storage_conn_id + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + + def get_conn(self) -> storage.Client: + """Returns a Google Cloud Storage service object.""" + if not self._conn: + self._conn = storage.Client( + credentials=self._get_credentials(), + client_info=self.client_info, + project=self.project_id, + ) + + return self._conn + + def copy( + self, + source_bucket: str, + source_object: str, + destination_bucket: Optional[str] = None, + destination_object: Optional[str] = None, + ) -> None: + """ + Copies an object from a bucket to another, with renaming if requested. + + destination_bucket or destination_object can be omitted, in which case + source bucket/object is used, but not both. + + :param source_bucket: The bucket of the object to copy from. + :type source_bucket: str + :param source_object: The object to copy. + :type source_object: str + :param destination_bucket: The destination of the object to copied to. + Can be omitted; then the same bucket is used. + :type destination_bucket: str + :param destination_object: The (renamed) path of the object if given. + Can be omitted; then the same name is used. + :type destination_object: str + """ + destination_bucket = destination_bucket or source_bucket + destination_object = destination_object or source_object + + if source_bucket == destination_bucket and source_object == destination_object: + + raise ValueError( + "Either source/destination bucket or source/destination object " + "must be different, not both the same: bucket=%s, object=%s" + % (source_bucket, source_object) + ) + if not source_bucket or not source_object: + raise ValueError("source_bucket and source_object cannot be empty.") + + client = self.get_conn() + source_bucket = client.bucket(source_bucket) + source_object = source_bucket.blob(source_object) # type: ignore[attr-defined] + destination_bucket = client.bucket(destination_bucket) + destination_object = source_bucket.copy_blob( # type: ignore[attr-defined] + blob=source_object, + destination_bucket=destination_bucket, + new_name=destination_object, + ) + + self.log.info( + "Object %s in bucket %s copied to object %s in bucket %s", + source_object.name, # type: ignore[attr-defined] + source_bucket.name, # type: ignore[attr-defined] + destination_object.name, # type: ignore[union-attr] + destination_bucket.name, # type: ignore[union-attr] + ) + + def rewrite( + self, + source_bucket: str, + source_object: str, + destination_bucket: str, + destination_object: Optional[str] = None, + ) -> None: + """ + Has the same functionality as copy, except that will work on files + over 5 TB, as well as when copying between locations and/or storage + classes. + + destination_object can be omitted, in which case source_object is used. + + :param source_bucket: The bucket of the object to copy from. + :type source_bucket: str + :param source_object: The object to copy. + :type source_object: str + :param destination_bucket: The destination of the object to copied to. + :type destination_bucket: str + :param destination_object: The (renamed) path of the object if given. + Can be omitted; then the same name is used. + :type destination_object: str + """ + destination_object = destination_object or source_object + if source_bucket == destination_bucket and source_object == destination_object: + raise ValueError( + "Either source/destination bucket or source/destination object " + "must be different, not both the same: bucket=%s, object=%s" + % (source_bucket, source_object) + ) + if not source_bucket or not source_object: + raise ValueError("source_bucket and source_object cannot be empty.") + + client = self.get_conn() + source_bucket = client.bucket(source_bucket) + source_object = source_bucket.blob(blob_name=source_object) # type: ignore[attr-defined] + destination_bucket = client.bucket(destination_bucket) + + token, bytes_rewritten, total_bytes = destination_bucket.blob( # type: ignore[attr-defined] + blob_name=destination_object + ).rewrite( + source=source_object + ) + + self.log.info( + "Total Bytes: %s | Bytes Written: %s", total_bytes, bytes_rewritten + ) + + while token is not None: + token, bytes_rewritten, total_bytes = destination_bucket.blob( # type: ignore[attr-defined] + blob_name=destination_object + ).rewrite( + source=source_object, token=token + ) + + self.log.info( + "Total Bytes: %s | Bytes Written: %s", total_bytes, bytes_rewritten + ) + self.log.info( + "Object %s in bucket %s rewritten to object %s in bucket %s", + source_object.name, # type: ignore[attr-defined] + source_bucket.name, # type: ignore[attr-defined] + destination_object, + destination_bucket.name, # type: ignore[attr-defined] + ) + + def download( + self, + object_name: str, + bucket_name: Optional[str], + filename: Optional[str] = None, + chunk_size: Optional[int] = None, + timeout: Optional[int] = DEFAULT_TIMEOUT, + num_max_attempts: Optional[int] = 1, + ) -> Union[str, bytes]: + """ + Downloads a file from Google Cloud Storage. + + When no filename is supplied, the operator loads the file into memory and returns its + content. When a filename is supplied, it writes the file to the specified location and + returns the location. For file sizes that exceed the available memory it is recommended + to write to a file. + + :param object_name: The object to fetch. + :type object_name: str + :param bucket_name: The bucket to fetch from. + :type bucket_name: str + :param filename: If set, a local file path where the file should be written to. + :type filename: str + :param chunk_size: Blob chunk size. + :type chunk_size: int + :param timeout: Request timeout in seconds. + :type timeout: int + :param num_max_attempts: Number of attempts to download the file. + :type num_max_attempts: int + """ + # TODO: future improvement check file size before downloading, + # to check for local space availability + + num_file_attempts = 0 + + while num_file_attempts < num_max_attempts: + try: + num_file_attempts += 1 + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size) + + if filename: + blob.download_to_filename(filename, timeout=timeout) + self.log.info("File downloaded to %s", filename) + return filename + else: + return blob.download_as_string() + + except GoogleCloudError: + if num_file_attempts == num_max_attempts: + self.log.error( + "Download attempt of object: %s from %s has failed. Attempt: %s, max %s.", + object_name, + object_name, + num_file_attempts, + num_max_attempts, + ) + raise + + # Wait with exponential backoff scheme before retrying. + timeout_seconds = 1.0 * 2 ** (num_file_attempts - 1) + time.sleep(timeout_seconds) + continue + + @_fallback_object_url_to_object_name_and_bucket_name() + @contextmanager + def provide_file( + self, + bucket_name: Optional[str] = None, + object_name: Optional[str] = None, + object_url: Optional[str] = None, # pylint: disable=unused-argument + ): + """ + Downloads the file to a temporary directory and returns a file handle + + You can use this method by passing the bucket_name and object_name parameters + or just object_url parameter. + + :param bucket_name: The bucket to fetch from. + :type bucket_name: str + :param object_name: The object to fetch. + :type object_name: str + :param object_url: File reference url. Must start with "gs: //" + :type object_url: str + :return: File handler + """ + if object_name is None: + raise ValueError("Object name can not be empty") + _, _, file_name = object_name.rpartition("/") + with NamedTemporaryFile(suffix=file_name) as tmp_file: + self.download( + bucket_name=bucket_name, object_name=object_name, filename=tmp_file.name + ) + tmp_file.flush() + yield tmp_file + + @_fallback_object_url_to_object_name_and_bucket_name() + @contextmanager + def provide_file_and_upload( + self, + bucket_name: Optional[str] = None, + object_name: Optional[str] = None, + object_url: Optional[str] = None, # pylint: disable=unused-argument + ): + """ + Creates temporary file, returns a file handle and uploads the files content + on close. + + You can use this method by passing the bucket_name and object_name parameters + or just object_url parameter. + + :param bucket_name: The bucket to fetch from. + :type bucket_name: str + :param object_name: The object to fetch. + :type object_name: str + :param object_url: File reference url. Must start with "gs: //" + :type object_url: str + :return: File handler + """ + if object_name is None: + raise ValueError("Object name can not be empty") + + _, _, file_name = object_name.rpartition("/") + with NamedTemporaryFile(suffix=file_name) as tmp_file: + yield tmp_file + tmp_file.flush() + self.upload( + bucket_name=bucket_name, object_name=object_name, filename=tmp_file.name + ) + + def upload( # pylint: disable=too-many-arguments + self, + bucket_name: str, + object_name: str, + filename: Optional[str] = None, + data: Optional[Union[str, bytes]] = None, + mime_type: Optional[str] = None, + gzip: bool = False, + encoding: str = "utf-8", + chunk_size: Optional[int] = None, + timeout: Optional[int] = DEFAULT_TIMEOUT, + num_max_attempts: int = 1, + ) -> None: + """ + Uploads a local file or file data as string or bytes to Google Cloud Storage. + + :param bucket_name: The bucket to upload to. + :type bucket_name: str + :param object_name: The object name to set when uploading the file. + :type object_name: str + :param filename: The local file path to the file to be uploaded. + :type filename: str + :param data: The file's data as a string or bytes to be uploaded. + :type data: str + :param mime_type: The file's mime type set when uploading the file. + :type mime_type: str + :param gzip: Option to compress local file or file data for upload + :type gzip: bool + :param encoding: bytes encoding for file data if provided as string + :type encoding: str + :param chunk_size: Blob chunk size. + :type chunk_size: int + :param timeout: Request timeout in seconds. + :type timeout: int + :param num_max_attempts: Number of attempts to try to upload the file. + :type num_max_attempts: int + """ + + def _call_with_retry(f: Callable[[], None]) -> None: + """Helper functions to upload a file or a string with a retry mechanism and exponential back-off. + :param f: Callable that should be retried. + :type f: Callable[[], None] + """ + num_file_attempts = 0 + + while num_file_attempts < num_max_attempts: + try: + num_file_attempts += 1 + f() + + except GoogleCloudError as e: + if num_file_attempts == num_max_attempts: + self.log.error( + "Upload attempt of object: %s from %s has failed. Attempt: %s, max %s.", + object_name, + object_name, + num_file_attempts, + num_max_attempts, + ) + raise e + + # Wait with exponential backoff scheme before retrying. + timeout_seconds = 1.0 * 2 ** (num_file_attempts - 1) + time.sleep(timeout_seconds) + continue + + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size) + if filename and data: + raise ValueError( + "'filename' and 'data' parameter provided. Please " + "specify a single parameter, either 'filename' for " + "local file uploads or 'data' for file content uploads." + ) + elif filename: + if not mime_type: + mime_type = "application/octet-stream" + if gzip: + filename_gz = filename + ".gz" + + with open(filename, "rb") as f_in: + with gz.open(filename_gz, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + filename = filename_gz + + _call_with_retry( + partial( + blob.upload_from_filename, + filename=filename, + content_type=mime_type, + timeout=timeout, + ) + ) + + if gzip: + os.remove(filename) + self.log.info( + "File %s uploaded to %s in %s bucket", + filename, + object_name, + bucket_name, + ) + elif data: + if not mime_type: + mime_type = "text/plain" + if gzip: + if isinstance(data, str): + data = bytes(data, encoding) + out = BytesIO() + with gz.GzipFile(fileobj=out, mode="w") as f: + f.write(data) + data = out.getvalue() + + _call_with_retry( + partial( + blob.upload_from_string, + data, + content_type=mime_type, + timeout=timeout, + ) + ) + + self.log.info( + "Data stream uploaded to %s in %s bucket", object_name, bucket_name + ) + else: + raise ValueError( + "'filename' and 'data' parameter missing. One is required to upload to gcs." + ) + + def exists(self, bucket_name: str, object_name: str) -> bool: + """ + Checks for the existence of a file in Google Cloud Storage. + + :param bucket_name: The Google Cloud Storage bucket where the object is. + :type bucket_name: str + :param object_name: The name of the blob_name to check in the Google cloud + storage bucket. + :type object_name: str + """ + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name=object_name) + return blob.exists() + + def get_blob_update_time(self, bucket_name: str, object_name: str): + """ + Get the update time of a file in Google Cloud Storage + + :param bucket_name: The Google Cloud Storage bucket where the object is. + :type bucket_name: str + :param object_name: The name of the blob to get updated time from the Google cloud + storage bucket. + :type object_name: str + """ + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.get_blob(blob_name=object_name) + if blob is None: + raise ValueError( + f"Object ({object_name}) not found in Bucket ({bucket_name})" + ) + return blob.updated + + def is_updated_after( + self, bucket_name: str, object_name: str, ts: datetime + ) -> bool: + """ + Checks if an blob_name is updated in Google Cloud Storage. + + :param bucket_name: The Google Cloud Storage bucket where the object is. + :type bucket_name: str + :param object_name: The name of the object to check in the Google cloud + storage bucket. + :type object_name: str + :param ts: The timestamp to check against. + :type ts: datetime.datetime + """ + blob_update_time = self.get_blob_update_time(bucket_name, object_name) + if blob_update_time is not None: + + if not ts.tzinfo: + ts = ts.replace(tzinfo=timezone.utc) + self.log.info("Verify object date: %s > %s", blob_update_time, ts) + if blob_update_time > ts: + return True + return False + + def is_updated_between( + self, bucket_name: str, object_name: str, min_ts: datetime, max_ts: datetime + ) -> bool: + """ + Checks if an blob_name is updated in Google Cloud Storage. + + :param bucket_name: The Google Cloud Storage bucket where the object is. + :type bucket_name: str + :param object_name: The name of the object to check in the Google cloud + storage bucket. + :type object_name: str + :param min_ts: The minimum timestamp to check against. + :type min_ts: datetime.datetime + :param max_ts: The maximum timestamp to check against. + :type max_ts: datetime.datetime + """ + blob_update_time = self.get_blob_update_time(bucket_name, object_name) + if blob_update_time is not None: + + if not min_ts.tzinfo: + min_ts = min_ts.replace(tzinfo=timezone.utc) + if not max_ts.tzinfo: + max_ts = max_ts.replace(tzinfo=timezone.utc) + self.log.info( + "Verify object date: %s is between %s and %s", + blob_update_time, + min_ts, + max_ts, + ) + if min_ts <= blob_update_time < max_ts: + return True + return False + + def is_updated_before( + self, bucket_name: str, object_name: str, ts: datetime + ) -> bool: + """ + Checks if an blob_name is updated before given time in Google Cloud Storage. + + :param bucket_name: The Google Cloud Storage bucket where the object is. + :type bucket_name: str + :param object_name: The name of the object to check in the Google cloud + storage bucket. + :type object_name: str + :param ts: The timestamp to check against. + :type ts: datetime.datetime + """ + blob_update_time = self.get_blob_update_time(bucket_name, object_name) + if blob_update_time is not None: + + if not ts.tzinfo: + ts = ts.replace(tzinfo=timezone.utc) + self.log.info("Verify object date: %s < %s", blob_update_time, ts) + if blob_update_time < ts: + return True + return False + + def is_older_than(self, bucket_name: str, object_name: str, seconds: int) -> bool: + """ + Check if object is older than given time + + :param bucket_name: The Google Cloud Storage bucket where the object is. + :type bucket_name: str + :param object_name: The name of the object to check in the Google cloud + storage bucket. + :type object_name: str + :param seconds: The time in seconds to check against + :type seconds: int + """ + blob_update_time = self.get_blob_update_time(bucket_name, object_name) + if blob_update_time is not None: + from datetime import timedelta + + current_time = timezone.utcnow() + given_time = current_time - timedelta(seconds=seconds) + self.log.info( + "Verify object date: %s is older than %s", blob_update_time, given_time + ) + if blob_update_time < given_time: + return True + return False + + def delete(self, bucket_name: str, object_name: str) -> None: + """ + Deletes an object from the bucket. + + :param bucket_name: name of the bucket, where the object resides + :type bucket_name: str + :param object_name: name of the object to delete + :type object_name: str + """ + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name=object_name) + blob.delete() + + self.log.info("Blob %s deleted.", object_name) + + def delete_bucket(self, bucket_name: str, force: bool = False) -> None: + """ + Delete a bucket object from the Google Cloud Storage. + + :param bucket_name: name of the bucket which will be deleted + :type bucket_name: str + :param force: false not allow to delete non empty bucket, set force=True + allows to delete non empty bucket + :type: bool + """ + client = self.get_conn() + bucket = client.bucket(bucket_name) + + self.log.info("Deleting %s bucket", bucket_name) + try: + bucket.delete(force=force) + self.log.info("Bucket %s has been deleted", bucket_name) + except NotFound: + self.log.info("Bucket %s not exists", bucket_name) + + def list( + self, bucket_name, versions=None, max_results=None, prefix=None, delimiter=None + ) -> list: + """ + List all objects from the bucket with the give string prefix in name + + :param bucket_name: bucket name + :type bucket_name: str + :param versions: if true, list all versions of the objects + :type versions: bool + :param max_results: max count of items to return in a single page of responses + :type max_results: int + :param prefix: prefix string which filters objects whose name begin with + this prefix + :type prefix: str + :param delimiter: filters objects based on the delimiter (for e.g '.csv') + :type delimiter: str + :return: a stream of object names matching the filtering criteria + """ + client = self.get_conn() + bucket = client.bucket(bucket_name) + + ids = [] + page_token = None + while True: + blobs = bucket.list_blobs( + max_results=max_results, + page_token=page_token, + prefix=prefix, + delimiter=delimiter, + versions=versions, + ) + + blob_names = [] + for blob in blobs: + blob_names.append(blob.name) + + prefixes = blobs.prefixes + if prefixes: + ids += list(prefixes) + else: + ids += blob_names + + page_token = blobs.next_page_token + if page_token is None: + # empty next page token + break + return ids + + def list_by_timespan( + self, + bucket_name: str, + timespan_start: datetime, + timespan_end: datetime, + versions: bool = None, + max_results: int = None, + prefix: str = None, + delimiter: str = None, + ) -> list: + """ + List all objects from the bucket with the give string prefix in name that were + updated in the time between ``timespan_start`` and ``timespan_end``. + + :param bucket_name: bucket name + :type bucket_name: str + :param timespan_start: will return objects that were updated at or after this datetime (UTC) + :type timespan_start: datetime + :param timespan_end: will return objects that were updated before this datetime (UTC) + :type timespan_end: datetime + :param versions: if true, list all versions of the objects + :type versions: bool + :param max_results: max count of items to return in a single page of responses + :type max_results: int + :param prefix: prefix string which filters objects whose name begin with + this prefix + :type prefix: str + :param delimiter: filters objects based on the delimiter (for e.g '.csv') + :type delimiter: str + :return: a stream of object names matching the filtering criteria + """ + client = self.get_conn() + bucket = client.bucket(bucket_name) + + ids = [] + page_token = None + + while True: + blobs = bucket.list_blobs( + max_results=max_results, + page_token=page_token, + prefix=prefix, + delimiter=delimiter, + versions=versions, + ) + + blob_names = [] + for blob in blobs: + if ( + timespan_start + <= blob.updated.replace(tzinfo=timezone.utc) + < timespan_end + ): + blob_names.append(blob.name) + + prefixes = blobs.prefixes + if prefixes: + ids += list(prefixes) + else: + ids += blob_names + + page_token = blobs.next_page_token + if page_token is None: + # empty next page token + break + return ids + + def get_size(self, bucket_name: str, object_name: str) -> int: + """ + Gets the size of a file in Google Cloud Storage. + + :param bucket_name: The Google Cloud Storage bucket where the blob_name is. + :type bucket_name: str + :param object_name: The name of the object to check in the Google + cloud storage bucket_name. + :type object_name: str + + """ + self.log.info( + "Checking the file size of object: %s in bucket_name: %s", + object_name, + bucket_name, + ) + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.get_blob(blob_name=object_name) + blob_size = blob.size + self.log.info("The file size of %s is %s bytes.", object_name, blob_size) + return blob_size + + def get_crc32c(self, bucket_name: str, object_name: str): + """ + Gets the CRC32c checksum of an object in Google Cloud Storage. + + :param bucket_name: The Google Cloud Storage bucket where the blob_name is. + :type bucket_name: str + :param object_name: The name of the object to check in the Google cloud + storage bucket_name. + :type object_name: str + """ + self.log.info( + "Retrieving the crc32c checksum of object_name: %s in bucket_name: %s", + object_name, + bucket_name, + ) + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.get_blob(blob_name=object_name) + blob_crc32c = blob.crc32c + self.log.info("The crc32c checksum of %s is %s", object_name, blob_crc32c) + return blob_crc32c + + def get_md5hash(self, bucket_name: str, object_name: str) -> str: + """ + Gets the MD5 hash of an object in Google Cloud Storage. + + :param bucket_name: The Google Cloud Storage bucket where the blob_name is. + :type bucket_name: str + :param object_name: The name of the object to check in the Google cloud + storage bucket_name. + :type object_name: str + """ + self.log.info( + "Retrieving the MD5 hash of object: %s in bucket: %s", + object_name, + bucket_name, + ) + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.get_blob(blob_name=object_name) + blob_md5hash = blob.md5_hash + self.log.info("The md5Hash of %s is %s", object_name, blob_md5hash) + return blob_md5hash + + @GoogleBaseHook.fallback_to_default_project_id + def create_bucket( + self, + bucket_name: str, + re# Optional[dict] = None, + storage_class: str = "MULTI_REGIONAL", + location: str = "US", + project_id: Optional[str] = None, + labels: Optional[dict] = None, + ) -> str: + """ + Creates a new bucket. Google Cloud Storage uses a flat namespace, so + you can't create a bucket with a name that is already in use. + + .. seealso:: + For more information, see Bucket Naming Guidelines: + https://cloud.google.com/storage/docs/bucketnaming.html#requirements + + :param bucket_name: The name of the bucket. + :type bucket_name: str + :param re# An optional dict with parameters for creating the bucket. + For information on available parameters, see Cloud Storage API doc: + https://cloud.google.com/storage/docs/json_api/v1/buckets/insert + :type re# dict + :param storage_class: This defines how objects in the bucket are stored + and determines the SLA and the cost of storage. Values include + + - ``MULTI_REGIONAL`` + - ``REGIONAL`` + - ``STANDARD`` + - ``NEARLINE`` + - ``COLDLINE``. + + If this value is not specified when the bucket is + created, it will default to STANDARD. + :type storage_class: str + :param location: The location of the bucket. + Object data for objects in the bucket resides in physical storage + within this region. Defaults to US. + + .. seealso:: + https://developers.google.com/storage/docs/bucket-locations + + :type location: str + :param project_id: The ID of the Google Cloud Project. + :type project_id: str + :param labels: User-provided labels, in key/value pairs. + :type labels: dict + :return: If successful, it returns the ``id`` of the bucket. + """ + self.log.info( + "Creating Bucket: %s; Location: %s; Storage Class: %s", + bucket_name, + location, + storage_class, + ) + + # Add airflow-version label to the bucket + labels = labels or {} + labels["airflow-version"] = "v" + version.replace(".", "-").replace("+", "-") + + client = self.get_conn() + bucket = client.bucket(bucket_name=bucket_name) + bucket_resource = resource or {} + + for item in bucket_re# + if item != "name": + bucket._patch_property( # pylint: disable=protected-access + name=item, value=resource[item] # type: ignore[index] + ) + + bucket.storage_class = storage_class + bucket.labels = labels + bucket.create(project=project_id, location=location) + return bucket.id + + def insert_bucket_acl( + self, + bucket_name: str, + entity: str, + role: str, + user_project: Optional[str] = None, + ) -> None: + """ + Creates a new ACL entry on the specified bucket_name. + See: https://cloud.google.com/storage/docs/json_api/v1/bucketAccessControls/insert + + :param bucket_name: Name of a bucket_name. + :type bucket_name: str + :param entity: The entity holding the permission, in one of the following forms: + user-userId, user-email, group-groupId, group-email, domain-domain, + project-team-projectId, allUsers, allAuthenticatedUsers. + See: https://cloud.google.com/storage/docs/access-control/lists#scopes + :type entity: str + :param role: The access permission for the entity. + Acceptable values are: "OWNER", "READER", "WRITER". + :type role: str + :param user_project: (Optional) The project to be billed for this request. + Required for Requester Pays buckets. + :type user_project: str + """ + self.log.info("Creating a new ACL entry in bucket: %s", bucket_name) + client = self.get_conn() + bucket = client.bucket(bucket_name=bucket_name) + bucket.acl.reload() + bucket.acl.entity_from_dict(entity_dict={"entity": entity, "role": role}) + if user_project: + bucket.acl.user_project = user_project + bucket.acl.save() + + self.log.info("A new ACL entry created in bucket: %s", bucket_name) + + def insert_object_acl( + self, + bucket_name: str, + object_name: str, + entity: str, + role: str, + generation: Optional[int] = None, + user_project: Optional[str] = None, + ) -> None: + """ + Creates a new ACL entry on the specified object. + See: https://cloud.google.com/storage/docs/json_api/v1/objectAccessControls/insert + + :param bucket_name: Name of a bucket_name. + :type bucket_name: str + :param object_name: Name of the object. For information about how to URL encode + object names to be path safe, see: + https://cloud.google.com/storage/docs/json_api/#encoding + :type object_name: str + :param entity: The entity holding the permission, in one of the following forms: + user-userId, user-email, group-groupId, group-email, domain-domain, + project-team-projectId, allUsers, allAuthenticatedUsers + See: https://cloud.google.com/storage/docs/access-control/lists#scopes + :type entity: str + :param role: The access permission for the entity. + Acceptable values are: "OWNER", "READER". + :type role: str + :param generation: Optional. If present, selects a specific revision of this object. + :type generation: long + :param user_project: (Optional) The project to be billed for this request. + Required for Requester Pays buckets. + :type user_project: str + """ + self.log.info( + "Creating a new ACL entry for object: %s in bucket: %s", + object_name, + bucket_name, + ) + client = self.get_conn() + bucket = client.bucket(bucket_name=bucket_name) + blob = bucket.blob(blob_name=object_name, generation=generation) + # Reload fetches the current ACL from Cloud Storage. + blob.acl.reload() + blob.acl.entity_from_dict(entity_dict={"entity": entity, "role": role}) + if user_project: + blob.acl.user_project = user_project + blob.acl.save() + + self.log.info( + "A new ACL entry created for object: %s in bucket: %s", + object_name, + bucket_name, + ) + + def compose( + self, bucket_name: str, source_objects: List, destination_object: str + ) -> None: + """ + Composes a list of existing object into a new object in the same storage bucket_name + + Currently it only supports up to 32 objects that can be concatenated + in a single operation + + https://cloud.google.com/storage/docs/json_api/v1/objects/compose + + :param bucket_name: The name of the bucket containing the source objects. + This is also the same bucket to store the composed destination object. + :type bucket_name: str + :param source_objects: The list of source objects that will be composed + into a single object. + :type source_objects: list + :param destination_object: The path of the object if given. + :type destination_object: str + """ + if not source_objects: + raise ValueError("source_objects cannot be empty.") + + if not bucket_name or not destination_object: + raise ValueError("bucket_name and destination_object cannot be empty.") + + self.log.info( + "Composing %s to %s in the bucket %s", + source_objects, + destination_object, + bucket_name, + ) + client = self.get_conn() + bucket = client.bucket(bucket_name) + destination_blob = bucket.blob(destination_object) + destination_blob.compose( + sources=[ + bucket.blob(blob_name=source_object) for source_object in source_objects + ] + ) + + self.log.info("Completed successfully.") + + def sync( + self, + source_bucket: str, + destination_bucket: str, + source_object: Optional[str] = None, + destination_object: Optional[str] = None, + recursive: bool = True, + allow_overwrite: bool = False, + delete_extra_files: bool = False, + ) -> None: + """ + Synchronizes the contents of the buckets. + + Parameters ``source_object`` and ``destination_object`` describe the root sync directories. If they + are not passed, the entire bucket will be synchronized. If they are passed, they should point + to directories. + + .. note:: + The synchronization of individual files is not supported. Only entire directories can be + synchronized. + + :param source_bucket: The name of the bucket containing the source objects. + :type source_bucket: str + :param destination_bucket: The name of the bucket containing the destination objects. + :type destination_bucket: str + :param source_object: The root sync directory in the source bucket. + :type source_object: Optional[str] + :param destination_object: The root sync directory in the destination bucket. + :type destination_object: Optional[str] + :param recursive: If True, subdirectories will be considered + :type recursive: bool + :param recursive: If True, subdirectories will be considered + :type recursive: bool + :param allow_overwrite: if True, the files will be overwritten if a mismatched file is found. + By default, overwriting files is not allowed + :type allow_overwrite: bool + :param delete_extra_files: if True, deletes additional files from the source that not found in the + destination. By default extra files are not deleted. + + .. note:: + This option can delete data quickly if you specify the wrong source/destination combination. + + :type delete_extra_files: bool + :return: none + """ + client = self.get_conn() + # Create bucket object + source_bucket_obj = client.bucket(source_bucket) + destination_bucket_obj = client.bucket(destination_bucket) + # Normalize parameters when they are passed + source_object = self._normalize_directory_path(source_object) + destination_object = self._normalize_directory_path(destination_object) + # Calculate the number of characters that remove from the name, because they contain information + # about the parent's path + source_object_prefix_len = len(source_object) if source_object else 0 + # Prepare synchronization plan + to_copy_blobs, to_delete_blobs, to_rewrite_blobs = self._prepare_sync_plan( + source_bucket=source_bucket_obj, + destination_bucket=destination_bucket_obj, + source_object=source_object, + destination_object=destination_object, + recursive=recursive, + ) + self.log.info( + "Planned synchronization. To delete blobs count: %s, to upload blobs count: %s, " + "to rewrite blobs count: %s", + len(to_delete_blobs), + len(to_copy_blobs), + len(to_rewrite_blobs), + ) + + # Copy missing object to new bucket + if not to_copy_blobs: + self.log.info("Skipped blobs copying.") + else: + for blob in to_copy_blobs: + dst_object = self._calculate_sync_destination_path( + blob, destination_object, source_object_prefix_len + ) + self.copy( + source_bucket=source_bucket_obj.name, + source_object=blob.name, + destination_bucket=destination_bucket_obj.name, + destination_object=dst_object, + ) + self.log.info("Blobs copied.") + # Delete redundant files + if not to_delete_blobs: + self.log.info("Skipped blobs deleting.") + elif delete_extra_files: + # TODO: Add batch. I tried to do it, but the Google library is not stable at the moment. + for blob in to_delete_blobs: + self.delete(blob.bucket.name, blob.name) + self.log.info("Blobs deleted.") + + # Overwrite files that are different + if not to_rewrite_blobs: + self.log.info("Skipped blobs overwriting.") + elif allow_overwrite: + for blob in to_rewrite_blobs: + dst_object = self._calculate_sync_destination_path( + blob, destination_object, source_object_prefix_len + ) + self.rewrite( + source_bucket=source_bucket_obj.name, + source_object=blob.name, + destination_bucket=destination_bucket_obj.name, + destination_object=dst_object, + ) + self.log.info("Blobs rewritten.") + + self.log.info("Synchronization finished.") + + def _calculate_sync_destination_path( + self, + blob: storage.Blob, + destination_object: Optional[str], + source_object_prefix_len: int, + ) -> str: + return ( + path.join(destination_object, blob.name[source_object_prefix_len:]) + if destination_object + else blob.name[source_object_prefix_len:] + ) + + def _normalize_directory_path(self, source_object: Optional[str]) -> Optional[str]: + return ( + source_object + "/" + if source_object and not source_object.endswith("/") + else source_object + ) + + @staticmethod + def _prepare_sync_plan( + source_bucket: storage.Bucket, + destination_bucket: storage.Bucket, + source_object: Optional[str], + destination_object: Optional[str], + recursive: bool, + ) -> Tuple[Set[storage.Blob], Set[storage.Blob], Set[storage.Blob]]: + # Calculate the number of characters that remove from the name, because they contain information + # about the parent's path + source_object_prefix_len = len(source_object) if source_object else 0 + destination_object_prefix_len = ( + len(destination_object) if destination_object else 0 + ) + delimiter = "/" if not recursive else None + # Fetch blobs list + source_blobs = list( + source_bucket.list_blobs(prefix=source_object, delimiter=delimiter) + ) + destination_blobs = list( + destination_bucket.list_blobs( + prefix=destination_object, delimiter=delimiter + ) + ) + # Create indexes that allow you to identify blobs based on their name + source_names_index = { + a.name[source_object_prefix_len:]: a for a in source_blobs + } + destination_names_index = { + a.name[destination_object_prefix_len:]: a for a in destination_blobs + } + # Create sets with names without parent object name + source_names = set(source_names_index.keys()) + destination_names = set(destination_names_index.keys()) + # Determine objects to copy and delete + to_copy = source_names - destination_names + to_delete = destination_names - source_names + to_copy_blobs = { + source_names_index[a] for a in to_copy + } # type: Set[storage.Blob] + to_delete_blobs = { + destination_names_index[a] for a in to_delete + } # type: Set[storage.Blob] + # Find names that are in both buckets + names_to_check = source_names.intersection(destination_names) + to_rewrite_blobs = set() # type: Set[storage.Blob] + # Compare objects based on crc32 + for current_name in names_to_check: + source_blob = source_names_index[current_name] + destination_blob = destination_names_index[current_name] + # if the objects are different, save it + if source_blob.crc32c != destination_blob.crc32c: + to_rewrite_blobs.add(source_blob) + return to_copy_blobs, to_delete_blobs, to_rewrite_blobs + + +def gcs_object_is_directory(bucket: str) -> bool: + """ + Return True if given Google Cloud Storage URL (gs:///) + is a directory or an empty bucket. Otherwise return False. + """ + _, blob = _parse_gcs_url(bucket) + + return len(blob) == 0 or blob.endswith("/") + + +def _parse_gcs_url(gsurl: str) -> Tuple[str, str]: + """ + Given a Google Cloud Storage URL (gs:///), returns a + tuple containing the corresponding bucket and blob. + """ + parsed_url = urlparse(gsurl) + if not parsed_url.netloc: + raise AirflowException("Please provide a bucket name") + if parsed_url.scheme.lower() != "gs": + raise AirflowException( + f"Schema must be to 'gs://': Current schema: '{parsed_url.scheme}://'" + ) + + bucket = parsed_url.netloc + # Remove leading '/' but NOT trailing one + blob = parsed_url.path.lstrip("/") + return bucket, blob diff --git a/reference/providers/google/cloud/hooks/gdm.py b/reference/providers/google/cloud/hooks/gdm.py new file mode 100644 index 0000000..cbb02ec --- /dev/null +++ b/reference/providers/google/cloud/hooks/gdm.py @@ -0,0 +1,119 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import Resource, build + + +class GoogleDeploymentManagerHook(GoogleBaseHook): # pylint: disable=abstract-method + """ + Interact with Google Cloud Deployment Manager using the Google Cloud connection. + This allows for scheduled and programmatic inspection and deletion fo resources managed by GDM. + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + + def get_conn(self) -> Re# + """ + Returns a Google Deployment Manager service object. + + :rtype: googleapiclient.discovery.Resource + """ + http_authorized = self._authorize() + return build( + "deploymentmanager", "v2", http=http_authorized, cache_discovery=False + ) + + @GoogleBaseHook.fallback_to_default_project_id + def list_deployments( + self, + project_id: Optional[str] = None, # pylint: disable=too-many-arguments + deployment_filter: Optional[str] = None, + order_by: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """ + Lists deployments in a google cloud project. + + :param project_id: The project ID for this request. + :type project_id: str + :param deployment_filter: A filter expression which limits resources returned in the response. + :type deployment_filter: str + :param order_by: A field name to order by, ex: "creationTimestamp desc" + :type order_by: Optional[str] + :rtype: list + """ + deployments = [] # type: List[Dict] + conn = self.get_conn() + # pylint: disable=no-member + request = conn.deployments().list( + project=project_id, filter=deployment_filter, orderBy=order_by + ) + + while request is not None: + response = request.execute(num_retries=self.num_retries) + deployments.extend(response.get("deployments", [])) + request = conn.deployments().list_next( # pylint: disable=no-member + previous_request=request, previous_response=response + ) + + return deployments + + @GoogleBaseHook.fallback_to_default_project_id + def delete_deployment( + self, + project_id: Optional[str], + deployment: Optional[str] = None, + delete_policy: Optional[str] = None, + ) -> None: + """ + Deletes a deployment and all associated resources in a google cloud project. + + :param project_id: The project ID for this request. + :type project_id: str + :param deployment: The name of the deployment for this request. + :type deployment: str + :param delete_policy: Sets the policy to use for deleting resources. (ABANDON | DELETE) + :type delete_policy: string + + :rtype: None + """ + conn = self.get_conn() + # pylint: disable=no-member + request = conn.deployments().delete( + project=project_id, deployment=deployment, deletePolicy=delete_policy + ) + resp = request.execute() + if "error" in resp.keys(): + raise AirflowException( + "Errors deleting deployment: ", + ", ".join([err["message"] for err in resp["error"]["errors"]]), + ) diff --git a/reference/providers/google/cloud/hooks/kms.py b/reference/providers/google/cloud/hooks/kms.py new file mode 100644 index 0000000..a15875b --- /dev/null +++ b/reference/providers/google/cloud/hooks/kms.py @@ -0,0 +1,175 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains a Google Cloud KMS hook""" + + +import base64 +from typing import Optional, Sequence, Tuple, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.retry import Retry +from google.cloud.kms_v1 import KeyManagementServiceClient + + +def _b64encode(s: bytes) -> str: + """Base 64 encodes a bytes object to a string""" + return base64.b64encode(s).decode("ascii") + + +def _b64decode(s: str) -> bytes: + """Base 64 decodes a string to bytes""" + return base64.b64decode(s.encode("utf-8")) + + +class CloudKMSHook(GoogleBaseHook): + """ + Hook for Google Cloud Key Management service. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._conn = None # type: Optional[KeyManagementServiceClient] + + def get_conn(self) -> KeyManagementServiceClient: + """ + Retrieves connection to Cloud Key Management service. + + :return: Cloud Key Management service object + :rtype: google.cloud.kms_v1.KeyManagementServiceClient + """ + if not self._conn: + self._conn = KeyManagementServiceClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._conn + + def encrypt( + self, + key_name: str, + plaintext: bytes, + authenticated_data: Optional[bytes] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> str: + """ + Encrypts a plaintext message using Google Cloud KMS. + + :param key_name: The Resource Name for the key (or key version) + to be used for encryption. Of the form + ``projects/*/locations/*/keyRings/*/cryptoKeys/**`` + :type key_name: str + :param plaintext: The message to be encrypted. + :type plaintext: bytes + :param authenticated_data: Optional additional authenticated data that + must also be provided to decrypt the message. + :type authenticated_data: bytes + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :return: The base 64 encoded ciphertext of the original message. + :rtype: str + """ + response = self.get_conn().encrypt( + request={ + "name": key_name, + "plaintext": plaintext, + "additional_authenticated_data": authenticated_data, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + ciphertext = _b64encode(response.ciphertext) + return ciphertext + + def decrypt( + self, + key_name: str, + ciphertext: str, + authenticated_data: Optional[bytes] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> bytes: + """ + Decrypts a ciphertext message using Google Cloud KMS. + + :param key_name: The Resource Name for the key to be used for decryption. + Of the form ``projects/*/locations/*/keyRings/*/cryptoKeys/**`` + :type key_name: str + :param ciphertext: The message to be decrypted. + :type ciphertext: str + :param authenticated_data: Any additional authenticated data that was + provided when encrypting the message. + :type authenticated_data: bytes + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :return: The original message. + :rtype: bytes + """ + response = self.get_conn().decrypt( + request={ + "name": key_name, + "ciphertext": _b64decode(ciphertext), + "additional_authenticated_data": authenticated_data, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + return response.plaintext diff --git a/reference/providers/google/cloud/hooks/kubernetes_engine.py b/reference/providers/google/cloud/hooks/kubernetes_engine.py new file mode 100644 index 0000000..1f675ce --- /dev/null +++ b/reference/providers/google/cloud/hooks/kubernetes_engine.py @@ -0,0 +1,311 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains a Google Kubernetes Engine Hook.""" + +import time +import warnings +from typing import Dict, Optional, Sequence, Union + +from airflow import version +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.exceptions import AlreadyExists, NotFound +from google.api_core.gapic_v1.method import DEFAULT +from google.api_core.retry import Retry +from google.cloud import container_v1, exceptions +from google.cloud.container_v1.gapic.enums import Operation +from google.cloud.container_v1.types import Cluster +from google.protobuf.json_format import ParseDict + +OPERATIONAL_POLL_INTERVAL = 15 + + +class GKEHook(GoogleBaseHook): + """ + Hook for Google Kubernetes Engine APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client = None + self.location = location + + def get_conn(self) -> container_v1.ClusterManagerClient: + """ + Returns ClusterManagerCLinet object. + + :rtype: google.cloud.container_v1.ClusterManagerClient + """ + if self._client is None: + credentials = self._get_credentials() + self._client = container_v1.ClusterManagerClient( + credentials=credentials, client_info=self.client_info + ) + return self._client + + # To preserve backward compatibility + # TODO: remove one day + def get_client( + self, + ) -> container_v1.ClusterManagerClient: # pylint: disable=missing-docstring + warnings.warn( + "The get_client method has been deprecated. " + "You should use the get_conn method.", + DeprecationWarning, + ) + return self.get_conn() + + def wait_for_operation( + self, operation: Operation, project_id: Optional[str] = None + ) -> Operation: + """ + Given an operation, continuously fetches the status from Google Cloud until either + completion or an error occurring + + :param operation: The Operation to wait for + :type operation: google.cloud.container_V1.gapic.enums.Operation + :param project_id: Google Cloud project ID + :type project_id: str + :return: A new, updated operation fetched from Google Cloud + """ + self.log.info("Waiting for OPERATION_NAME %s", operation.name) + time.sleep(OPERATIONAL_POLL_INTERVAL) + while operation.status != Operation.Status.DONE: + if ( + operation.status == Operation.Status.RUNNING + or operation.status == Operation.Status.PENDING + ): + time.sleep(OPERATIONAL_POLL_INTERVAL) + else: + raise exceptions.GoogleCloudError( + f"Operation has failed with status: {operation.status}" + ) + # To update status of operation + operation = self.get_operation( + operation.name, project_id=project_id or self.project_id + ) + return operation + + def get_operation( + self, operation_name: str, project_id: Optional[str] = None + ) -> Operation: + """ + Fetches the operation from Google Cloud + + :param operation_name: Name of operation to fetch + :type operation_name: str + :param project_id: Google Cloud project ID + :type project_id: str + :return: The new, updated operation from Google Cloud + """ + return self.get_conn().get_operation( + project_id=project_id or self.project_id, + zone=self.location, + operation_id=operation_name, + ) + + @staticmethod + def _append_label(cluster_proto: Cluster, key: str, val: str) -> Cluster: + """ + Append labels to provided Cluster Protobuf + + Labels must fit the regex ``[a-z]([-a-z0-9]*[a-z0-9])?`` (current + airflow version string follows semantic versioning spec: x.y.z). + + :param cluster_proto: The proto to append resource_label airflow + version to + :type cluster_proto: google.cloud.container_v1.types.Cluster + :param key: The key label + :type key: str + :param val: + :type val: str + :return: The cluster proto updated with new label + """ + val = val.replace(".", "-").replace("+", "-") + cluster_proto.resource_labels.update({key: val}) + return cluster_proto + + @GoogleBaseHook.fallback_to_default_project_id + def delete_cluster( + self, + name: str, + project_id: str, + retry: Retry = DEFAULT, + timeout: float = DEFAULT, + ) -> Optional[str]: + """ + Deletes the cluster, including the Kubernetes endpoint and all + worker nodes. Firewalls and routes that were configured during + cluster creation are also deleted. Other Google Compute Engine + resources that might be in use by the cluster (e.g. load balancer + resources) will not be deleted if they were not present at the + initial create time. + + :param name: The name of the cluster to delete + :type name: str + :param project_id: Google Cloud project ID + :type project_id: str + :param retry: Retry object used to determine when/if to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :return: The full url to the delete operation if successful, else None + """ + self.log.info( + "Deleting (project_id=%s, zone=%s, cluster_id=%s)", + project_id, + self.location, + name, + ) + + try: + resource = self.get_conn().delete_cluster( + project_id=project_id, + zone=self.location, + cluster_id=name, + retry=retry, + timeout=timeout, + ) + resource = self.wait_for_operation(resource) + # Returns server-defined url for the resource + return resource.self_link + except NotFound as error: + self.log.info("Assuming Success: %s", error.message) + return None + + @GoogleBaseHook.fallback_to_default_project_id + def create_cluster( + self, + cluster: Union[Dict, Cluster], + project_id: str, + retry: Retry = DEFAULT, + timeout: float = DEFAULT, + ) -> str: + """ + Creates a cluster, consisting of the specified number and type of Google Compute + Engine instances. + + :param cluster: A Cluster protobuf or dict. If dict is provided, it must + be of the same form as the protobuf message + :class:`google.cloud.container_v1.types.Cluster` + :type cluster: dict or google.cloud.container_v1.types.Cluster + :param project_id: Google Cloud project ID + :type project_id: str + :param retry: A retry object (``google.api_core.retry.Retry``) used to + retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :return: The full url to the new, or existing, cluster + :raises: + ParseError: On JSON parsing problems when trying to convert dict + AirflowException: cluster is not dict type nor Cluster proto type + """ + if isinstance(cluster, dict): + cluster_proto = Cluster() + cluster = ParseDict(cluster, cluster_proto) + elif not isinstance(cluster, Cluster): + raise AirflowException( + "cluster is not instance of Cluster proto or python dict" + ) + + self._append_label(cluster, "airflow-version", "v" + version.version) + + self.log.info( + "Creating (project_id=%s, zone=%s, cluster_name=%s)", + project_id, + self.location, + cluster.name, + ) + try: + resource = self.get_conn().create_cluster( + project_id=project_id, + zone=self.location, + cluster=cluster, + retry=retry, + timeout=timeout, + ) + resource = self.wait_for_operation(resource) + + return resource.target_link + except AlreadyExists as error: + self.log.info("Assuming Success: %s", error.message) + return self.get_cluster(name=cluster.name, project_id=project_id) + + @GoogleBaseHook.fallback_to_default_project_id + def get_cluster( + self, + name: str, + project_id: str, + retry: Retry = DEFAULT, + timeout: float = DEFAULT, + ) -> Cluster: + """ + Gets details of specified cluster + + :param name: The name of the cluster to retrieve + :type name: str + :param project_id: Google Cloud project ID + :type project_id: str + :param retry: A retry object used to retry requests. If None is specified, + requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :return: google.cloud.container_v1.types.Cluster + """ + self.log.info( + "Fetching cluster (project_id=%s, zone=%s, cluster_name=%s)", + project_id or self.project_id, + self.location, + name, + ) + + return ( + self.get_conn() + .get_cluster( + project_id=project_id, + zone=self.location, + cluster_id=name, + retry=retry, + timeout=timeout, + ) + .self_link + ) diff --git a/reference/providers/google/cloud/hooks/life_sciences.py b/reference/providers/google/cloud/hooks/life_sciences.py new file mode 100644 index 0000000..57ded89 --- /dev/null +++ b/reference/providers/google/cloud/hooks/life_sciences.py @@ -0,0 +1,169 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Google Cloud Life Sciences service""" + +import time +from typing import Any, Optional, Sequence, Union + +import google.api_core.path_template +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import build + +# Time to sleep between active checks of the operation results +TIME_TO_SLEEP_IN_SECONDS = 5 + + +class LifeSciencesHook(GoogleBaseHook): + """ + Hook for the Google Cloud Life Sciences APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + + :param api_version: API version used (for example v1 or v1beta1). + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + _conn = None # type: Optional[Any] + + def __init__( + self, + api_version: str = "v2beta", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.api_version = api_version + + def get_conn(self) -> build: + """ + Retrieves the connection to Cloud Life Sciences. + + :return: Google Cloud Life Sciences service object. + """ + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "lifesciences", + self.api_version, + http=http_authorized, + cache_discovery=False, + ) + return self._conn + + @GoogleBaseHook.fallback_to_default_project_id + def run_pipeline(self, body: dict, location: str, project_id: str) -> dict: + """ + Runs a pipeline + + :param body: The request body. + :type body: dict + :param location: The location of the project. For example: "us-east1". + :type location: str + :param project_id: Optional, Google Cloud Project project_id where the function belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :rtype: dict + """ + parent = self._location_path(project_id=project_id, location=location) + service = self.get_conn() + + request = ( + service.projects() # pylint: disable=no-member + .locations() + .pipelines() + .run(parent=parent, body=body) + ) + + response = request.execute(num_retries=self.num_retries) + + # wait + operation_name = response["name"] + self._wait_for_operation_to_complete(operation_name) + + return response + + @GoogleBaseHook.fallback_to_default_project_id + def _location_path(self, project_id: str, location: str) -> str: + """ + Return a location string. + + :param project_id: Optional, Google Cloud Project project_id where the + function belongs. If set to None or missing, the default project_id + from the Google Cloud connection is used. + :type project_id: str + :param location: The location of the project. For example: "us-east1". + :type location: str + """ + return google.api_core.path_template.expand( + "projects/{project}/locations/{location}", + project=project_id, + location=location, + ) + + def _wait_for_operation_to_complete(self, operation_name: str) -> None: + """ + Waits for the named operation to complete - checks status of the + asynchronous call. + + :param operation_name: The name of the operation. + :type operation_name: str + :return: The response returned by the operation. + :rtype: dict + :exception: AirflowException in case error is returned. + """ + service = self.get_conn() + while True: + operation_response = ( + service.projects() # pylint: disable=no-member + .locations() + .operations() + .get(name=operation_name) + .execute(num_retries=self.num_retries) + ) + self.log.info("Waiting for pipeline operation to complete") + if operation_response.get("done"): + response = operation_response.get("response") + error = operation_response.get("error") + # Note, according to documentation always either response or error is + # set when "done" == True + if error: + raise AirflowException(str(error)) + return response + time.sleep(TIME_TO_SLEEP_IN_SECONDS) diff --git a/reference/providers/google/cloud/hooks/mlengine.py b/reference/providers/google/cloud/hooks/mlengine.py new file mode 100644 index 0000000..398ea4c --- /dev/null +++ b/reference/providers/google/cloud/hooks/mlengine.py @@ -0,0 +1,592 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google ML Engine Hook.""" +import logging +import random +import time +from typing import Callable, Dict, List, Optional + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.version import version as airflow_version +from googleapiclient.discovery import Resource, build +from googleapiclient.errors import HttpError + +log = logging.getLogger(__name__) + +_AIRFLOW_VERSION = "v" + airflow_version.replace(".", "-").replace("+", "-") + + +def _poll_with_exponential_delay( + request, execute_num_retries, max_n, is_done_func, is_error_func +): + """ + Execute request with exponential delay. + + This method is intended to handle and retry in case of api-specific errors, + such as 429 "Too Many Requests", unlike the `request.execute` which handles + lower level errors like `ConnectionError`/`socket.timeout`/`ssl.SSLError`. + + :param request: request to be executed. + :type request: googleapiclient.http.HttpRequest + :param execute_num_retries: num_retries for `request.execute` method. + :type execute_num_retries: int + :param max_n: number of times to retry request in this method. + :type max_n: int + :param is_done_func: callable to determine if operation is done. + :type is_done_func: callable + :param is_error_func: callable to determine if operation is failed. + :type is_error_func: callable + :return: response + :rtype: httplib2.Response + """ + for i in range(0, max_n): + try: + response = request.execute(num_retries=execute_num_retries) + if is_error_func(response): + raise ValueError(f"The response contained an error: {response}") + if is_done_func(response): + log.info("Operation is done: %s", response) + return response + + time.sleep((2 ** i) + (random.randint(0, 1000) / 1000)) + except HttpError as e: + if e.resp.status != 429: + log.info("Something went wrong. Not retrying: %s", format(e)) + raise + else: + time.sleep((2 ** i) + (random.randint(0, 1000) / 1000)) + + raise ValueError(f"Connection could not be established after {max_n} retries.") + + +class MLEngineHook(GoogleBaseHook): + """ + Hook for Google ML Engine APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + def get_conn(self) -> Re# + """ + Retrieves the connection to MLEngine. + + :return: Google MLEngine services object. + """ + authed_http = self._authorize() + return build("ml", "v1", http=authed_http, cache_discovery=False) + + @GoogleBaseHook.fallback_to_default_project_id + def create_job( + self, job: dict, project_id: str, use_existing_job_fn: Optional[Callable] = None + ) -> dict: + """ + Launches a MLEngine job and wait for it to reach a terminal state. + + :param project_id: The Google Cloud project id within which MLEngine + job will be launched. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :param job: MLEngine Job object that should be provided to the MLEngine + API, such as: :: + + { + 'jobId': 'my_job_id', + 'trainingInput': { + 'scaleTier': 'STANDARD_1', + ... + } + } + + :type job: dict + :param use_existing_job_fn: In case that a MLEngine job with the same + job_id already exist, this method (if provided) will decide whether + we should use this existing job, continue waiting for it to finish + and returning the job object. It should accepts a MLEngine job + object, and returns a boolean value indicating whether it is OK to + reuse the existing job. If 'use_existing_job_fn' is not provided, + we by default reuse the existing MLEngine job. + :type use_existing_job_fn: function + :return: The MLEngine job object if the job successfully reach a + terminal state (which might be FAILED or CANCELLED state). + :rtype: dict + """ + hook = self.get_conn() + + self._append_label(job) + self.log.info("Creating job.") + # pylint: disable=no-member + request = ( + hook.projects().jobs().create(parent=f"projects/{project_id}", body=job) + ) + job_id = job["jobId"] + + try: + request.execute(num_retries=self.num_retries) + except HttpError as e: + # 409 means there is an existing job with the same job ID. + if e.resp.status == 409: + if use_existing_job_fn is not None: + existing_job = self._get_job(project_id, job_id) + if not use_existing_job_fn(existing_job): + self.log.error( + "Job with job_id %s already exist, but it does not match our expectation: %s", + job_id, + existing_job, + ) + raise + self.log.info( + "Job with job_id %s already exist. Will waiting for it to finish", + job_id, + ) + else: + self.log.error("Failed to create MLEngine job: %s", e) + raise + + return self._wait_for_job_done(project_id, job_id) + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_job( + self, + job_id: str, + project_id: str, + ) -> dict: + """ + Cancels a MLEngine job. + + :param project_id: The Google Cloud project id within which MLEngine + job will be cancelled. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :param job_id: A unique id for the want-to-be cancelled Google MLEngine training job. + :type job_id: str + + :return: Empty dict if cancelled successfully + :rtype: dict + :raises: googleapiclient.errors.HttpError + """ + hook = self.get_conn() + # pylint: disable=no-member + request = ( + hook.projects().jobs().cancel(name=f"projects/{project_id}/jobs/{job_id}") + ) + + try: + return request.execute(num_retries=self.num_retries) + except HttpError as e: + if e.resp.status == 404: + self.log.error("Job with job_id %s does not exist. ", job_id) + raise + elif e.resp.status == 400: + self.log.info( + "Job with job_id %s is already complete, cancellation aborted.", + job_id, + ) + return {} + else: + self.log.error("Failed to cancel MLEngine job: %s", e) + raise + + def _get_job(self, project_id: str, job_id: str) -> dict: + """ + Gets a MLEngine job based on the job id. + + :param project_id: The project in which the Job is located. If set to None or missing, the default + project_id from the Google Cloud connection is used. (templated) + :type project_id: str + :param job_id: A unique id for the Google MLEngine job. (templated) + :type job_id: str + :return: MLEngine job object if succeed. + :rtype: dict + :raises: googleapiclient.errors.HttpError + """ + hook = self.get_conn() + job_name = f"projects/{project_id}/jobs/{job_id}" + request = hook.projects().jobs().get(name=job_name) # pylint: disable=no-member + while True: + try: + return request.execute(num_retries=self.num_retries) + except HttpError as e: + if e.resp.status == 429: + # polling after 30 seconds when quota failure occurs + time.sleep(30) + else: + self.log.error("Failed to get MLEngine job: %s", e) + raise + + def _wait_for_job_done(self, project_id: str, job_id: str, interval: int = 30): + """ + Waits for the Job to reach a terminal state. + + This method will periodically check the job state until the job reach + a terminal state. + + :param project_id: The project in which the Job is located. If set to None or missing, the default + project_id from the Google Cloud connection is used. (templated) + :type project_id: str + :param job_id: A unique id for the Google MLEngine job. (templated) + :type job_id: str + :param interval: Time expressed in seconds after which the job status is checked again. (templated) + :type interval: int + :raises: googleapiclient.errors.HttpError + """ + self.log.info("Waiting for job. job_id=%s", job_id) + + if interval <= 0: + raise ValueError("Interval must be > 0") + while True: + job = self._get_job(project_id, job_id) + if job["state"] in ["SUCCEEDED", "FAILED", "CANCELLED"]: + return job + time.sleep(interval) + + @GoogleBaseHook.fallback_to_default_project_id + def create_version( + self, + model_name: str, + version_spec: Dict, + project_id: str, + ) -> dict: + """ + Creates the Version on Google Cloud ML Engine. + + :param version_spec: A dictionary containing the information about the version. (templated) + :type version_spec: dict + :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. + (templated) + :type model_name: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + (templated) + :type project_id: str + :return: If the version was created successfully, returns the operation. + Otherwise raises an error . + :rtype: dict + """ + hook = self.get_conn() + parent_name = f"projects/{project_id}/models/{model_name}" + + self._append_label(version_spec) + + # pylint: disable=no-member + create_request = ( + hook.projects() + .models() + .versions() + .create(parent=parent_name, body=version_spec) + ) + response = create_request.execute(num_retries=self.num_retries) + get_request = ( + hook.projects().operations().get(name=response["name"]) + ) # pylint: disable=no-member + + return _poll_with_exponential_delay( + request=get_request, + execute_num_retries=self.num_retries, + max_n=9, + is_done_func=lambda resp: resp.get("done", False), + is_error_func=lambda resp: resp.get("error", None) is not None, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def set_default_version( + self, + model_name: str, + version_name: str, + project_id: str, + ) -> dict: + """ + Sets a version to be the default. Blocks until finished. + + :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. + (templated) + :type model_name: str + :param version_name: A name to use for the version being operated upon. (templated) + :type version_name: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None + or missing, the default project_id from the Google Cloud connection is used. (templated) + :type project_id: str + :return: If successful, return an instance of Version. + Otherwise raises an error. + :rtype: dict + :raises: googleapiclient.errors.HttpError + """ + hook = self.get_conn() + full_version_name = ( + f"projects/{project_id}/models/{model_name}/versions/{version_name}" + ) + # pylint: disable=no-member + request = ( + hook.projects() + .models() + .versions() + .setDefault(name=full_version_name, body={}) + ) + + try: + response = request.execute(num_retries=self.num_retries) + self.log.info("Successfully set version: %s to default", response) + return response + except HttpError as e: + self.log.error("Something went wrong: %s", e) + raise + + @GoogleBaseHook.fallback_to_default_project_id + def list_versions( + self, + model_name: str, + project_id: str, + ) -> List[dict]: + """ + Lists all available versions of a model. Blocks until finished. + + :param model_name: The name of the Google Cloud ML Engine model that the version + belongs to. (templated) + :type model_name: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or + missing, the default project_id from the Google Cloud connection is used. (templated) + :type project_id: str + :return: return an list of instance of Version. + :rtype: List[Dict] + :raises: googleapiclient.errors.HttpError + """ + hook = self.get_conn() + result = [] # type: List[Dict] + full_parent_name = f"projects/{project_id}/models/{model_name}" + # pylint: disable=no-member + request = ( + hook.projects() + .models() + .versions() + .list(parent=full_parent_name, pageSize=100) + ) + + while request is not None: + response = request.execute(num_retries=self.num_retries) + result.extend(response.get("versions", [])) + # pylint: disable=no-member + request = ( + hook.projects() + .models() + .versions() + .list_next(previous_request=request, previous_response=response) + ) + time.sleep(5) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_version( + self, + model_name: str, + version_name: str, + project_id: str, + ) -> dict: + """ + Deletes the given version of a model. Blocks until finished. + + :param model_name: The name of the Google Cloud ML Engine model that the version + belongs to. (templated) + :type model_name: str + :param project_id: The Google Cloud project name to which MLEngine + model belongs. + :type project_id: str + :return: If the version was deleted successfully, returns the operation. + Otherwise raises an error. + :rtype: Dict + """ + hook = self.get_conn() + full_name = f"projects/{project_id}/models/{model_name}/versions/{version_name}" + delete_request = ( + hook.projects() + .models() + .versions() + .delete(name=full_name) # pylint: disable=no-member + ) + response = delete_request.execute(num_retries=self.num_retries) + get_request = ( + hook.projects().operations().get(name=response["name"]) + ) # pylint: disable=no-member + + return _poll_with_exponential_delay( + request=get_request, + execute_num_retries=self.num_retries, + max_n=9, + is_done_func=lambda resp: resp.get("done", False), + is_error_func=lambda resp: resp.get("error", None) is not None, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_model( + self, + model: dict, + project_id: str, + ) -> dict: + """ + Create a Model. Blocks until finished. + + :param model: A dictionary containing the information about the model. + :type model: dict + :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or + missing, the default project_id from the Google Cloud connection is used. (templated) + :type project_id: str + :return: If the version was created successfully, returns the instance of Model. + Otherwise raises an error. + :rtype: Dict + :raises: googleapiclient.errors.HttpError + """ + hook = self.get_conn() + if "name" not in model or not model["name"]: + raise ValueError( + "Model name must be provided and " "could not be an empty string" + ) + project = f"projects/{project_id}" + + self._append_label(model) + try: + request = ( + hook.projects().models().create(parent=project, body=model) + ) # pylint: disable=no-member + response = request.execute(num_retries=self.num_retries) + except HttpError as e: + if e.resp.status != 409: + raise e + str(e) # Fills in the error_details field + if not e.error_details or len(e.error_details) != 1: + raise e + + error_detail = e.error_details[0] + if error_detail["@type"] != "type.googleapis.com/google.rpc.BadRequest": + raise e + + if ( + "fieldViolations" not in error_detail + or len(error_detail["fieldViolations"]) != 1 + ): + raise e + + field_violation = error_detail["fieldViolations"][0] + if ( + field_violation["field"] != "model.name" + or field_violation["description"] + != "A model with the same name already exists." + ): + raise e + response = self.get_model(model_name=model["name"], project_id=project_id) + + return response + + @GoogleBaseHook.fallback_to_default_project_id + def get_model( + self, + model_name: str, + project_id: str, + ) -> Optional[dict]: + """ + Gets a Model. Blocks until finished. + + :param model_name: The name of the model. + :type model_name: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None + or missing, the default project_id from the Google Cloud connection is used. (templated) + :type project_id: str + :return: If the model exists, returns the instance of Model. + Otherwise return None. + :rtype: Dict + :raises: googleapiclient.errors.HttpError + """ + hook = self.get_conn() + if not model_name: + raise ValueError( + "Model name must be provided and " "it could not be an empty string" + ) + full_model_name = f"projects/{project_id}/models/{model_name}" + request = ( + hook.projects().models().get(name=full_model_name) + ) # pylint: disable=no-member + try: + return request.execute(num_retries=self.num_retries) + except HttpError as e: + if e.resp.status == 404: + self.log.error("Model was not found: %s", e) + return None + raise + + @GoogleBaseHook.fallback_to_default_project_id + def delete_model( + self, + model_name: str, + project_id: str, + delete_contents: bool = False, + ) -> None: + """ + Delete a Model. Blocks until finished. + + :param model_name: The name of the model. + :type model_name: str + :param delete_contents: Whether to force the deletion even if the models is not empty. + Will delete all version (if any) in the dataset if set to True. + The default value is False. + :type delete_contents: bool + :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None + or missing, the default project_id from the Google Cloud connection is used. (templated) + :type project_id: str + :raises: googleapiclient.errors.HttpError + """ + hook = self.get_conn() + + if not model_name: + raise ValueError( + "Model name must be provided and it could not be an empty string" + ) + model_path = f"projects/{project_id}/models/{model_name}" + if delete_contents: + self._delete_all_versions(model_name, project_id) + request = ( + hook.projects().models().delete(name=model_path) + ) # pylint: disable=no-member + try: + request.execute(num_retries=self.num_retries) + except HttpError as e: + if e.resp.status == 404: + self.log.error("Model was not found: %s", e) + return + raise + + def _delete_all_versions(self, model_name: str, project_id: str): + versions = self.list_versions(project_id=project_id, model_name=model_name) + # The default version can only be deleted when it is the last one in the model + non_default_versions = ( + version for version in versions if not version.get("isDefault", False) + ) + for version in non_default_versions: + _, _, version_name = version["name"].rpartition("/") + self.delete_version( + project_id=project_id, model_name=model_name, version_name=version_name + ) + default_versions = ( + version for version in versions if version.get("isDefault", False) + ) + for version in default_versions: + _, _, version_name = version["name"].rpartition("/") + self.delete_version( + project_id=project_id, model_name=model_name, version_name=version_name + ) + + def _append_label(self, model: dict) -> None: + model["labels"] = model.get("labels", {}) + model["labels"]["airflow-version"] = _AIRFLOW_VERSION diff --git a/reference/providers/google/cloud/hooks/natural_language.py b/reference/providers/google/cloud/hooks/natural_language.py new file mode 100644 index 0000000..6e58b92 --- /dev/null +++ b/reference/providers/google/cloud/hooks/natural_language.py @@ -0,0 +1,305 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Natural Language Hook.""" +from typing import Optional, Sequence, Tuple, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.retry import Retry +from google.cloud.language_v1 import LanguageServiceClient, enums +from google.cloud.language_v1.types import ( + AnalyzeEntitiesResponse, + AnalyzeEntitySentimentResponse, + AnalyzeSentimentResponse, + AnalyzeSyntaxResponse, + AnnotateTextRequest, + AnnotateTextResponse, + ClassifyTextResponse, + Document, +) + + +class CloudNaturalLanguageHook(GoogleBaseHook): + """ + Hook for Google Cloud Natural Language Service. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._conn = None + + def get_conn(self) -> LanguageServiceClient: + """ + Retrieves connection to Cloud Natural Language service. + + :return: Cloud Natural Language service object + :rtype: google.cloud.language_v1.LanguageServiceClient + """ + if not self._conn: + self._conn = LanguageServiceClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._conn + + @GoogleBaseHook.quota_retry() + def analyze_entities( + self, + document: Union[dict, Document], + encoding_type: Optional[enums.EncodingType] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> AnalyzeEntitiesResponse: + """ + Finds named entities in the text along with entity types, + salience, mentions for each entity, and other properties. + + :param document: Input document. + If a dict is provided, it must be of the same form as the protobuf message Document + :type document: dict or google.cloud.language_v1.types.Document + :param encoding_type: The encoding type used by the API to calculate offsets. + :type encoding_type: google.cloud.language_v1.enums.EncodingType + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.language_v1.types.AnalyzeEntitiesResponse + """ + client = self.get_conn() + + return client.analyze_entities( + document=document, + encoding_type=encoding_type, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.quota_retry() + def analyze_entity_sentiment( + self, + document: Union[dict, Document], + encoding_type: Optional[enums.EncodingType] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> AnalyzeEntitySentimentResponse: + """ + Finds entities, similar to AnalyzeEntities in the text and analyzes sentiment associated with each + entity and its mentions. + + :param document: Input document. + If a dict is provided, it must be of the same form as the protobuf message Document + :type document: dict or google.cloud.language_v1.types.Document + :param encoding_type: The encoding type used by the API to calculate offsets. + :type encoding_type: google.cloud.language_v1.enums.EncodingType + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.language_v1.types.AnalyzeEntitiesResponse + """ + client = self.get_conn() + + return client.analyze_entity_sentiment( + document=document, + encoding_type=encoding_type, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.quota_retry() + def analyze_sentiment( + self, + document: Union[dict, Document], + encoding_type: Optional[enums.EncodingType] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> AnalyzeSentimentResponse: + """ + Analyzes the sentiment of the provided text. + + :param document: Input document. + If a dict is provided, it must be of the same form as the protobuf message Document + :type document: dict or google.cloud.language_v1.types.Document + :param encoding_type: The encoding type used by the API to calculate offsets. + :type encoding_type: google.cloud.language_v1.enums.EncodingType + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.language_v1.types.AnalyzeSentimentResponse + """ + client = self.get_conn() + + return client.analyze_sentiment( + document=document, + encoding_type=encoding_type, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.quota_retry() + def analyze_syntax( + self, + document: Union[dict, Document], + encoding_type: Optional[enums.EncodingType] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> AnalyzeSyntaxResponse: + """ + Analyzes the syntax of the text and provides sentence boundaries and tokenization along with part + of speech tags, dependency trees, and other properties. + + :param document: Input document. + If a dict is provided, it must be of the same form as the protobuf message Document + :type document: dict or google.cloud.language_v1.types.Document + :param encoding_type: The encoding type used by the API to calculate offsets. + :type encoding_type: google.cloud.language_v1.enums.EncodingType + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.language_v1.types.AnalyzeSyntaxResponse + """ + client = self.get_conn() + + return client.analyze_syntax( + document=document, + encoding_type=encoding_type, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.quota_retry() + def annotate_text( + self, + document: Union[dict, Document], + features: Union[dict, AnnotateTextRequest.Features], + encoding_type: enums.EncodingType = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> AnnotateTextResponse: + """ + A convenience method that provides all the features that analyzeSentiment, + analyzeEntities, and analyzeSyntax provide in one call. + + :param document: Input document. + If a dict is provided, it must be of the same form as the protobuf message Document + :type document: dict or google.cloud.language_v1.types.Document + :param features: The enabled features. + If a dict is provided, it must be of the same form as the protobuf message Features + :type features: dict or google.cloud.language_v1.types.AnnotateTextRequest.Features + :param encoding_type: The encoding type used by the API to calculate offsets. + :type encoding_type: google.cloud.language_v1.enums.EncodingType + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.language_v1.types.AnnotateTextResponse + """ + client = self.get_conn() + + return client.annotate_text( + document=document, + features=features, + encoding_type=encoding_type, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.quota_retry() + def classify_text( + self, + document: Union[dict, Document], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> ClassifyTextResponse: + """ + Classifies a document into categories. + + :param document: Input document. + If a dict is provided, it must be of the same form as the protobuf message Document + :type document: dict or google.cloud.language_v1.types.Document + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.language_v1.types.ClassifyTextResponse + """ + client = self.get_conn() + + return client.classify_text( + document=document, retry=retry, timeout=timeout, metadata=metadata + ) diff --git a/reference/providers/google/cloud/hooks/os_login.py b/reference/providers/google/cloud/hooks/os_login.py new file mode 100644 index 0000000..2c49d6d --- /dev/null +++ b/reference/providers/google/cloud/hooks/os_login.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict, Optional, Sequence, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.cloud.oslogin_v1 import ImportSshPublicKeyResponse, OsLoginServiceClient + + +class OSLoginHook(GoogleBaseHook): + """ + Hook for Google OS login APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._conn = None # type: Optional[OsLoginServiceClient] + + def get_conn(self) -> OsLoginServiceClient: + """Return OS Login service client""" + if self._conn: + return self._conn + + self._conn = OsLoginServiceClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._conn + + @GoogleBaseHook.fallback_to_default_project_id + def import_ssh_public_key( + self, + user: str, + ssh_public_key: Dict, + project_id: str, + retry=None, + timeout=None, + metadata=None, + ) -> ImportSshPublicKeyResponse: + """ + Adds an SSH public key and returns the profile information. Default POSIX + account information is set when no username and UID exist as part of the + login profile. + + :param user: The unique ID for the user + :type user: str + :param ssh_public_key: The SSH public key and expiration time. + :type ssh_public_key: dict + :param project_id: The project ID of the Google Cloud project. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will + be retried using a default configuration. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that + if ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :return: A :class:`~google.cloud.oslogin_v1.ImportSshPublicKeyResponse` instance. + """ + conn = self.get_conn() + return conn.import_ssh_public_key( + request=dict( + parent=f"users/{user}", + ssh_public_key=ssh_public_key, + project_id=project_id, + ), + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) diff --git a/reference/providers/google/cloud/hooks/pubsub.py b/reference/providers/google/cloud/hooks/pubsub.py new file mode 100644 index 0000000..27df47b --- /dev/null +++ b/reference/providers/google/cloud/hooks/pubsub.py @@ -0,0 +1,675 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Pub/Sub Hook.""" +import warnings +from base64 import b64decode +from typing import Dict, List, Optional, Sequence, Tuple, Union +from uuid import uuid4 + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.version import version +from google.api_core.exceptions import AlreadyExists, GoogleAPICallError +from google.api_core.retry import Retry +from google.cloud.exceptions import NotFound +from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient +from google.cloud.pubsub_v1.types import ( + DeadLetterPolicy, + Duration, + ExpirationPolicy, + MessageStoragePolicy, + PushConfig, + ReceivedMessage, + RetryPolicy, +) +from googleapiclient.errors import HttpError + + +class PubSubException(Exception): + """Alias for Exception.""" + + +class PubSubHook(GoogleBaseHook): + """ + Hook for accessing Google Pub/Sub. + + The Google Cloud project against which actions are applied is determined by + the project embedded in the Connection referenced by gcp_conn_id. + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client = None + + def get_conn(self) -> PublisherClient: + """ + Retrieves connection to Google Cloud Pub/Sub. + + :return: Google Cloud Pub/Sub client object. + :rtype: google.cloud.pubsub_v1.PublisherClient + """ + if not self._client: + self._client = PublisherClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._client + + @cached_property + def subscriber_client(self) -> SubscriberClient: + """ + Creates SubscriberClient. + + :return: Google Cloud Pub/Sub client object. + :rtype: google.cloud.pubsub_v1.SubscriberClient + """ + return SubscriberClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + + @GoogleBaseHook.fallback_to_default_project_id + def publish( + self, + topic: str, + messages: List[dict], + project_id: str, + ) -> None: + """ + Publishes messages to a Pub/Sub topic. + + :param topic: the Pub/Sub topic to which to publish; do not + include the ``projects/{project}/topics/`` prefix. + :type topic: str + :param messages: messages to publish; if the data field in a + message is set, it should be a bytestring (utf-8 encoded) + :type messages: list of PubSub messages; see + http://cloud.google.com/pubsub/docs/reference/rest/v1/PubsubMessage + :param project_id: Optional, the Google Cloud project ID in which to publish. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + """ + self._validate_messages(messages) + + publisher = self.get_conn() + topic_path = f"projects/{project_id}/topics/{topic}" + + self.log.info( + "Publish %d messages to topic (path) %s", len(messages), topic_path + ) + try: + for message in messages: + future = publisher.publish( + topic=topic_path, + data=message.get("data", b""), + **message.get("attributes", {}), + ) + future.result() + except GoogleAPICallError as e: + raise PubSubException(f"Error publishing to topic {topic_path}", e) + + self.log.info( + "Published %d messages to topic (path) %s", len(messages), topic_path + ) + + @staticmethod + def _validate_messages(messages) -> None: + for message in messages: + # To warn about broken backward compatibility + # TODO: remove one day + if "data" in message and isinstance(message["data"], str): + try: + b64decode(message["data"]) + warnings.warn( + "The base 64 encoded string as 'data' field has been deprecated. " + "You should pass bytestring (utf-8 encoded).", + DeprecationWarning, + stacklevel=4, + ) + except ValueError: + pass + + if not isinstance(message, dict): + raise PubSubException("Wrong message type. Must be a dictionary.") + if "data" not in message and "attributes" not in message: + raise PubSubException( + "Wrong message. Dictionary must contain 'data' or 'attributes'." + ) + if "data" in message and not isinstance(message["data"], bytes): + raise PubSubException( + "Wrong message. 'data' must be send as a bytestring" + ) + if ( + "data" not in message + and "attributes" in message + and not message["attributes"] + ) or ( + "attributes" in message and not isinstance(message["attributes"], dict) + ): + raise PubSubException( + "Wrong message. If 'data' is not provided 'attributes' must be a non empty dictionary." + ) + + # pylint: disable=too-many-arguments + @GoogleBaseHook.fallback_to_default_project_id + def create_topic( + self, + topic: str, + project_id: str, + fail_if_exists: bool = False, + labels: Optional[Dict[str, str]] = None, + message_storage_policy: Union[Dict, MessageStoragePolicy] = None, + kms_key_name: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Creates a Pub/Sub topic, if it does not already exist. + + :param topic: the Pub/Sub topic name to create; do not + include the ``projects/{project}/topics/`` prefix. + :type topic: str + :param project_id: Optional, the Google Cloud project ID in which to create the topic + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param fail_if_exists: if set, raise an exception if the topic + already exists + :type fail_if_exists: bool + :param labels: Client-assigned labels; see + https://cloud.google.com/pubsub/docs/labels + :type labels: Dict[str, str] + :param message_storage_policy: Policy constraining the set + of Google Cloud regions where messages published to + the topic may be stored. If not present, then no constraints + are in effect. + :type message_storage_policy: + Union[Dict, google.cloud.pubsub_v1.types.MessageStoragePolicy] + :param kms_key_name: The resource name of the Cloud KMS CryptoKey + to be used to protect access to messages published on this topic. + The expected format is + ``projects/*/locations/*/keyRings/*/cryptoKeys/*``. + :type kms_key_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + """ + publisher = self.get_conn() + topic_path = f"projects/{project_id}/topics/{topic}" + + # Add airflow-version label to the topic + labels = labels or {} + labels["airflow-version"] = "v" + version.replace(".", "-").replace("+", "-") + + self.log.info("Creating topic (path) %s", topic_path) + try: + # pylint: disable=no-member + publisher.create_topic( + request={ + "name": topic_path, + "labels": labels, + "message_storage_policy": message_storage_policy, + "kms_key_name": kms_key_name, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + except AlreadyExists: + self.log.warning("Topic already exists: %s", topic) + if fail_if_exists: + raise PubSubException(f"Topic already exists: {topic}") + except GoogleAPICallError as e: + raise PubSubException(f"Error creating topic {topic}", e) + + self.log.info("Created topic (path) %s", topic_path) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_topic( + self, + topic: str, + project_id: str, + fail_if_not_exists: bool = False, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes a Pub/Sub topic if it exists. + + :param topic: the Pub/Sub topic name to delete; do not + include the ``projects/{project}/topics/`` prefix. + :type topic: str + :param project_id: Optional, the Google Cloud project ID in which to delete the topic. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param fail_if_not_exists: if set, raise an exception if the topic + does not exist + :type fail_if_not_exists: bool + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + """ + publisher = self.get_conn() + topic_path = f"projects/{project_id}/topics/{topic}" + + self.log.info("Deleting topic (path) %s", topic_path) + try: + # pylint: disable=no-member + publisher.delete_topic( + request={"topic": topic_path}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + except NotFound: + self.log.warning("Topic does not exist: %s", topic_path) + if fail_if_not_exists: + raise PubSubException(f"Topic does not exist: {topic_path}") + except GoogleAPICallError as e: + raise PubSubException(f"Error deleting topic {topic}", e) + self.log.info("Deleted topic (path) %s", topic_path) + + # pylint: disable=too-many-arguments + @GoogleBaseHook.fallback_to_default_project_id + def create_subscription( + self, + topic: str, + project_id: str, + subscription: Optional[str] = None, + subscription_project_id: Optional[str] = None, + ack_deadline_secs: int = 10, + fail_if_exists: bool = False, + push_config: Optional[Union[dict, PushConfig]] = None, + retain_acked_messages: Optional[bool] = None, + message_retention_duration: Optional[Union[dict, Duration]] = None, + labels: Optional[Dict[str, str]] = None, + enable_message_ordering: bool = False, + expiration_policy: Optional[Union[dict, ExpirationPolicy]] = None, + filter_: Optional[str] = None, + dead_letter_policy: Optional[Union[dict, DeadLetterPolicy]] = None, + retry_policy: Optional[Union[dict, RetryPolicy]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> str: + """ + Creates a Pub/Sub subscription, if it does not already exist. + + :param topic: the Pub/Sub topic name that the subscription will be bound + to create; do not include the ``projects/{project}/subscriptions/`` prefix. + :type topic: str + :param project_id: Optional, the Google Cloud project ID of the topic that the subscription will be + bound to. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :type project_id: str + :param subscription: the Pub/Sub subscription name. If empty, a random + name will be generated using the uuid module + :type subscription: str + :param subscription_project_id: the Google Cloud project ID where the subscription + will be created. If unspecified, ``project_id`` will be used. + :type subscription_project_id: str + :param ack_deadline_secs: Number of seconds that a subscriber has to + acknowledge each message pulled from the subscription + :type ack_deadline_secs: int + :param fail_if_exists: if set, raise an exception if the topic + already exists + :type fail_if_exists: bool + :param push_config: If push delivery is used with this subscription, + this field is used to configure it. An empty ``pushConfig`` signifies + that the subscriber will pull and ack messages using API methods. + :type push_config: Union[Dict, google.cloud.pubsub_v1.types.PushConfig] + :param retain_acked_messages: Indicates whether to retain acknowledged + messages. If true, then messages are not expunged from the subscription's + backlog, even if they are acknowledged, until they fall out of the + ``message_retention_duration`` window. This must be true if you would + like to Seek to a timestamp. + :type retain_acked_messages: bool + :param message_retention_duration: How long to retain unacknowledged messages + in the subscription's backlog, from the moment a message is published. If + ``retain_acked_messages`` is true, then this also configures the + retention of acknowledged messages, and thus configures how far back in + time a ``Seek`` can be done. Defaults to 7 days. Cannot be more than 7 + days or less than 10 minutes. + :type message_retention_duration: Union[Dict, google.cloud.pubsub_v1.types.Duration] + :param labels: Client-assigned labels; see + https://cloud.google.com/pubsub/docs/labels + :type labels: Dict[str, str] + :param enable_message_ordering: If true, messages published with the same + ordering_key in PubsubMessage will be delivered to the subscribers in the order + in which they are received by the Pub/Sub system. Otherwise, they may be + delivered in any order. + :type enable_message_ordering: bool + :param expiration_policy: A policy that specifies the conditions for this + subscription’s expiration. A subscription is considered active as long as any + connected subscriber is successfully consuming messages from the subscription or + is issuing operations on the subscription. If expiration_policy is not set, + a default policy with ttl of 31 days will be used. The minimum allowed value for + expiration_policy.ttl is 1 day. + :type expiration_policy: Union[Dict, google.cloud.pubsub_v1.types.ExpirationPolicy`] + :param filter_: An expression written in the Cloud Pub/Sub filter language. If + non-empty, then only PubsubMessages whose attributes field matches the filter are + delivered on this subscription. If empty, then no messages are filtered out. + :type filter_: str + :param dead_letter_policy: A policy that specifies the conditions for dead lettering + messages in this subscription. If dead_letter_policy is not set, dead lettering is + disabled. + :type dead_letter_policy: Union[Dict, google.cloud.pubsub_v1.types.DeadLetterPolicy] + :param retry_policy: A policy that specifies how Pub/Sub retries message delivery + for this subscription. If not set, the default retry policy is applied. This + generally implies that messages will be retried as soon as possible for healthy + subscribers. RetryPolicy will be triggered on NACKs or acknowledgement deadline + exceeded events for a given message. + :type retry_policy: Union[Dict, google.cloud.pubsub_v1.types.RetryPolicy] + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + :return: subscription name which will be the system-generated value if + the ``subscription`` parameter is not supplied + :rtype: str + """ + subscriber = self.subscriber_client + + if not subscription: + subscription = f"sub-{uuid4()}" + if not subscription_project_id: + subscription_project_id = project_id + + # Add airflow-version label to the subscription + labels = labels or {} + labels["airflow-version"] = "v" + version.replace(".", "-").replace("+", "-") + + # pylint: disable=no-member + subscription_path = ( + f"projects/{subscription_project_id}/subscriptions/{subscription}" + ) + topic_path = f"projects/{project_id}/topics/{topic}" + + self.log.info( + "Creating subscription (path) %s for topic (path) %a", + subscription_path, + topic_path, + ) + try: + subscriber.create_subscription( + request={ + "name": subscription_path, + "topic": topic_path, + "push_config": push_config, + "ack_deadline_seconds": ack_deadline_secs, + "retain_acked_messages": retain_acked_messages, + "message_retention_duration": message_retention_duration, + "labels": labels, + "enable_message_ordering": enable_message_ordering, + "expiration_policy": expiration_policy, + "filter": filter_, + "dead_letter_policy": dead_letter_policy, + "retry_policy": retry_policy, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + except AlreadyExists: + self.log.warning("Subscription already exists: %s", subscription_path) + if fail_if_exists: + raise PubSubException( + f"Subscription already exists: {subscription_path}" + ) + except GoogleAPICallError as e: + raise PubSubException(f"Error creating subscription {subscription_path}", e) + + self.log.info( + "Created subscription (path) %s for topic (path) %s", + subscription_path, + topic_path, + ) + return subscription + + @GoogleBaseHook.fallback_to_default_project_id + def delete_subscription( + self, + subscription: str, + project_id: str, + fail_if_not_exists: bool = False, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes a Pub/Sub subscription, if it exists. + + :param subscription: the Pub/Sub subscription name to delete; do not + include the ``projects/{project}/subscriptions/`` prefix. + :param project_id: Optional, the Google Cloud project ID where the subscription exists + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :type subscription: str + :param fail_if_not_exists: if set, raise an exception if the topic does not exist + :type fail_if_not_exists: bool + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + """ + subscriber = self.subscriber_client + # noqa E501 # pylint: disable=no-member + subscription_path = f"projects/{project_id}/subscriptions/{subscription}" + + self.log.info("Deleting subscription (path) %s", subscription_path) + try: + # pylint: disable=no-member + subscriber.delete_subscription( + request={"subscription": subscription_path}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + except NotFound: + self.log.warning("Subscription does not exist: %s", subscription_path) + if fail_if_not_exists: + raise PubSubException( + f"Subscription does not exist: {subscription_path}" + ) + except GoogleAPICallError as e: + raise PubSubException(f"Error deleting subscription {subscription_path}", e) + + self.log.info("Deleted subscription (path) %s", subscription_path) + + @GoogleBaseHook.fallback_to_default_project_id + def pull( + self, + subscription: str, + max_messages: int, + project_id: str, + return_immediately: bool = False, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> List[ReceivedMessage]: + """ + Pulls up to ``max_messages`` messages from Pub/Sub subscription. + + :param subscription: the Pub/Sub subscription name to pull from; do not + include the 'projects/{project}/topics/' prefix. + :type subscription: str + :param max_messages: The maximum number of messages to return from + the Pub/Sub API. + :type max_messages: int + :param project_id: Optional, the Google Cloud project ID where the subscription exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param return_immediately: If set, the Pub/Sub API will immediately + return if no messages are available. Otherwise, the request will + block for an undisclosed, but bounded period of time + :type return_immediately: bool + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + :return: A list of Pub/Sub ReceivedMessage objects each containing + an ``ackId`` property and a ``message`` property, which includes + the base64-encoded message content. See + https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.subscriptions/pull#ReceivedMessage + """ + subscriber = self.subscriber_client + # noqa E501 # pylint: disable=no-member,line-too-long + subscription_path = f"projects/{project_id}/subscriptions/{subscription}" + + self.log.info( + "Pulling max %d messages from subscription (path) %s", + max_messages, + subscription_path, + ) + try: + # pylint: disable=no-member + response = subscriber.pull( + request={ + "subscription": subscription_path, + "max_messages": max_messages, + "return_immediately": return_immediately, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + result = getattr(response, "received_messages", []) + self.log.info( + "Pulled %d messages from subscription (path) %s", + len(result), + subscription_path, + ) + return result + except (HttpError, GoogleAPICallError) as e: + raise PubSubException( + f"Error pulling messages from subscription {subscription_path}", e + ) + + @GoogleBaseHook.fallback_to_default_project_id + def acknowledge( + self, + subscription: str, + project_id: str, + ack_ids: Optional[List[str]] = None, + messages: Optional[List[ReceivedMessage]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Acknowledges the messages associated with the ``ack_ids`` from Pub/Sub subscription. + + :param subscription: the Pub/Sub subscription name to delete; do not + include the 'projects/{project}/topics/' prefix. + :type subscription: str + :param ack_ids: List of ReceivedMessage ackIds from a previous pull response. + Mutually exclusive with ``messages`` argument. + :type ack_ids: list + :param messages: List of ReceivedMessage objects to acknowledge. + Mutually exclusive with ``ack_ids`` argument. + :type messages: list + :param project_id: Optional, the Google Cloud project name or ID in which to create the topic + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + """ + if ack_ids is not None and messages is None: + pass + elif ack_ids is None and messages is not None: + ack_ids = [message.ack_id for message in messages] + else: + raise ValueError( + "One and only one of 'ack_ids' and 'messages' arguments have to be provided" + ) + + subscriber = self.subscriber_client + # noqa E501 # pylint: disable=no-member + subscription_path = f"projects/{project_id}/subscriptions/{subscription}" + + self.log.info( + "Acknowledging %d ack_ids from subscription (path) %s", + len(ack_ids), + subscription_path, + ) + try: + # pylint: disable=no-member + subscriber.acknowledge( + request={"subscription": subscription_path, "ack_ids": ack_ids}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + except (HttpError, GoogleAPICallError) as e: + raise PubSubException( + "Error acknowledging {} messages pulled from subscription {}".format( + len(ack_ids), subscription_path + ), + e, + ) + + self.log.info( + "Acknowledged ack_ids from subscription (path) %s", subscription_path + ) diff --git a/reference/providers/google/cloud/hooks/secret_manager.py b/reference/providers/google/cloud/hooks/secret_manager.py new file mode 100644 index 0000000..6693248 --- /dev/null +++ b/reference/providers/google/cloud/hooks/secret_manager.py @@ -0,0 +1,94 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Secrets Manager service""" +from typing import Optional, Sequence, Union + +from airflow.providers.google.cloud._internal_client.secret_manager_client import ( # noqa + _SecretManagerClient, +) +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook + + +class SecretsManagerHook(GoogleBaseHook): + """ + Hook for the Google Secret Manager API. + + See https://cloud.google.com/secret-manager + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.client = _SecretManagerClient(credentials=self._get_credentials()) + + def get_conn(self) -> _SecretManagerClient: + """ + Retrieves the connection to Secret Manager. + + :return: Secret Manager client. + :rtype: airflow.providers.google.cloud._internal_client.secret_manager_client._SecretManagerClient + """ + return self.client + + @GoogleBaseHook.fallback_to_default_project_id + def get_secret( + self, + secret_id: str, + secret_version: str = "latest", + project_id: Optional[str] = None, + ) -> Optional[str]: + """ + Get secret value from the Secret Manager. + + :param secret_id: Secret Key + :type secret_id: str + :param secret_version: version of the secret (default is 'latest') + :type secret_version: str + :param project_id: Project id (if you want to override the project_id from credentials) + :type project_id: str + """ + return self.get_conn().get_secret( + secret_id=secret_id, secret_version=secret_version, project_id=project_id # type: ignore + ) diff --git a/reference/providers/google/cloud/hooks/spanner.py b/reference/providers/google/cloud/hooks/spanner.py new file mode 100644 index 0000000..4146e28 --- /dev/null +++ b/reference/providers/google/cloud/hooks/spanner.py @@ -0,0 +1,442 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Spanner Hook.""" +from typing import Callable, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.exceptions import AlreadyExists, GoogleAPICallError +from google.cloud.spanner_v1.client import Client +from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.instance import Instance +from google.cloud.spanner_v1.transaction import Transaction +from google.longrunning.operations_grpc_pb2 import Operation # noqa: F401 + + +class SpannerHook(GoogleBaseHook): + """ + Hook for Google Cloud Spanner APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client = None + + def _get_client(self, project_id: str) -> Client: + """ + Provides a client for interacting with the Cloud Spanner API. + + :param project_id: The ID of the Google Cloud project. + :type project_id: str + :return: Client + :rtype: google.cloud.spanner_v1.client.Client + """ + if not self._client: + self._client = Client( + project=project_id, + credentials=self._get_credentials(), + client_info=self.client_info, + ) + return self._client + + @GoogleBaseHook.fallback_to_default_project_id + def get_instance( + self, + instance_id: str, + project_id: str, + ) -> Instance: + """ + Gets information about a particular instance. + + :param project_id: Optional, The ID of the Google Cloud project that owns the Cloud Spanner + database. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :type project_id: str + :param instance_id: The ID of the Cloud Spanner instance. + :type instance_id: str + :return: Spanner instance + :rtype: google.cloud.spanner_v1.instance.Instance + """ + instance = self._get_client(project_id=project_id).instance( + instance_id=instance_id + ) + if not instance.exists(): + return None + return instance + + def _apply_to_instance( + self, + project_id: str, + instance_id: str, + configuration_name: str, + node_count: int, + display_name: str, + func: Callable[[Instance], Operation], + ) -> None: + """ + Invokes a method on a given instance by applying a specified Callable. + + :param project_id: The ID of the Google Cloud project that owns the Cloud Spanner database. + :type project_id: str + :param instance_id: The ID of the instance. + :type instance_id: str + :param configuration_name: Name of the instance configuration defining how the + instance will be created. Required for instances which do not yet exist. + :type configuration_name: str + :param node_count: (Optional) Number of nodes allocated to the instance. + :type node_count: int + :param display_name: (Optional) The display name for the instance in the Cloud + Console UI. (Must be between 4 and 30 characters.) If this value is not set + in the constructor, will fall back to the instance ID. + :type display_name: str + :param func: Method of the instance to be called. + :type func: Callable[google.cloud.spanner_v1.instance.Instance] + """ + instance = self._get_client(project_id=project_id).instance( + instance_id=instance_id, + configuration_name=configuration_name, + node_count=node_count, + display_name=display_name, + ) + try: + operation = func(instance) # type: Operation + except GoogleAPICallError as e: + self.log.error("An error occurred: %s. Exiting.", e.message) + raise e + + if operation: + result = operation.result() + self.log.info(result) + + @GoogleBaseHook.fallback_to_default_project_id + def create_instance( + self, + instance_id: str, + configuration_name: str, + node_count: int, + display_name: str, + project_id: str, + ) -> None: + """ + Creates a new Cloud Spanner instance. + + :param instance_id: The ID of the Cloud Spanner instance. + :type instance_id: str + :param configuration_name: The name of the instance configuration defining how the + instance will be created. Possible configuration values can be retrieved via + https://cloud.google.com/spanner/docs/reference/rest/v1/projects.instanceConfigs/list + :type configuration_name: str + :param node_count: (Optional) The number of nodes allocated to the Cloud Spanner + instance. + :type node_count: int + :param display_name: (Optional) The display name for the instance in the Google Cloud Console. + Must be between 4 and 30 characters. If this value is not passed, the name falls back + to the instance ID. + :type display_name: str + :param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner + database. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :type project_id: str + :return: None + """ + self._apply_to_instance( + project_id, + instance_id, + configuration_name, + node_count, + display_name, + lambda x: x.create(), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def update_instance( + self, + instance_id: str, + configuration_name: str, + node_count: int, + display_name: str, + project_id: str, + ) -> None: + """ + Updates an existing Cloud Spanner instance. + + :param instance_id: The ID of the Cloud Spanner instance. + :type instance_id: str + :param configuration_name: The name of the instance configuration defining how the + instance will be created. Possible configuration values can be retrieved via + https://cloud.google.com/spanner/docs/reference/rest/v1/projects.instanceConfigs/list + :type configuration_name: str + :param node_count: (Optional) The number of nodes allocated to the Cloud Spanner + instance. + :type node_count: int + :param display_name: (Optional) The display name for the instance in the Google Cloud + Console. Must be between 4 and 30 characters. If this value is not set in + the constructor, the name falls back to the instance ID. + :type display_name: str + :param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner + database. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :type project_id: str + :return: None + """ + self._apply_to_instance( + project_id, + instance_id, + configuration_name, + node_count, + display_name, + lambda x: x.update(), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_instance(self, instance_id: str, project_id: str) -> None: + """ + Deletes an existing Cloud Spanner instance. + + :param instance_id: The ID of the Cloud Spanner instance. + :type instance_id: str + :param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner + database. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :type project_id: str + :return: None + """ + instance = self._get_client(project_id=project_id).instance(instance_id) + try: + instance.delete() + return + except GoogleAPICallError as e: + self.log.error("An error occurred: %s. Exiting.", e.message) + raise e + + @GoogleBaseHook.fallback_to_default_project_id + def get_database( + self, + instance_id: str, + database_id: str, + project_id: str, + ) -> Optional[Database]: + """ + Retrieves a database in Cloud Spanner. If the database does not exist + in the specified instance, it returns None. + + :param instance_id: The ID of the Cloud Spanner instance. + :type instance_id: str + :param database_id: The ID of the database in Cloud Spanner. + :type database_id: str + :param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner + database. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :type project_id: str + :return: Database object or None if database does not exist + :rtype: google.cloud.spanner_v1.database.Database or None + """ + instance = self._get_client(project_id=project_id).instance( + instance_id=instance_id + ) + if not instance.exists(): + raise AirflowException( + f"The instance {instance_id} does not exist in project {project_id} !" + ) + database = instance.database(database_id=database_id) + if not database.exists(): + return None + + return database + + @GoogleBaseHook.fallback_to_default_project_id + def create_database( + self, + instance_id: str, + database_id: str, + ddl_statements: List[str], + project_id: str, + ) -> None: + """ + Creates a new database in Cloud Spanner. + + :type project_id: str + :param instance_id: The ID of the Cloud Spanner instance. + :type instance_id: str + :param database_id: The ID of the database to create in Cloud Spanner. + :type database_id: str + :param ddl_statements: The string list containing DDL for the new database. + :type ddl_statements: list[str] + :param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner + database. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :return: None + """ + instance = self._get_client(project_id=project_id).instance( + instance_id=instance_id + ) + if not instance.exists(): + raise AirflowException( + f"The instance {instance_id} does not exist in project {project_id} !" + ) + database = instance.database( + database_id=database_id, ddl_statements=ddl_statements + ) + try: + operation = database.create() # type: Operation + except GoogleAPICallError as e: + self.log.error("An error occurred: %s. Exiting.", e.message) + raise e + + if operation: + result = operation.result() + self.log.info(result) + + @GoogleBaseHook.fallback_to_default_project_id + def update_database( + self, + instance_id: str, + database_id: str, + ddl_statements: List[str], + project_id: str, + operation_id: Optional[str] = None, + ) -> None: + """ + Updates DDL of a database in Cloud Spanner. + + :type project_id: str + :param instance_id: The ID of the Cloud Spanner instance. + :type instance_id: str + :param database_id: The ID of the database in Cloud Spanner. + :type database_id: str + :param ddl_statements: The string list containing DDL for the new database. + :type ddl_statements: list[str] + :param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner + database. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :param operation_id: (Optional) The unique per database operation ID that can be + specified to implement idempotency check. + :type operation_id: str + :return: None + """ + instance = self._get_client(project_id=project_id).instance( + instance_id=instance_id + ) + if not instance.exists(): + raise AirflowException( + f"The instance {instance_id} does not exist in project {project_id} !" + ) + database = instance.database(database_id=database_id) + try: + operation = database.update_ddl( + ddl_statements=ddl_statements, operation_id=operation_id + ) + if operation: + result = operation.result() + self.log.info(result) + return + except AlreadyExists as e: + if e.code == 409 and operation_id in e.message: + self.log.info( + "Replayed update_ddl message - the operation id %s " + "was already done before.", + operation_id, + ) + return + except GoogleAPICallError as e: + self.log.error("An error occurred: %s. Exiting.", e.message) + raise e + + @GoogleBaseHook.fallback_to_default_project_id + def delete_database(self, instance_id: str, database_id, project_id: str) -> bool: + """ + Drops a database in Cloud Spanner. + + :type project_id: str + :param instance_id: The ID of the Cloud Spanner instance. + :type instance_id: str + :param database_id: The ID of the database in Cloud Spanner. + :type database_id: str + :param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner + database. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :return: True if everything succeeded + :rtype: bool + """ + instance = self._get_client(project_id=project_id).instance( + instance_id=instance_id + ) + if not instance.exists(): + raise AirflowException( + f"The instance {instance_id} does not exist in project {project_id} !" + ) + database = instance.database(database_id=database_id) + if not database.exists(): + self.log.info( + "The database %s is already deleted from instance %s. Exiting.", + database_id, + instance_id, + ) + return False + try: + database.drop() # pylint: disable=E1111 + except GoogleAPICallError as e: + self.log.error("An error occurred: %s. Exiting.", e.message) + raise e + + return True + + @GoogleBaseHook.fallback_to_default_project_id + def execute_dml( + self, + instance_id: str, + database_id: str, + queries: List[str], + project_id: str, + ) -> None: + """ + Executes an arbitrary DML query (INSERT, UPDATE, DELETE). + + :param instance_id: The ID of the Cloud Spanner instance. + :type instance_id: str + :param database_id: The ID of the database in Cloud Spanner. + :type database_id: str + :param queries: The queries to execute. + :type queries: List[str] + :param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner + database. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :type project_id: str + """ + self._get_client(project_id=project_id).instance( + instance_id=instance_id + ).database(database_id=database_id).run_in_transaction( + lambda transaction: self._execute_sql_in_transaction(transaction, queries) + ) + + @staticmethod + def _execute_sql_in_transaction(transaction: Transaction, queries: List[str]): + for sql in queries: + transaction.execute_update(sql) diff --git a/reference/providers/google/cloud/hooks/speech_to_text.py b/reference/providers/google/cloud/hooks/speech_to_text.py new file mode 100644 index 0000000..30e23fd --- /dev/null +++ b/reference/providers/google/cloud/hooks/speech_to_text.py @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Speech Hook.""" +from typing import Dict, Optional, Sequence, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.retry import Retry +from google.cloud.speech_v1 import SpeechClient +from google.cloud.speech_v1.types import RecognitionAudio, RecognitionConfig + + +class CloudSpeechToTextHook(GoogleBaseHook): + """ + Hook for Google Cloud Speech API. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client = None + + def get_conn(self) -> SpeechClient: + """ + Retrieves connection to Cloud Speech. + + :return: Google Cloud Speech client object. + :rtype: google.cloud.speech_v1.SpeechClient + """ + if not self._client: + self._client = SpeechClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._client + + @GoogleBaseHook.quota_retry() + def recognize_speech( + self, + config: Union[Dict, RecognitionConfig], + audio: Union[Dict, RecognitionAudio], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + ): + """ + Recognizes audio input + + :param config: information to the recognizer that specifies how to process the request. + https://googleapis.github.io/google-cloud-python/latest/speech/gapic/v1/types.html#google.cloud.speech_v1.types.RecognitionConfig + :type config: dict or google.cloud.speech_v1.types.RecognitionConfig + :param audio: audio data to be recognized + https://googleapis.github.io/google-cloud-python/latest/speech/gapic/v1/types.html#google.cloud.speech_v1.types.RecognitionAudio + :type audio: dict or google.cloud.speech_v1.types.RecognitionAudio + :param retry: (Optional) A retry object used to retry requests. If None is specified, + requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to complete. + Note that if retry is specified, the timeout applies to each individual attempt. + :type timeout: float + """ + client = self.get_conn() + response = client.recognize( + config=config, audio=audio, retry=retry, timeout=timeout + ) + self.log.info("Recognised speech: %s", response) + return response diff --git a/reference/providers/google/cloud/hooks/stackdriver.py b/reference/providers/google/cloud/hooks/stackdriver.py new file mode 100644 index 0000000..459294f --- /dev/null +++ b/reference/providers/google/cloud/hooks/stackdriver.py @@ -0,0 +1,653 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains Google Cloud Stackdriver operators.""" + +import json +from typing import Any, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.exceptions import InvalidArgument +from google.api_core.gapic_v1.method import DEFAULT +from google.cloud import monitoring_v3 +from google.cloud.monitoring_v3 import AlertPolicy, NotificationChannel +from google.protobuf.field_mask_pb2 import FieldMask +from googleapiclient.errors import HttpError + + +class StackdriverHook(GoogleBaseHook): + """Stackdriver Hook for connecting with Google Cloud Stackdriver""" + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._policy_client = None + self._channel_client = None + + def _get_policy_client(self): + if not self._policy_client: + self._policy_client = monitoring_v3.AlertPolicyServiceClient() + return self._policy_client + + def _get_channel_client(self): + if not self._channel_client: + self._channel_client = monitoring_v3.NotificationChannelServiceClient() + return self._channel_client + + @GoogleBaseHook.fallback_to_default_project_id + def list_alert_policies( + self, + project_id: str, + format_: Optional[str] = None, + filter_: Optional[str] = None, + order_by: Optional[str] = None, + page_size: Optional[int] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + ) -> Any: + """ + Fetches all the Alert Policies identified by the filter passed as + filter parameter. The desired return type can be specified by the + format parameter, the supported formats are "dict", "json" and None + which returns python dictionary, stringified JSON and protobuf + respectively. + + :param format_: (Optional) Desired output format of the result. The + supported formats are "dict", "json" and None which returns + python dictionary, stringified JSON and protobuf respectively. + :type format_: str + :param filter_: If provided, this field specifies the criteria that + must be met by alert policies to be included in the response. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type filter_: str + :param order_by: A comma-separated list of fields by which to sort the result. + Supports the same set of field references as the ``filter`` field. Entries + can be prefixed with a minus sign to sort by the field in descending order. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type order_by: str + :param page_size: The maximum number of resources contained in the + underlying API response. If page streaming is performed per- + resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number + of resources in a page. + :type page_size: int + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + :param project_id: The project to fetch alerts from. + :type project_id: str + """ + client = self._get_policy_client() + policies_ = client.list_alert_policies( + request={ + "name": f"projects/{project_id}", + "filter": filter_, + "order_by": order_by, + "page_size": page_size, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + if format_ == "dict": + return [AlertPolicy.to_dict(policy) for policy in policies_] + elif format_ == "json": + return [AlertPolicy.to_jsoon(policy) for policy in policies_] + else: + return policies_ + + @GoogleBaseHook.fallback_to_default_project_id + def _toggle_policy_status( + self, + new_state: bool, + project_id: str, + filter_: Optional[str] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + ): + client = self._get_policy_client() + policies_ = self.list_alert_policies(project_id=project_id, filter_=filter_) + for policy in policies_: + if policy.enabled != bool(new_state): + policy.enabled = bool(new_state) + mask = FieldMask(paths=["enabled"]) + client.update_alert_policy( + request={"alert_policy": policy, "update_mask": mask}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def enable_alert_policies( + self, + project_id: str, + filter_: Optional[str] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + ) -> None: + """ + Enables one or more disabled alerting policies identified by filter + parameter. Inoperative in case the policy is already enabled. + + :param project_id: The project in which alert needs to be enabled. + :type project_id: str + :param filter_: If provided, this field specifies the criteria that + must be met by alert policies to be enabled. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type filter_: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + """ + self._toggle_policy_status( + new_state=True, + project_id=project_id, + filter_=filter_, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def disable_alert_policies( + self, + project_id: str, + filter_: Optional[str] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + ) -> None: + """ + Disables one or more enabled alerting policies identified by filter + parameter. Inoperative in case the policy is already disabled. + + :param project_id: The project in which alert needs to be disabled. + :type project_id: str + :param filter_: If provided, this field specifies the criteria that + must be met by alert policies to be disabled. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type filter_: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + """ + self._toggle_policy_status( + filter_=filter_, + project_id=project_id, + new_state=False, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def upsert_alert( + self, + alerts: str, + project_id: str, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + ) -> None: + """ + Creates a new alert or updates an existing policy identified + the name field in the alerts parameter. + + :param project_id: The project in which alert needs to be created/updated. + :type project_id: str + :param alerts: A JSON string or file that specifies all the alerts that needs + to be either created or updated. For more details, see + https://cloud.google.com/monitoring/api/ref_v3/rest/v3/projects.alertPolicies#AlertPolicy. + (templated) + :type alerts: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + """ + policy_client = self._get_policy_client() + channel_client = self._get_channel_client() + + record = json.loads(alerts) + existing_policies = [ + policy["name"] + for policy in self.list_alert_policies( + project_id=project_id, format_="dict" + ) + ] + existing_channels = [ + channel["name"] + for channel in self.list_notification_channels( + project_id=project_id, format_="dict" + ) + ] + policies_ = [] + channels = [] + for channel in record.get("channels", []): + channels.append(NotificationChannel(**channel)) + for policy in record.get("policies", []): + policies_.append(AlertPolicy(**policy)) + + channel_name_map = {} + + for channel in channels: + channel.verification_status = ( + monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED + ) + + if channel.name in existing_channels: + channel_client.update_notification_channel( + request={"notification_channel": channel}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + else: + old_name = channel.name + channel.name = None + new_channel = channel_client.create_notification_channel( + request={ + "name": f"projects/{project_id}", + "notification_channel": channel, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + channel_name_map[old_name] = new_channel.name + + for policy in policies_: + policy.creation_record = None + policy.mutation_record = None + + for i, channel in enumerate(policy.notification_channels): + new_channel = channel_name_map.get(channel) + if new_channel: + policy.notification_channels[i] = new_channel + + if policy.name in existing_policies: + try: + policy_client.update_alert_policy( + request={"alert_policy": policy}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + except InvalidArgument: + pass + else: + policy.name = None + for condition in policy.conditions: + condition.name = None + policy_client.create_alert_policy( + request={"name": f"projects/{project_id}", "alert_policy": policy}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + def delete_alert_policy( + self, + name: str, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + ) -> None: + """ + Deletes an alerting policy. + + :param name: The alerting policy to delete. The format is: + ``projects/[PROJECT_ID]/alertPolicies/[ALERT_POLICY_ID]``. + :type name: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + """ + policy_client = self._get_policy_client() + try: + policy_client.delete_alert_policy( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + except HttpError as err: + raise AirflowException( + f"Delete alerting policy failed. Error was {err.content}" + ) + + @GoogleBaseHook.fallback_to_default_project_id + def list_notification_channels( + self, + project_id: str, + format_: Optional[str] = None, + filter_: Optional[str] = None, + order_by: Optional[str] = None, + page_size: Optional[int] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[str] = DEFAULT, + metadata: Optional[str] = None, + ) -> Any: + """ + Fetches all the Notification Channels identified by the filter passed as + filter parameter. The desired return type can be specified by the + format parameter, the supported formats are "dict", "json" and None + which returns python dictionary, stringified JSON and protobuf + respectively. + + :param format_: (Optional) Desired output format of the result. The + supported formats are "dict", "json" and None which returns + python dictionary, stringified JSON and protobuf respectively. + :type format_: str + :param filter_: If provided, this field specifies the criteria that + must be met by notification channels to be included in the response. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type filter_: str + :param order_by: A comma-separated list of fields by which to sort the result. + Supports the same set of field references as the ``filter`` field. Entries + can be prefixed with a minus sign to sort by the field in descending order. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type order_by: str + :param page_size: The maximum number of resources contained in the + underlying API response. If page streaming is performed per- + resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number + of resources in a page. + :type page_size: int + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + :param project_id: The project to fetch notification channels from. + :type project_id: str + """ + client = self._get_channel_client() + channels = client.list_notification_channels( + request={ + "name": f"projects/{project_id}", + "filter": filter_, + "order_by": order_by, + "page_size": page_size, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + if format_ == "dict": + return [NotificationChannel.to_dict(channel) for channel in channels] + elif format_ == "json": + return [NotificationChannel.to_json(channel) for channel in channels] + else: + return channels + + @GoogleBaseHook.fallback_to_default_project_id + def _toggle_channel_status( + self, + new_state: bool, + project_id: str, + filter_: Optional[str] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[str] = DEFAULT, + metadata: Optional[str] = None, + ) -> None: + client = self._get_channel_client() + channels = client.list_notification_channels( + request={"name": f"projects/{project_id}", "filter": filter_} + ) + for channel in channels: + if channel.enabled != bool(new_state): + channel.enabled = bool(new_state) + mask = FieldMask(paths=["enabled"]) + client.update_notification_channel( + request={"notification_channel": channel, "update_mask": mask}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def enable_notification_channels( + self, + project_id: str, + filter_: Optional[str] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[str] = DEFAULT, + metadata: Optional[str] = None, + ) -> None: + """ + Enables one or more disabled alerting policies identified by filter + parameter. Inoperative in case the policy is already enabled. + + :param project_id: The project in which notification channels needs to be enabled. + :type project_id: str + :param filter_: If provided, this field specifies the criteria that + must be met by notification channels to be enabled. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type filter_: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + """ + self._toggle_channel_status( + project_id=project_id, + filter_=filter_, + new_state=True, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def disable_notification_channels( + self, + project_id: str, + filter_: Optional[str] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[str] = DEFAULT, + metadata: Optional[str] = None, + ) -> None: + """ + Disables one or more enabled notification channels identified by filter + parameter. Inoperative in case the policy is already disabled. + + :param project_id: The project in which notification channels needs to be enabled. + :type project_id: str + :param filter_: If provided, this field specifies the criteria that + must be met by alert policies to be disabled. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type filter_: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + """ + self._toggle_channel_status( + filter_=filter_, + project_id=project_id, + new_state=False, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def upsert_channel( + self, + channels: str, + project_id: str, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + ) -> dict: + """ + Creates a new notification or updates an existing notification channel + identified the name field in the alerts parameter. + + :param channels: A JSON string or file that specifies all the alerts that needs + to be either created or updated. For more details, see + https://cloud.google.com/monitoring/api/ref_v3/rest/v3/projects.notificationChannels. + (templated) + :type channels: str + :param project_id: The project in which notification channels needs to be created/updated. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + """ + channel_client = self._get_channel_client() + + record = json.loads(channels) + existing_channels = [ + channel["name"] + for channel in self.list_notification_channels( + project_id=project_id, format_="dict" + ) + ] + channels_list = [] + channel_name_map = {} + + for channel in record["channels"]: + channels_list.append(NotificationChannel(**channel)) + + for channel in channels_list: + channel.verification_status = ( + monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED + ) + + if channel.name in existing_channels: + channel_client.update_notification_channel( + request={"notification_channel": channel}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + else: + old_name = channel.name + channel.name = None + new_channel = channel_client.create_notification_channel( + request={ + "name": f"projects/{project_id}", + "notification_channel": channel, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + channel_name_map[old_name] = new_channel.name + + return channel_name_map + + def delete_notification_channel( + self, + name: str, + retry: Optional[str] = DEFAULT, + timeout: Optional[str] = DEFAULT, + metadata: Optional[str] = None, + ) -> None: + """ + Deletes a notification channel. + + :param name: The alerting policy to delete. The format is: + ``projects/[PROJECT_ID]/notificationChannels/[CHANNEL_ID]``. + :type name: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + """ + channel_client = self._get_channel_client() + try: + channel_client.delete_notification_channel( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + except HttpError as err: + raise AirflowException( + f"Delete notification channel failed. Error was {err.content}" + ) diff --git a/reference/providers/google/cloud/hooks/tasks.py b/reference/providers/google/cloud/hooks/tasks.py new file mode 100644 index 0000000..d52c3a7 --- /dev/null +++ b/reference/providers/google/cloud/hooks/tasks.py @@ -0,0 +1,728 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This module contains a CloudTasksHook +which allows you to connect to Google Cloud Tasks service, +performing actions to queues or tasks. +""" + +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.retry import Retry +from google.cloud.tasks_v2 import CloudTasksClient +from google.cloud.tasks_v2.types import Queue, Task +from google.protobuf.field_mask_pb2 import FieldMask + + +class CloudTasksHook(GoogleBaseHook): + """ + Hook for Google Cloud Tasks APIs. Cloud Tasks allows developers to manage + the execution of background work in their applications. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client = None + + def get_conn(self) -> CloudTasksClient: + """ + Provides a client for interacting with the Google Cloud Tasks API. + + :return: Google Cloud Tasks API Client + :rtype: google.cloud.tasks_v2.CloudTasksClient + """ + if not self._client: + self._client = CloudTasksClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._client + + @GoogleBaseHook.fallback_to_default_project_id + def create_queue( + self, + location: str, + task_queue: Union[dict, Queue], + project_id: str, + queue_name: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Queue: + """ + Creates a queue in Cloud Tasks. + + :param location: The location name in which the queue will be created. + :type location: str + :param task_queue: The task queue to create. + Queue's name cannot be the same as an existing queue. + If a dict is provided, it must be of the same form as the protobuf message Queue. + :type task_queue: dict or google.cloud.tasks_v2.types.Queue + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param queue_name: (Optional) The queue's name. + If provided, it will be used to construct the full queue path. + :type queue_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.tasks_v2.types.Queue + """ + client = self.get_conn() + + if queue_name: + full_queue_name = ( + f"projects/{project_id}/locations/{location}/queues/{queue_name}" + ) + if isinstance(task_queue, Queue): + task_queue.name = full_queue_name + elif isinstance(task_queue, dict): + task_queue["name"] = full_queue_name + else: + raise AirflowException("Unable to set queue_name.") + full_location_path = f"projects/{project_id}/locations/{location}" + return client.create_queue( + request={"parent": full_location_path, "queue": task_queue}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def update_queue( + self, + task_queue: Queue, + project_id: str, + location: Optional[str] = None, + queue_name: Optional[str] = None, + update_mask: Optional[FieldMask] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Queue: + """ + Updates a queue in Cloud Tasks. + + :param task_queue: The task queue to update. + This method creates the queue if it does not exist and updates the queue if + it does exist. The queue's name must be specified. + :type task_queue: dict or google.cloud.tasks_v2.types.Queue + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param location: (Optional) The location name in which the queue will be updated. + If provided, it will be used to construct the full queue path. + :type location: str + :param queue_name: (Optional) The queue's name. + If provided, it will be used to construct the full queue path. + :type queue_name: str + :param update_mask: A mast used to specify which fields of the queue are being updated. + If empty, then all fields will be updated. + If a dict is provided, it must be of the same form as the protobuf message. + :type update_mask: dict or google.protobuf.field_mask_pb2.FieldMask + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.tasks_v2.types.Queue + """ + client = self.get_conn() + + if queue_name and location: + full_queue_name = ( + f"projects/{project_id}/locations/{location}/queues/{queue_name}" + ) + if isinstance(task_queue, Queue): + task_queue.name = full_queue_name + elif isinstance(task_queue, dict): + task_queue["name"] = full_queue_name + else: + raise AirflowException("Unable to set queue_name.") + return client.update_queue( + request={"queue": task_queue, "update_mask": update_mask}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_queue( + self, + location: str, + queue_name: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Queue: + """ + Gets a queue from Cloud Tasks. + + :param location: The location name in which the queue was created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.tasks_v2.types.Queue + """ + client = self.get_conn() + + full_queue_name = ( + f"projects/{project_id}/locations/{location}/queues/{queue_name}" + ) + return client.get_queue( + request={"name": full_queue_name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def list_queues( + self, + location: str, + project_id: str, + results_filter: Optional[str] = None, + page_size: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> List[Queue]: + """ + Lists queues from Cloud Tasks. + + :param location: The location name in which the queues were created. + :type location: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param results_filter: (Optional) Filter used to specify a subset of queues. + :type results_filter: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + client = self.get_conn() + + full_location_path = f"projects/{project_id}/locations/{location}" + queues = client.list_queues( + request={ + "parent": full_location_path, + "filter": results_filter, + "page_size": page_size, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return list(queues) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_queue( + self, + location: str, + queue_name: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes a queue from Cloud Tasks, even if it has tasks in it. + + :param location: The location name in which the queue will be deleted. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + """ + client = self.get_conn() + + full_queue_name = ( + f"projects/{project_id}/locations/{location}/queues/{queue_name}" + ) + client.delete_queue( + request={"name": full_queue_name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def purge_queue( + self, + location: str, + queue_name: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> List[Queue]: + """ + Purges a queue by deleting all of its tasks from Cloud Tasks. + + :param location: The location name in which the queue will be purged. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + client = self.get_conn() + + full_queue_name = ( + f"projects/{project_id}/locations/{location}/queues/{queue_name}" + ) + return client.purge_queue( + request={"name": full_queue_name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def pause_queue( + self, + location: str, + queue_name: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> List[Queue]: + """ + Pauses a queue in Cloud Tasks. + + :param location: The location name in which the queue will be paused. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + client = self.get_conn() + + full_queue_name = ( + f"projects/{project_id}/locations/{location}/queues/{queue_name}" + ) + return client.pause_queue( + request={"name": full_queue_name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def resume_queue( + self, + location: str, + queue_name: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> List[Queue]: + """ + Resumes a queue in Cloud Tasks. + + :param location: The location name in which the queue will be resumed. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + client = self.get_conn() + + full_queue_name = ( + f"projects/{project_id}/locations/{location}/queues/{queue_name}" + ) + return client.resume_queue( + request={"name": full_queue_name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_task( + self, + location: str, + queue_name: str, + task: Union[Dict, Task], + project_id: str, + task_name: Optional[str] = None, + response_view: Optional = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Task: + """ + Creates a task in Cloud Tasks. + + :param location: The location name in which the task will be created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task: The task to add. + If a dict is provided, it must be of the same form as the protobuf message Task. + :type task: dict or google.cloud.tasks_v2.types.Task + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param task_name: (Optional) The task's name. + If provided, it will be used to construct the full task path. + :type task_name: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.Task.View + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.tasks_v2.types.Task + """ + client = self.get_conn() + + if task_name: + full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}" + if isinstance(task, Task): + task.name = full_task_name + elif isinstance(task, dict): + task["name"] = full_task_name + else: + raise AirflowException("Unable to set task_name.") + full_queue_name = ( + f"projects/{project_id}/locations/{location}/queues/{queue_name}" + ) + return client.create_task( + request={ + "parent": full_queue_name, + "task": task, + "response_view": response_view, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_task( + self, + location: str, + queue_name: str, + task_name: str, + project_id: str, + response_view: Optional = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Task: + """ + Gets a task from Cloud Tasks. + + :param location: The location name in which the task was created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task_name: The task's name. + :type task_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.Task.View + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.tasks_v2.types.Task + """ + client = self.get_conn() + + full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}" + return client.get_task( + request={"name": full_task_name, "response_view": response_view}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def list_tasks( + self, + location: str, + queue_name: str, + project_id: str, + response_view: Optional = None, + page_size: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> List[Task]: + """ + Lists the tasks in Cloud Tasks. + + :param location: The location name in which the tasks were created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.Task.View + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: list[google.cloud.tasks_v2.types.Task] + """ + client = self.get_conn() + full_queue_name = ( + f"projects/{project_id}/locations/{location}/queues/{queue_name}" + ) + tasks = client.list_tasks( + request={ + "parent": full_queue_name, + "response_view": response_view, + "page_size": page_size, + }, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + return list(tasks) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_task( + self, + location: str, + queue_name: str, + task_name: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + Deletes a task from Cloud Tasks. + + :param location: The location name in which the task will be deleted. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task_name: The task's name. + :type task_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + """ + client = self.get_conn() + + full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}" + client.delete_task( + request={"name": full_task_name}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) + + @GoogleBaseHook.fallback_to_default_project_id + def run_task( + self, + location: str, + queue_name: str, + task_name: str, + project_id: str, + response_view: Optional = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Task: + """ + Forces to run a task in Cloud Tasks. + + :param location: The location name in which the task was created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task_name: The task's name. + :type task_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.Task.View + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :rtype: google.cloud.tasks_v2.types.Task + """ + client = self.get_conn() + + full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}" + return client.run_task( + request={"name": full_task_name, "response_view": response_view}, + retry=retry, + timeout=timeout, + metadata=metadata or (), + ) diff --git a/reference/providers/google/cloud/hooks/text_to_speech.py b/reference/providers/google/cloud/hooks/text_to_speech.py new file mode 100644 index 0000000..35130db --- /dev/null +++ b/reference/providers/google/cloud/hooks/text_to_speech.py @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Text to Speech Hook.""" +from typing import Dict, Optional, Sequence, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.retry import Retry +from google.cloud.texttospeech_v1 import TextToSpeechClient +from google.cloud.texttospeech_v1.types import ( + AudioConfig, + SynthesisInput, + SynthesizeSpeechResponse, + VoiceSelectionParams, +) + + +class CloudTextToSpeechHook(GoogleBaseHook): + """ + Hook for Google Cloud Text to Speech API. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client = None # type: Optional[TextToSpeechClient] + + def get_conn(self) -> TextToSpeechClient: + """ + Retrieves connection to Cloud Text to Speech. + + :return: Google Cloud Text to Speech client object. + :rtype: google.cloud.texttospeech_v1.TextToSpeechClient + """ + if not self._client: + # pylint: disable=unexpected-keyword-arg + self._client = TextToSpeechClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + # pylint: enable=unexpected-keyword-arg + + return self._client + + @GoogleBaseHook.quota_retry() + def synthesize_speech( + self, + input_data: Union[Dict, SynthesisInput], + voice: Union[Dict, VoiceSelectionParams], + audio_config: Union[Dict, AudioConfig], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + ) -> SynthesizeSpeechResponse: + """ + Synthesizes text input + + :param input_data: text input to be synthesized. See more: + https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.SynthesisInput + :type input_data: dict or google.cloud.texttospeech_v1.types.SynthesisInput + :param voice: configuration of voice to be used in synthesis. See more: + https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.VoiceSelectionParams + :type voice: dict or google.cloud.texttospeech_v1.types.VoiceSelectionParams + :param audio_config: configuration of the synthesized audio. See more: + https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.AudioConfig + :type audio_config: dict or google.cloud.texttospeech_v1.types.AudioConfig + :param retry: (Optional) A retry object used to retry requests. If None is specified, + requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to complete. + Note that if retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :return: SynthesizeSpeechResponse See more: + https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.SynthesizeSpeechResponse + :rtype: object + """ + client = self.get_conn() + self.log.info("Synthesizing input: %s", input_data) + # pylint: disable=unexpected-keyword-arg + return client.synthesize_speech( + input_=input_data, + voice=voice, + audio_config=audio_config, + retry=retry, + timeout=timeout, + ) + # pylint: enable=unexpected-keyword-arg diff --git a/reference/providers/google/cloud/hooks/translate.py b/reference/providers/google/cloud/hooks/translate.py new file mode 100644 index 0000000..18d6710 --- /dev/null +++ b/reference/providers/google/cloud/hooks/translate.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Translate Hook.""" +from typing import List, Optional, Sequence, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.cloud.translate_v2 import Client + + +class CloudTranslateHook(GoogleBaseHook): + """ + Hook for Google Cloud translate APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client = None # type: Optional[Client] + + def get_conn(self) -> Client: + """ + Retrieves connection to Cloud Translate + + :return: Google Cloud Translate client object. + :rtype: google.cloud.translate_v2.Client + """ + if not self._client: + self._client = Client( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._client + + @GoogleBaseHook.quota_retry() + def translate( + self, + values: Union[str, List[str]], + target_language: str, + format_: Optional[str] = None, + source_language: Optional[str] = None, + model: Optional[Union[str, List[str]]] = None, + ) -> dict: + """Translate a string or list of strings. + + See https://cloud.google.com/translate/docs/translating-text + + :type values: str or list + :param values: String or list of strings to translate. + :type target_language: str + :param target_language: The language to translate results into. This + is required by the API and defaults to + the target language of the current instance. + :type format_: str + :param format_: (Optional) One of ``text`` or ``html``, to specify + if the input text is plain text or HTML. + :type source_language: str or None + :param source_language: (Optional) The language of the text to + be translated. + :type model: str or None + :param model: (Optional) The model used to translate the text, such + as ``'base'`` or ``'nmt'``. + :rtype: str or list + :returns: A list of dictionaries for each queried value. Each + dictionary typically contains three keys (though not + all will be present in all cases) + + * ``detectedSourceLanguage``: The detected language (as an + ISO 639-1 language code) of the text. + + * ``translatedText``: The translation of the text into the + target language. + + * ``input``: The corresponding input value. + + * ``model``: The model used to translate the text. + + If only a single value is passed, then only a single + dictionary will be returned. + :raises: :class:`~exceptions.ValueError` if the number of + values and translations differ. + """ + client = self.get_conn() + + return client.translate( + values=values, + target_language=target_language, + format_=format_, + source_language=source_language, + model=model, + ) diff --git a/reference/providers/google/cloud/hooks/video_intelligence.py b/reference/providers/google/cloud/hooks/video_intelligence.py new file mode 100644 index 0000000..bd6d189 --- /dev/null +++ b/reference/providers/google/cloud/hooks/video_intelligence.py @@ -0,0 +1,132 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Video Intelligence Hook.""" +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.operation import Operation +from google.api_core.retry import Retry +from google.cloud.videointelligence_v1 import VideoIntelligenceServiceClient +from google.cloud.videointelligence_v1.types import VideoContext + + +class CloudVideoIntelligenceHook(GoogleBaseHook): + """ + Hook for Google Cloud Video Intelligence APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._conn = None + + def get_conn(self) -> VideoIntelligenceServiceClient: + """ + Returns Gcp Video Intelligence Service client + + :rtype: google.cloud.videointelligence_v1.VideoIntelligenceServiceClient + """ + if not self._conn: + self._conn = VideoIntelligenceServiceClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._conn + + @GoogleBaseHook.quota_retry() + def annotate_video( + self, + input_uri: Optional[str] = None, + input_content: Optional[bytes] = None, + features: Optional[List[VideoIntelligenceServiceClient.enums.Feature]] = None, + video_context: Union[Dict, VideoContext] = None, + output_uri: Optional[str] = None, + location: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Performs video annotation. + + :param input_uri: Input video location. Currently, only Google Cloud Storage URIs are supported, + which must be specified in the following format: ``gs://bucket-id/object-id``. + :type input_uri: str + :param input_content: The video data bytes. + If unset, the input video(s) should be specified via ``input_uri``. + If set, ``input_uri`` should be unset. + :type input_content: bytes + :param features: Requested video annotation features. + :type features: list[google.cloud.videointelligence_v1.VideoIntelligenceServiceClient.enums.Feature] + :param output_uri: Optional, location where the output (in JSON format) should be stored. Currently, + only Google Cloud Storage URIs are supported, which must be specified in the following format: + ``gs://bucket-id/object-id``. + :type output_uri: str + :param video_context: Optional, Additional video context and/or feature-specific parameters. + :type video_context: dict or google.cloud.videointelligence_v1.types.VideoContext + :param location: Optional, cloud region where annotation should take place. Supported cloud regions: + us-east1, us-west1, europe-west1, asia-east1. + If no region is specified, a region will be determined based on video file location. + :type location: str + :param retry: Retry object used to determine when/if to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: Optional, The amount of time, in seconds, to wait for the request to complete. + Note that if retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Optional, Additional metadata that is provided to the method. + :type metadata: seq[tuple[str, str]] + """ + client = self.get_conn() + return client.annotate_video( + input_uri=input_uri, + input_content=input_content, + features=features, + video_context=video_context, + output_uri=output_uri, + location_id=location, + retry=retry, + timeout=timeout, + metadata=metadata, + ) diff --git a/reference/providers/google/cloud/hooks/vision.py b/reference/providers/google/cloud/hooks/vision.py new file mode 100644 index 0000000..6deb76b --- /dev/null +++ b/reference/providers/google/cloud/hooks/vision.py @@ -0,0 +1,766 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Vision Hook.""" + +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.retry import Retry +from google.cloud.vision_v1 import ImageAnnotatorClient, ProductSearchClient +from google.cloud.vision_v1.types import ( + AnnotateImageRequest, + FieldMask, + Image, + Product, + ProductSet, + ReferenceImage, +) +from google.protobuf.json_format import MessageToDict + +ERR_DIFF_NAMES = """The {label} name provided in the object ({explicit_name}) is different + than the name created from the input parameters ({constructed_name}). Please either: + 1) Remove the {label} name, + 2) Remove the location and {id_label} parameters, + 3) Unify the {label} name and input parameters. + """ + +ERR_UNABLE_TO_CREATE = """Unable to determine the {label} name. Please either set the name directly + in the {label} object or provide the `location` and `{id_label}` parameters. + """ + + +class NameDeterminer: + """Helper class to determine entity name.""" + + def __init__( + self, label: str, id_label: str, get_path: Callable[[str, str, str], str] + ) -> None: + self.label = label + self.id_label = id_label + self.get_path = get_path + + def get_entity_with_name( + self, + entity: Any, + entity_id: Optional[str], + location: Optional[str], + project_id: str, + ) -> Any: + """ + Check if entity has the `name` attribute set: + * If so, no action is taken. + + * If not, and the name can be constructed from other parameters provided, it is created and filled in + the entity. + + * If both the entity's 'name' attribute is set and the name can be constructed from other parameters + provided: + + * If they are the same - no action is taken + + * if they are different - an exception is thrown. + + + :param entity: Entity + :type entity: any + :param entity_id: Entity id + :type entity_id: str + :param location: Location + :type location: str + :param project_id: The id of Google Cloud Vision project. + :type project_id: str + :return: The same entity or entity with new name + :rtype: str + :raises: AirflowException + """ + entity = deepcopy(entity) + explicit_name = getattr(entity, "name") + if location and entity_id: + # Necessary parameters to construct the name are present. Checking for conflict with explicit name + constructed_name = self.get_path(project_id, location, entity_id) + if not explicit_name: + entity.name = constructed_name + return entity + + if explicit_name != constructed_name: + raise AirflowException( + ERR_DIFF_NAMES.format( + label=self.label, + explicit_name=explicit_name, + constructed_name=constructed_name, + id_label=self.id_label, + ) + ) + + # Not enough parameters to construct the name. Trying to use the name from Product / ProductSet. + if explicit_name: + return entity + else: + raise AirflowException( + ERR_UNABLE_TO_CREATE.format(label=self.label, id_label=self.id_label) + ) + + +class CloudVisionHook(GoogleBaseHook): + """ + Hook for Google Cloud Vision APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + product_name_determiner = NameDeterminer( + "Product", "product_id", ProductSearchClient.product_path + ) + product_set_name_determiner = NameDeterminer( + "ProductSet", "productset_id", ProductSearchClient.product_set_path + ) + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self._client = None + + def get_conn(self) -> ProductSearchClient: + """ + Retrieves connection to Cloud Vision. + + :return: Google Cloud Vision client object. + :rtype: google.cloud.vision_v1.ProductSearchClient + """ + if not self._client: + self._client = ProductSearchClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._client + + @cached_property + def annotator_client(self) -> ImageAnnotatorClient: + """ + Creates ImageAnnotatorClient. + + :return: Google Image Annotator client object. + :rtype: google.cloud.vision_v1.ImageAnnotatorClient + """ + return ImageAnnotatorClient(credentials=self._get_credentials()) + + @staticmethod + def _check_for_error(response: Dict) -> None: + if "error" in response: + raise AirflowException(response) + + @GoogleBaseHook.fallback_to_default_project_id + def create_product_set( + self, + location: str, + product_set: Union[dict, ProductSet], + project_id: str, + product_set_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> str: + """ + For the documentation see: + :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductSetOperator` + """ + client = self.get_conn() + parent = ProductSearchClient.location_path(project_id, location) + self.log.info("Creating a new ProductSet under the parent: %s", parent) + response = client.create_product_set( + parent=parent, + product_set=product_set, + product_set_id=product_set_id, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + self.log.info("ProductSet created: %s", response.name if response else "") + self.log.debug("ProductSet created:\n%s", response) + + if not product_set_id: + # Product set id was generated by the API + product_set_id = self._get_autogenerated_id(response) + self.log.info( + "Extracted autogenerated ProductSet ID from the response: %s", + product_set_id, + ) + + return product_set_id + + @GoogleBaseHook.fallback_to_default_project_id + def get_product_set( + self, + location: str, + product_set_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> dict: + """ + For the documentation see: + :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionGetProductSetOperator` + """ + client = self.get_conn() + name = ProductSearchClient.product_set_path( + project_id, location, product_set_id + ) + self.log.info("Retrieving ProductSet: %s", name) + response = client.get_product_set( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + self.log.info("ProductSet retrieved.") + self.log.debug("ProductSet retrieved:\n%s", response) + return MessageToDict(response) + + @GoogleBaseHook.fallback_to_default_project_id + def update_product_set( + self, + product_set: Union[dict, ProductSet], + project_id: str, + location: Optional[str] = None, + product_set_id: Optional[str] = None, + update_mask: Union[dict, FieldMask] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> dict: + """ + For the documentation see: + :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductSetOperator` + """ + client = self.get_conn() + product_set = self.product_set_name_determiner.get_entity_with_name( + product_set, product_set_id, location, project_id + ) + self.log.info("Updating ProductSet: %s", product_set.name) + response = client.update_product_set( + product_set=product_set, + update_mask=update_mask, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + self.log.info("ProductSet updated: %s", response.name if response else "") + self.log.debug("ProductSet updated:\n%s", response) + return MessageToDict(response) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_product_set( + self, + location: str, + product_set_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + For the documentation see: + :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductSetOperator` + """ + client = self.get_conn() + name = ProductSearchClient.product_set_path( + project_id, location, product_set_id + ) + self.log.info("Deleting ProductSet: %s", name) + client.delete_product_set( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + self.log.info("ProductSet with the name [%s] deleted.", name) + + @GoogleBaseHook.fallback_to_default_project_id + def create_product( + self, + location: str, + product: Union[dict, Product], + project_id: str, + product_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + For the documentation see: + :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductOperator` + """ + client = self.get_conn() + parent = ProductSearchClient.location_path(project_id, location) + self.log.info("Creating a new Product under the parent: %s", parent) + response = client.create_product( + parent=parent, + product=product, + product_id=product_id, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + self.log.info("Product created: %s", response.name if response else "") + self.log.debug("Product created:\n%s", response) + + if not product_id: + # Product id was generated by the API + product_id = self._get_autogenerated_id(response) + self.log.info( + "Extracted autogenerated Product ID from the response: %s", product_id + ) + + return product_id + + @GoogleBaseHook.fallback_to_default_project_id + def get_product( + self, + location: str, + product_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + For the documentation see: + :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionGetProductOperator` + """ + client = self.get_conn() + name = ProductSearchClient.product_path(project_id, location, product_id) + self.log.info("Retrieving Product: %s", name) + response = client.get_product( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + self.log.info("Product retrieved.") + self.log.debug("Product retrieved:\n%s", response) + return MessageToDict(response) + + @GoogleBaseHook.fallback_to_default_project_id + def update_product( + self, + product: Union[dict, Product], + project_id: str, + location: Optional[str] = None, + product_id: Optional[str] = None, + update_mask: Optional[Dict[str, FieldMask]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ): + """ + For the documentation see: + :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator` + """ + client = self.get_conn() + product = self.product_name_determiner.get_entity_with_name( + product, product_id, location, project_id + ) + self.log.info("Updating ProductSet: %s", product.name) + response = client.update_product( + product=product, + update_mask=update_mask, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + self.log.info("Product updated: %s", response.name if response else "") + self.log.debug("Product updated:\n%s", response) + return MessageToDict(response) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_product( + self, + location: str, + product_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + For the documentation see: + :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductOperator` + """ + client = self.get_conn() + name = ProductSearchClient.product_path(project_id, location, product_id) + self.log.info("Deleting ProductSet: %s", name) + client.delete_product( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + self.log.info("Product with the name [%s] deleted:", name) + + @GoogleBaseHook.fallback_to_default_project_id + def create_reference_image( + self, + location: str, + product_id: str, + reference_image: Union[dict, ReferenceImage], + project_id: str, + reference_image_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> str: + """ + For the documentation see: + :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionCreateReferenceImageOperator` + """ + client = self.get_conn() + self.log.info("Creating ReferenceImage") + parent = ProductSearchClient.product_path( + project=project_id, location=location, product=product_id + ) + + response = client.create_reference_image( + parent=parent, + reference_image=reference_image, + reference_image_id=reference_image_id, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + self.log.info("ReferenceImage created: %s", response.name if response else "") + self.log.debug("ReferenceImage created:\n%s", response) + + if not reference_image_id: + # Reference image id was generated by the API + reference_image_id = self._get_autogenerated_id(response) + self.log.info( + "Extracted autogenerated ReferenceImage ID from the response: %s", + reference_image_id, + ) + + return reference_image_id + + @GoogleBaseHook.fallback_to_default_project_id + def delete_reference_image( + self, + location: str, + product_id: str, + reference_image_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> dict: + """ + For the documentation see: + :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionDeleteReferenceImageOperator` + """ + client = self.get_conn() + self.log.info("Deleting ReferenceImage") + name = ProductSearchClient.reference_image_path( + project=project_id, + location=location, + product=product_id, + reference_image=reference_image_id, + ) + # pylint: disable=assignment-from-no-return + response = client.delete_reference_image( + name=name, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + self.log.info("ReferenceImage with the name [%s] deleted.", name) + return MessageToDict(response) + + @GoogleBaseHook.fallback_to_default_project_id + def add_product_to_product_set( + self, + product_set_id: str, + product_id: str, + project_id: str, + location: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + For the documentation see: + :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionAddProductToProductSetOperator` + """ + client = self.get_conn() + + product_name = ProductSearchClient.product_path( + project_id, location, product_id + ) + product_set_name = ProductSearchClient.product_set_path( + project_id, location, product_set_id + ) + + self.log.info( + "Add Product[name=%s] to Product Set[name=%s]", + product_name, + product_set_name, + ) + + client.add_product_to_product_set( + name=product_set_name, + product=product_name, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + self.log.info("Product added to Product Set") + + @GoogleBaseHook.fallback_to_default_project_id + def remove_product_from_product_set( + self, + product_set_id: str, + product_id: str, + project_id: str, + location: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: + """ + For the documentation see: + :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionRemoveProductFromProductSetOperator` # pylint: disable=line-too-long # noqa + """ + client = self.get_conn() + + product_name = ProductSearchClient.product_path( + project_id, location, product_id + ) + product_set_name = ProductSearchClient.product_set_path( + project_id, location, product_set_id + ) + + self.log.info( + "Remove Product[name=%s] from Product Set[name=%s]", + product_name, + product_set_name, + ) + + client.remove_product_from_product_set( + name=product_set_name, + product=product_name, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + self.log.info("Product removed from Product Set") + + def annotate_image( + self, + request: Union[dict, AnnotateImageRequest], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + ) -> Dict: + """ + For the documentation see: + :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator` + """ + client = self.annotator_client + + self.log.info("Annotating image") + + # pylint: disable=no-member + response = client.annotate_image(request=request, retry=retry, timeout=timeout) + + self.log.info("Image annotated") + + return MessageToDict(response) + + @GoogleBaseHook.quota_retry() + def batch_annotate_images( + self, + requests: Union[List[dict], List[AnnotateImageRequest]], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + ) -> dict: + """ + For the documentation see: + :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator` + """ + client = self.annotator_client + + self.log.info("Annotating images") + + response = client.batch_annotate_images( + requests=requests, retry=retry, timeout=timeout # pylint: disable=no-member + ) + + self.log.info("Images annotated") + + return MessageToDict(response) + + @GoogleBaseHook.quota_retry() + def text_detection( + self, + image: Union[dict, Image], + max_results: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + additional_properties: Optional[Dict] = None, + ) -> dict: + """ + For the documentation see: + :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionDetectTextOperator` + """ + client = self.annotator_client + + self.log.info("Detecting text") + + if additional_properties is None: + additional_properties = {} + + response = client.text_detection( # pylint: disable=no-member + image=image, + max_results=max_results, + retry=retry, + timeout=timeout, + **additional_properties, + ) + response = MessageToDict(response) + self._check_for_error(response) + + self.log.info("Text detection finished") + + return response + + @GoogleBaseHook.quota_retry() + def document_text_detection( + self, + image: Union[dict, Image], + max_results: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + additional_properties: Optional[dict] = None, + ) -> dict: + """ + For the documentation see: + :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionTextDetectOperator` + """ + client = self.annotator_client + + self.log.info("Detecting document text") + + if additional_properties is None: + additional_properties = {} + + response = client.document_text_detection( # pylint: disable=no-member + image=image, + max_results=max_results, + retry=retry, + timeout=timeout, + **additional_properties, + ) + response = MessageToDict(response) + self._check_for_error(response) + + self.log.info("Document text detection finished") + + return response + + @GoogleBaseHook.quota_retry() + def label_detection( + self, + image: Union[dict, Image], + max_results: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + additional_properties: Optional[dict] = None, + ) -> dict: + """ + For the documentation see: + :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionDetectImageLabelsOperator` + """ + client = self.annotator_client + + self.log.info("Detecting labels") + + if additional_properties is None: + additional_properties = {} + + response = client.label_detection( # pylint: disable=no-member + image=image, + max_results=max_results, + retry=retry, + timeout=timeout, + **additional_properties, + ) + response = MessageToDict(response) + self._check_for_error(response) + + self.log.info("Labels detection finished") + + return response + + @GoogleBaseHook.quota_retry() + def safe_search_detection( + self, + image: Union[dict, Image], + max_results: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + additional_properties: Optional[dict] = None, + ) -> dict: + """ + For the documentation see: + :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionDetectImageSafeSearchOperator` + """ + client = self.annotator_client + + self.log.info("Detecting safe search") + + if additional_properties is None: + additional_properties = {} + + response = client.safe_search_detection( # pylint: disable=no-member + image=image, + max_results=max_results, + retry=retry, + timeout=timeout, + **additional_properties, + ) + response = MessageToDict(response) + self._check_for_error(response) + + self.log.info("Safe search detection finished") + return response + + @staticmethod + def _get_autogenerated_id(response) -> str: + try: + name = response.name + except AttributeError as e: + raise AirflowException( + f"Unable to get name from response... [{response}]\n{e}" + ) + if "/" not in name: + raise AirflowException(f"Unable to get id from name... [{name}]") + return name.rsplit("/", 1)[1] diff --git a/reference/providers/google/cloud/hooks/workflows.py b/reference/providers/google/cloud/hooks/workflows.py new file mode 100644 index 0000000..cde98d0 --- /dev/null +++ b/reference/providers/google/cloud/hooks/workflows.py @@ -0,0 +1,416 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict, Optional, Sequence, Tuple, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from google.api_core.operation import Operation +from google.api_core.retry import Retry + +# pylint: disable=no-name-in-module +from google.cloud.workflows.executions_v1beta import Execution, ExecutionsClient +from google.cloud.workflows.executions_v1beta.services.executions.pagers import ( + ListExecutionsPager, +) +from google.cloud.workflows_v1beta import Workflow, WorkflowsClient +from google.cloud.workflows_v1beta.services.workflows.pagers import ListWorkflowsPager +from google.protobuf.field_mask_pb2 import FieldMask + +# pylint: enable=no-name-in-module + + +class WorkflowsHook(GoogleBaseHook): + """ + Hook for Google GCP APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + def get_workflows_client(self) -> WorkflowsClient: + """Returns WorkflowsClient.""" + return WorkflowsClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + + def get_executions_client(self) -> ExecutionsClient: + """Returns ExecutionsClient.""" + return ExecutionsClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_workflow( + self, + workflow: Dict, + workflow_id: str, + location: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Creates a new workflow. If a workflow with the specified name + already exists in the specified project and location, the long + running operation will return + [ALREADY_EXISTS][google.rpc.Code.ALREADY_EXISTS] error. + + :param workflow: Required. Workflow to be created. + :type workflow: Dict + :param workflow_id: Required. The ID of the workflow to be created. + :type workflow_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + metadata = metadata or () + client = self.get_workflows_client() + parent = f"projects/{project_id}/locations/{location}" + return client.create_workflow( + request={ + "parent": parent, + "workflow": workflow, + "workflow_id": workflow_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_workflow( + self, + workflow_id: str, + location: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Workflow: + """ + Gets details of a single Workflow. + + :param workflow_id: Required. The ID of the workflow to be created. + :type workflow_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + metadata = metadata or () + client = self.get_workflows_client() + name = f"projects/{project_id}/locations/{location}/workflows/{workflow_id}" + return client.get_workflow( + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata + ) + + def update_workflow( + self, + workflow: Union[Dict, Workflow], + update_mask: Optional[FieldMask] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Updates an existing workflow. + Running this method has no impact on already running + executions of the workflow. A new revision of the + workflow may be created as a result of a successful + update operation. In that case, such revision will be + used in new workflow executions. + + :param workflow: Required. Workflow to be created. + :type workflow: Dict + :param update_mask: List of fields to be updated. If not present, + the entire workflow will be updated. + :type update_mask: FieldMask + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + metadata = metadata or () + client = self.get_workflows_client() + return client.update_workflow( + request={"workflow": workflow, "update_mask": update_mask}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_workflow( + self, + workflow_id: str, + location: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Operation: + """ + Deletes a workflow with the specified name. + This method also cancels and deletes all running + executions of the workflow. + + :param workflow_id: Required. The ID of the workflow to be created. + :type workflow_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + metadata = metadata or () + client = self.get_workflows_client() + name = f"projects/{project_id}/locations/{location}/workflows/{workflow_id}" + return client.delete_workflow( + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleBaseHook.fallback_to_default_project_id + def list_workflows( + self, + location: str, + project_id: str, + filter_: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> ListWorkflowsPager: + """ + Lists Workflows in a given project and location. + The default order is not specified. + + :param filter_: Filter to restrict results to specific workflows. + :type filter_: str + :param order_by: Comma-separated list of fields that that + specify the order of the results. Default sorting order for a field is ascending. + To specify descending order for a field, append a "desc" suffix. + If not specified, the results will be returned in an unspecified order. + :type order_by: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + metadata = metadata or () + client = self.get_workflows_client() + parent = f"projects/{project_id}/locations/{location}" + + return client.list_workflows( + request={"parent": parent, "filter": filter_, "order_by": order_by}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_execution( + self, + workflow_id: str, + location: str, + project_id: str, + execution: Dict, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Execution: + """ + Creates a new execution using the latest revision of + the given workflow. + + :param execution: Required. Input parameters of the execution represented as a dictionary. + :type execution: Dict + :param workflow_id: Required. The ID of the workflow. + :type workflow_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + metadata = metadata or () + client = self.get_executions_client() + parent = f"projects/{project_id}/locations/{location}/workflows/{workflow_id}" + return client.create_execution( + request={"parent": parent, "execution": execution}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_execution( + self, + workflow_id: str, + execution_id: str, + location: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Execution: + """ + Returns an execution for the given ``workflow_id`` and ``execution_id``. + + :param workflow_id: Required. The ID of the workflow. + :type workflow_id: str + :param execution_id: Required. The ID of the execution. + :type execution_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + metadata = metadata or () + client = self.get_executions_client() + name = f"projects/{project_id}/locations/{location}/workflows/{workflow_id}/executions/{execution_id}" + return client.get_execution( + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_execution( + self, + workflow_id: str, + execution_id: str, + location: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> Execution: + """ + Cancels an execution using the given ``workflow_id`` and ``execution_id``. + + :param workflow_id: Required. The ID of the workflow. + :type workflow_id: str + :param execution_id: Required. The ID of the execution. + :type execution_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + metadata = metadata or () + client = self.get_executions_client() + name = f"projects/{project_id}/locations/{location}/workflows/{workflow_id}/executions/{execution_id}" + return client.cancel_execution( + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleBaseHook.fallback_to_default_project_id + def list_executions( + self, + workflow_id: str, + location: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> ListExecutionsPager: + """ + Returns a list of executions which belong to the + workflow with the given name. The method returns + executions of all workflow revisions. Returned + executions are ordered by their start time (newest + first). + + :param workflow_id: Required. The ID of the workflow to be created. + :type workflow_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + metadata = metadata or () + client = self.get_executions_client() + parent = f"projects/{project_id}/locations/{location}/workflows/{workflow_id}" + return client.list_executions( + request={"parent": parent}, retry=retry, timeout=timeout, metadata=metadata + ) diff --git a/reference/providers/google/cloud/log/__init__.py b/reference/providers/google/cloud/log/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/cloud/log/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/cloud/log/gcs_task_handler.py b/reference/providers/google/cloud/log/gcs_task_handler.py new file mode 100644 index 0000000..420b911 --- /dev/null +++ b/reference/providers/google/cloud/log/gcs_task_handler.py @@ -0,0 +1,199 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from typing import Collection, Optional + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow import version +from airflow.providers.google.cloud.utils.credentials_provider import ( + get_credentials_and_project_id, +) +from airflow.utils.log.file_task_handler import FileTaskHandler +from airflow.utils.log.logging_mixin import LoggingMixin +from google.api_core.client_info import ClientInfo +from google.cloud import storage + +_DEFAULT_SCOPESS = frozenset( + [ + "https://www.googleapis.com/auth/devstorage.read_write", + ] +) + + +class GCSTaskHandler(FileTaskHandler, LoggingMixin): + """ + GCSTaskHandler is a python log handler that handles and reads + task instance logs. It extends airflow FileTaskHandler and + uploads to and reads from GCS remote storage. Upon log reading + failure, it reads from host machine's local disk. + + :param base_log_folder: Base log folder to place logs. + :type base_log_folder: str + :param gcs_log_folder: Path to a remote location where logs will be saved. It must have the prefix + ``gs://``. For example: ``gs://bucket/remote/log/location`` + :type gcs_log_folder: str + :param filename_template: template filename string + :type filename_template: str + :param gcp_key_path: Path to Google Cloud Service Account file (JSON). Mutually exclusive with + gcp_keyfile_dict. + If omitted, authorization based on `the Application Default Credentials + `__ will + be used. + :type gcp_key_path: str + :param gcp_keyfile_dict: Dictionary of keyfile parameters. Mutually exclusive with gcp_key_path. + :type gcp_keyfile_dict: dict + :param gcp_scopes: Comma-separated string containing OAuth2 scopes + :type gcp_scopes: str + :param project_id: Project ID to read the secrets from. If not passed, the project ID from credentials + will be used. + :type project_id: str + """ + + def __init__( + self, + *, + base_log_folder: str, + gcs_log_folder: str, + filename_template: str, + gcp_key_path: Optional[str] = None, + gcp_keyfile_dict: Optional[dict] = None, + # See: https://github.com/PyCQA/pylint/issues/2377 + gcp_scopes: Optional[ + Collection[str] + ] = _DEFAULT_SCOPESS, # pylint: disable=unsubscriptable-object + project_id: Optional[str] = None, + ): + super().__init__(base_log_folder, filename_template) + self.remote_base = gcs_log_folder + self.log_relative_path = "" + self._hook = None + self.closed = False + self.upload_on_close = True + self.gcp_key_path = gcp_key_path + self.gcp_keyfile_dict = gcp_keyfile_dict + self.scopes = gcp_scopes + self.project_id = project_id + + @cached_property + def client(self) -> storage.Client: + """Returns GCS Client.""" + credentials, project_id = get_credentials_and_project_id( + key_path=self.gcp_key_path, + keyfile_dict=self.gcp_keyfile_dict, + scopes=self.scopes, + disable_logging=True, + ) + return storage.Client( + credentials=credentials, + client_info=ClientInfo( + client_library_version="airflow_v" + version.version + ), + project=self.project_id if self.project_id else project_id, + ) + + def set_context(self, ti): + super().set_context(ti) + # Log relative path is used to construct local and remote + # log path to upload log files into GCS and read from the + # remote location. + self.log_relative_path = self._render_filename(ti, ti.try_number) + self.upload_on_close = not ti.raw + + def close(self): + """Close and upload local log file to remote storage GCS.""" + # When application exit, system shuts down all handlers by + # calling close method. Here we check if logger is already + # closed to prevent uploading the log to remote storage multiple + # times when `logging.shutdown` is called. + if self.closed: + return + + super().close() + + if not self.upload_on_close: + return + + local_loc = os.path.join(self.local_base, self.log_relative_path) + remote_loc = os.path.join(self.remote_base, self.log_relative_path) + if os.path.exists(local_loc): + # read log and remove old logs to get just the latest additions + with open(local_loc) as logfile: + log = logfile.read() + self.gcs_write(log, remote_loc) + + # Mark closed so we don't double write if close is called twice + self.closed = True + + def _read(self, ti, try_number, metadata=None): + """ + Read logs of given task instance and try_number from GCS. + If failed, read the log from task instance host machine. + + :param ti: task instance object + :param try_number: task instance try_number to read logs from + :param metadata: log metadata, + can be used for steaming log reading and auto-tailing. + """ + # Explicitly getting log relative path is necessary as the given + # task instance might be different than task instance passed in + # in set_context method. + log_relative_path = self._render_filename(ti, try_number) + remote_loc = os.path.join(self.remote_base, log_relative_path) + + try: + blob = storage.Blob.from_string(remote_loc, self.client) + remote_log = blob.download_as_bytes().decode() + log = f"*** Reading remote log from {remote_loc}.\n{remote_log}\n" + return log, {"end_of_log": True} + except Exception as e: # pylint: disable=broad-except + log = f"*** Unable to read remote log from {remote_loc}\n*** {str(e)}\n\n" + self.log.error(log) + local_log, metadata = super()._read(ti, try_number) + log += local_log + return log, metadata + + def gcs_write(self, log, remote_log_location): + """ + Writes the log to the remote_log_location. Fails silently if no log + was created. + + :param log: the log to write to the remote_log_location + :type log: str + :param remote_log_location: the log's location in remote storage + :type remote_log_location: str (path) + """ + try: + blob = storage.Blob.from_string(remote_log_location, self.client) + old_log = blob.download_as_bytes().decode() + log = "\n".join([old_log, log]) if old_log else log + except Exception as e: # pylint: disable=broad-except + if ( + not hasattr(e, "resp") or e.resp.get("status") != "404" + ): # pylint: disable=no-member + log = f"*** Previous log discarded: {str(e)}\n\n" + log + self.log.info("Previous log discarded: %s", e) + + try: + blob = storage.Blob.from_string(remote_log_location, self.client) + blob.upload_from_string(log, content_type="text/plain") + except Exception as e: # pylint: disable=broad-except + self.log.error("Could not write logs to %s: %s", remote_log_location, e) diff --git a/reference/providers/google/cloud/log/stackdriver_task_handler.py b/reference/providers/google/cloud/log/stackdriver_task_handler.py new file mode 100644 index 0000000..6c0f0e3 --- /dev/null +++ b/reference/providers/google/cloud/log/stackdriver_task_handler.py @@ -0,0 +1,397 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Handler that integrates with Stackdriver""" +import logging +from typing import Collection, Dict, List, Optional, Tuple, Type +from urllib.parse import urlencode + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow import version +from airflow.models import TaskInstance +from airflow.providers.google.cloud.utils.credentials_provider import ( + get_credentials_and_project_id, +) +from google.api_core.gapic_v1.client_info import ClientInfo +from google.auth.credentials import Credentials +from google.cloud import logging as gcp_logging +from google.cloud.logging import Resource +from google.cloud.logging.handlers.transports import ( + BackgroundThreadTransport, + Transport, +) +from google.cloud.logging_v2.services.logging_service_v2 import LoggingServiceV2Client +from google.cloud.logging_v2.types import ListLogEntriesRequest, ListLogEntriesResponse + +DEFAULT_LOGGER_NAME = "airflow" +_GLOBAL_RESOURCE = Resource(type="global", labels={}) + +_DEFAULT_SCOPESS = frozenset( + [ + "https://www.googleapis.com/auth/logging.read", + "https://www.googleapis.com/auth/logging.write", + ] +) + + +class StackdriverTaskHandler(logging.Handler): + """Handler that directly makes Stackdriver logging API calls. + + This is a Python standard ``logging`` handler using that can be used to + route Python standard logging messages directly to the Stackdriver + Logging API. + + It can also be used to save logs for executing tasks. To do this, you should set as a handler with + the name "tasks". In this case, it will also be used to read the log for display in Web UI. + + This handler supports both an asynchronous and synchronous transport. + + :param gcp_key_path: Path to Google Cloud Credential JSON file. + If omitted, authorization based on `the Application Default Credentials + `__ will + be used. + :type gcp_key_path: str + :param scopes: OAuth scopes for the credentials, + :type scopes: Sequence[str] + :param name: the name of the custom log in Stackdriver Logging. Defaults + to 'airflow'. The name of the Python logger will be represented + in the ``python_logger`` field. + :type name: str + :param transport: Class for creating new transport objects. It should + extend from the base :class:`google.cloud.logging.handlers.Transport` type and + implement :meth`google.cloud.logging.handlers.Transport.send`. Defaults to + :class:`google.cloud.logging.handlers.BackgroundThreadTransport`. The other + option is :class:`google.cloud.logging.handlers.SyncTransport`. + :type transport: :class:`type` + :param re# (Optional) Monitored resource of the entry, defaults + to the global resource type. + :type re# :class:`~google.cloud.logging.resource.Resource` + :param labels: (Optional) Mapping of labels for the entry. + :type labels: dict + """ + + LABEL_TASK_ID = "task_id" + LABEL_DAG_ID = "dag_id" + LABEL_EXECUTION_DATE = "execution_date" + LABEL_TRY_NUMBER = "try_number" + LOG_VIEWER_BASE_URL = "https://console.cloud.google.com/logs/viewer" + LOG_NAME = "Google Stackdriver" + + def __init__( + self, + gcp_key_path: Optional[str] = None, + # See: https://github.com/PyCQA/pylint/issues/2377 + scopes: Optional[ + Collection[str] + ] = _DEFAULT_SCOPESS, # pylint: disable=unsubscriptable-object + name: str = DEFAULT_LOGGER_NAME, + transport: Type[Transport] = BackgroundThreadTransport, + re# Resource = _GLOBAL_RESOURCE, + labels: Optional[Dict[str, str]] = None, + ): + super().__init__() + self.gcp_key_path: Optional[str] = gcp_key_path + # See: https://github.com/PyCQA/pylint/issues/2377 + self.scopes: Optional[ + Collection[str] + ] = scopes # pylint: disable=unsubscriptable-object + self.name: str = name + self.transport_type: Type[Transport] = transport + self.re# Resource = resource + self.labels: Optional[Dict[str, str]] = labels + self.task_instance_labels: Optional[Dict[str, str]] = {} + self.task_instance_hostname = "default-hostname" + + @cached_property + def _credentials_and_project(self) -> Tuple[Credentials, str]: + credentials, project = get_credentials_and_project_id( + key_path=self.gcp_key_path, scopes=self.scopes, disable_logging=True + ) + return credentials, project + + @property + def _client(self) -> gcp_logging.Client: + """The Cloud Library API client""" + credentials, project = self._credentials_and_project + client = gcp_logging.Client( + credentials=credentials, + project=project, + client_info=ClientInfo( + client_library_version="airflow_v" + version.version + ), + ) + return client + + @property + def _logging_service_client(self) -> LoggingServiceV2Client: + """The Cloud logging service v2 client.""" + credentials, _ = self._credentials_and_project + client = LoggingServiceV2Client( + credentials=credentials, + client_info=ClientInfo( + client_library_version="airflow_v" + version.version + ), + ) + return client + + @cached_property + def _transport(self) -> Transport: + """Object responsible for sending data to Stackdriver""" + return self.transport_type(self._client, self.name) + + def emit(self, record: logging.LogRecord) -> None: + """Actually log the specified logging record. + + :param record: The record to be logged. + :type record: logging.LogRecord + """ + message = self.format(record) + labels: Optional[Dict[str, str]] + if self.labels and self.task_instance_labels: + labels = {} + labels.update(self.labels) + labels.update(self.task_instance_labels) + elif self.labels: + labels = self.labels + elif self.task_instance_labels: + labels = self.task_instance_labels + else: + labels = None + self._transport.send(record, message, resource=self.resource, labels=labels) + + def set_context(self, task_instance: TaskInstance) -> None: + """ + Configures the logger to add information with information about the current task + + :param task_instance: Currently executed task + :type task_instance: :class:`airflow.models.TaskInstance` + """ + self.task_instance_labels = self._task_instance_to_labels(task_instance) + self.task_instance_hostname = task_instance.hostname + + def read( + self, + task_instance: TaskInstance, + try_number: Optional[int] = None, + metadata: Optional[Dict] = None, + ) -> Tuple[List[Tuple[Tuple[str, str]]], List[Dict[str, str]]]: + """ + Read logs of given task instance from Stackdriver logging. + + :param task_instance: task instance object + :type task_instance: :class:`airflow.models.TaskInstance` + :param try_number: task instance try_number to read logs from. If None + it returns all logs + :type try_number: Optional[int] + :param metadata: log metadata. It is used for steaming log reading and auto-tailing. + :type metadata: Dict + :return: a tuple of ( + list of (one element tuple with two element tuple - hostname and logs) + and list of metadata) + :rtype: Tuple[List[Tuple[Tuple[str, str]]], List[Dict[str, str]]] + """ + if try_number is not None and try_number < 1: + logs = f"Error fetching the logs. Try number {try_number} is invalid." + return [((self.task_instance_hostname, logs),)], [{"end_of_log": "true"}] + + if not metadata: + metadata = {} + + ti_labels = self._task_instance_to_labels(task_instance) + + if try_number is not None: + ti_labels[self.LABEL_TRY_NUMBER] = str(try_number) + else: + del ti_labels[self.LABEL_TRY_NUMBER] + + log_filter = self._prepare_log_filter(ti_labels) + next_page_token = metadata.get("next_page_token", None) + all_pages = "download_logs" in metadata and metadata["download_logs"] + + messages, end_of_log, next_page_token = self._read_logs( + log_filter, next_page_token, all_pages + ) + + new_metadata = {"end_of_log": end_of_log} + + if next_page_token: + new_metadata["next_page_token"] = next_page_token + + return [((self.task_instance_hostname, messages),)], [new_metadata] + + def _prepare_log_filter(self, ti_labels: Dict[str, str]) -> str: + """ + Prepares the filter that chooses which log entries to fetch. + + More information: + https://cloud.google.com/logging/docs/reference/v2/rest/v2/entries/list#body.request_body.FIELDS.filter + https://cloud.google.com/logging/docs/view/advanced-queries + + :param ti_labels: Task Instance's labels that will be used to search for logs + :type: Dict[str, str] + :return: logs filter + """ + + def escape_label_key(key: str) -> str: + return f'"{key}"' if "." in key else key + + def escale_label_value(value: str) -> str: + escaped_value = value.replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped_value}"' + + _, project = self._credentials_and_project + log_filters = [ + f"resource.type={escale_label_value(self.resource.type)}", + f'logName="projects/{project}/logs/{self.name}"', + ] + + for key, value in self.resource.labels.items(): + log_filters.append( + f"resource.labels.{escape_label_key(key)}={escale_label_value(value)}" + ) + + for key, value in ti_labels.items(): + log_filters.append( + f"labels.{escape_label_key(key)}={escale_label_value(value)}" + ) + return "\n".join(log_filters) + + def _read_logs( + self, log_filter: str, next_page_token: Optional[str], all_pages: bool + ) -> Tuple[str, bool, Optional[str]]: + """ + Sends requests to the Stackdriver service and downloads logs. + + :param log_filter: Filter specifying the logs to be downloaded. + :type log_filter: str + :param next_page_token: The token of the page from which the log download will start. + If None is passed, it will start from the first page. + :param all_pages: If True is passed, all subpages will be downloaded. Otherwise, only the first + page will be downloaded + :return: A token that contains the following items: + * string with logs + * Boolean value describing whether there are more logs, + * token of the next page + :rtype: Tuple[str, bool, str] + """ + messages = [] + new_messages, next_page_token = self._read_single_logs_page( + log_filter=log_filter, + page_token=next_page_token, + ) + messages.append(new_messages) + if all_pages: + while next_page_token: + new_messages, next_page_token = self._read_single_logs_page( + log_filter=log_filter, page_token=next_page_token + ) + messages.append(new_messages) + if not messages: + break + + end_of_log = True + next_page_token = None + else: + end_of_log = not bool(next_page_token) + return "\n".join(messages), end_of_log, next_page_token + + def _read_single_logs_page( + self, log_filter: str, page_token: Optional[str] = None + ) -> Tuple[str, str]: + """ + Sends requests to the Stackdriver service and downloads single pages with logs. + + :param log_filter: Filter specifying the logs to be downloaded. + :type log_filter: str + :param page_token: The token of the page to be downloaded. If None is passed, the first page will be + downloaded. + :type page_token: str + :return: Downloaded logs and next page token + :rtype: Tuple[str, str] + """ + _, project = self._credentials_and_project + request = ListLogEntriesRequest( + resource_names=[f"projects/{project}"], + filter=log_filter, + page_token=page_token, + order_by="timestamp asc", + page_size=1000, + ) + response = self._logging_service_client.list_log_entries(request=request) + page: ListLogEntriesResponse = next(response.pages) + messages = [] + for entry in page.entries: + if "message" in entry.json_payload: + messages.append(entry.json_payload["message"]) + return "\n".join(messages), page.next_page_token + + @classmethod + def _task_instance_to_labels(cls, ti: TaskInstance) -> Dict[str, str]: + return { + cls.LABEL_TASK_ID: ti.task_id, + cls.LABEL_DAG_ID: ti.dag_id, + cls.LABEL_EXECUTION_DATE: str(ti.execution_date.isoformat()), + cls.LABEL_TRY_NUMBER: str(ti.try_number), + } + + @property + def log_name(self): + """Return log name.""" + return self.LOG_NAME + + @cached_property + def _resource_path(self): + segments = [self.resource.type] + + for key, value in self.resource.labels: + segments += [key] + segments += [value] + + return "/".join(segments) + + def get_external_log_url(self, task_instance: TaskInstance, try_number: int) -> str: + """ + Creates an address for an external log collecting service. + :param task_instance: task instance object + :type: task_instance: TaskInstance + :param try_number: task instance try_number to read logs from. + :type try_number: Optional[int] + :return: URL to the external log collection service + :rtype: str + """ + _, project_id = self._credentials_and_project + + ti_labels = self._task_instance_to_labels(task_instance) + ti_labels[self.LABEL_TRY_NUMBER] = str(try_number) + + log_filter = self._prepare_log_filter(ti_labels) + + url_query_string = { + "project": project_id, + "interval": "NO_LIMIT", + "resource": self._resource_path, + "advancedFilter": log_filter, + } + + url = f"{self.LOG_VIEWER_BASE_URL}?{urlencode(url_query_string)}" + return url + + def close(self) -> None: + self._transport.flush() diff --git a/reference/providers/google/cloud/operators/__init__.py b/reference/providers/google/cloud/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/cloud/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/cloud/operators/automl.py b/reference/providers/google/cloud/operators/automl.py new file mode 100644 index 0000000..63b8c7b --- /dev/null +++ b/reference/providers/google/cloud/operators/automl.py @@ -0,0 +1,1284 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# pylint: disable=too-many-lines +"""This module contains Google AutoML operators.""" +import ast +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.utils.decorators import apply_defaults +from google.api_core.retry import Retry +from google.cloud.automl_v1beta1 import ( + BatchPredictResult, + ColumnSpec, + Dataset, + Model, + PredictResponse, + TableSpec, +) + +MetaData = Sequence[Tuple[str, str]] + + +class AutoMLTrainModelOperator(BaseOperator): + """ + Creates Google Cloud AutoML model. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLTrainModelOperator` + + :param model: Model definition. + :type model: dict + :param project_id: ID of the Google Cloud project where model will be created if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "model", + "location", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + model: dict, + location: str, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.model = model + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Creating model.") + operation = hook.create_model( + model=self.model, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = Model.to_dict(operation.result()) + model_id = hook.extract_object_id(result) + self.log.info("Model created: %s", model_id) + + self.xcom_push(context, key="model_id", value=model_id) + return result + + +class AutoMLPredictOperator(BaseOperator): + """ + Runs prediction operation on Google Cloud AutoML. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLPredictOperator` + + :param model_id: Name of the model requested to serve the batch prediction. + :type model_id: str + :param payload: Name od the model used for the prediction. + :type payload: dict + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param params: Additional domain-specific parameters for the predictions. + :type params: Optional[Dict[str, str]] + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "model_id", + "location", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + model_id: str, + location: str, + payload: dict, + params: Optional[Dict[str, str]] = None, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.model_id = model_id + self.params = params # type: ignore + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.payload = payload + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + result = hook.predict( + model_id=self.model_id, + payload=self.payload, + location=self.location, + project_id=self.project_id, + params=self.params, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return PredictResponse.to_dict(result) + + +class AutoMLBatchPredictOperator(BaseOperator): + """ + Perform a batch prediction on Google Cloud AutoML. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLBatchPredictOperator` + + :param project_id: ID of the Google Cloud project where model will be created if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param model_id: Name of the model_id requested to serve the batch prediction. + :type model_id: str + :param input_config: Required. The input configuration for batch prediction. + If a dict is provided, it must be of the same form as the protobuf message + `google.cloud.automl_v1beta1.types.BatchPredictInputConfig` + :type input_config: Union[dict, ~google.cloud.automl_v1beta1.types.BatchPredictInputConfig] + :param output_config: Required. The Configuration specifying where output predictions should be + written. If a dict is provided, it must be of the same form as the protobuf message + `google.cloud.automl_v1beta1.types.BatchPredictOutputConfig` + :type output_config: Union[dict, ~google.cloud.automl_v1beta1.types.BatchPredictOutputConfig] + :param prediction_params: Additional domain-specific parameters for the predictions, + any string must be up to 25000 characters long. + :type prediction_params: Optional[Dict[str, str]] + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "model_id", + "input_config", + "output_config", + "location", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + model_id: str, + input_config: dict, + output_config: dict, + location: str, + project_id: Optional[str] = None, + prediction_params: Optional[Dict[str, str]] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.model_id = model_id + self.location = location + self.project_id = project_id + self.prediction_params = prediction_params + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.input_config = input_config + self.output_config = output_config + + def execute(self, context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Fetch batch prediction.") + operation = hook.batch_predict( + model_id=self.model_id, + input_config=self.input_config, + output_config=self.output_config, + project_id=self.project_id, + location=self.location, + params=self.prediction_params, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = BatchPredictResult.to_dict(operation.result()) + self.log.info("Batch prediction ready.") + return result + + +class AutoMLCreateDatasetOperator(BaseOperator): + """ + Creates a Google Cloud AutoML dataset. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLCreateDatasetOperator` + + :param dataset: The dataset to create. If a dict is provided, it must be of the + same form as the protobuf message Dataset. + :type dataset: Union[dict, Dataset] + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param params: Additional domain-specific parameters for the predictions. + :type params: Optional[Dict[str, str]] + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "dataset", + "location", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + dataset: dict, + location: str, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.dataset = dataset + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Creating dataset") + result = hook.create_dataset( + dataset=self.dataset, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = Dataset.to_dict(result) + dataset_id = hook.extract_object_id(result) + self.log.info("Creating completed. Dataset id: %s", dataset_id) + + self.xcom_push(context, key="dataset_id", value=dataset_id) + return result + + +class AutoMLImportDataOperator(BaseOperator): + """ + Imports data to a Google Cloud AutoML dataset. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLImportDataOperator` + + :param dataset_id: ID of dataset to be updated. + :type dataset_id: str + :param input_config: The desired input location and its domain specific semantics, if any. + If a dict is provided, it must be of the same form as the protobuf message InputConfig. + :type input_config: dict + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param params: Additional domain-specific parameters for the predictions. + :type params: Optional[Dict[str, str]] + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "dataset_id", + "input_config", + "location", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + dataset_id: str, + location: str, + input_config: dict, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.dataset_id = dataset_id + self.input_config = input_config + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Importing dataset") + operation = hook.import_data( + dataset_id=self.dataset_id, + input_config=self.input_config, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + operation.result() + self.log.info("Import completed") + + +class AutoMLTablesListColumnSpecsOperator(BaseOperator): + """ + Lists column specs in a table. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLTablesListColumnSpecsOperator` + + :param dataset_id: Name of the dataset. + :type dataset_id: str + :param table_spec_id: table_spec_id for path builder. + :type table_spec_id: str + :param field_mask: Mask specifying which fields to read. If a dict is provided, it must be of the same + form as the protobuf message `google.cloud.automl_v1beta1.types.FieldMask` + :type field_mask: Union[dict, google.cloud.automl_v1beta1.types.FieldMask] + :param filter_: Filter expression, see go/filtering. + :type filter_: str + :param page_size: The maximum number of resources contained in the + underlying API response. If page streaming is performed per + resource, this parameter does not affect the return value. If page + streaming is performed per page, this determines the maximum number + of resources in a page. + :type page_size: int + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "dataset_id", + "table_spec_id", + "field_mask", + "filter_", + "location", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + dataset_id: str, + table_spec_id: str, + location: str, + field_mask: Optional[dict] = None, + filter_: Optional[str] = None, + page_size: Optional[int] = None, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataset_id = dataset_id + self.table_spec_id = table_spec_id + self.field_mask = field_mask + self.filter_ = filter_ + self.page_size = page_size + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Requesting column specs.") + page_iterator = hook.list_column_specs( + dataset_id=self.dataset_id, + table_spec_id=self.table_spec_id, + field_mask=self.field_mask, + filter_=self.filter_, + page_size=self.page_size, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = [ColumnSpec.to_dict(spec) for spec in page_iterator] + self.log.info("Columns specs obtained.") + + return result + + +class AutoMLTablesUpdateDatasetOperator(BaseOperator): + """ + Updates a dataset. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLTablesUpdateDatasetOperator` + + :param dataset: The dataset which replaces the resource on the server. + If a dict is provided, it must be of the same form as the protobuf message Dataset. + :type dataset: Union[dict, Dataset] + :param update_mask: The update mask applies to the resource. If a dict is provided, it must + be of the same form as the protobuf message FieldMask. + :type update_mask: Union[dict, FieldMask] + :param location: The location of the project. + :type location: str + :param params: Additional domain-specific parameters for the predictions. + :type params: Optional[Dict[str, str]] + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "dataset", + "update_mask", + "location", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + dataset: dict, + location: str, + update_mask: Optional[dict] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.dataset = dataset + self.update_mask = update_mask + self.location = location + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Updating AutoML dataset %s.", self.dataset["name"]) + result = hook.update_dataset( + dataset=self.dataset, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Dataset updated.") + return Dataset.to_dict(result) + + +class AutoMLGetModelOperator(BaseOperator): + """ + Get Google Cloud AutoML model. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLGetModelOperator` + + :param model_id: Name of the model requested to serve the prediction. + :type model_id: str + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param params: Additional domain-specific parameters for the predictions. + :type params: Optional[Dict[str, str]] + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "model_id", + "location", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + model_id: str, + location: str, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.model_id = model_id + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + result = hook.get_model( + model_id=self.model_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Model.to_dict(result) + + +class AutoMLDeleteModelOperator(BaseOperator): + """ + Delete Google Cloud AutoML model. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLDeleteModelOperator` + + :param model_id: Name of the model requested to serve the prediction. + :type model_id: str + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param params: Additional domain-specific parameters for the predictions. + :type params: Optional[Dict[str, str]] + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "model_id", + "location", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + model_id: str, + location: str, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.model_id = model_id + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + operation = hook.delete_model( + model_id=self.model_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + operation.result() + + +class AutoMLDeployModelOperator(BaseOperator): + """ + Deploys a model. If a model is already deployed, deploying it with the same parameters + has no effect. Deploying with different parameters (as e.g. changing node_number) will + reset the deployment state without pausing the model_id’s availability. + + Only applicable for Text Classification, Image Object Detection and Tables; all other + domains manage deployment automatically. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLDeployModelOperator` + + :param model_id: Name of the model to be deployed. + :type model_id: str + :param image_detection_metadata: Model deployment metadata specific to Image Object Detection. + If a dict is provided, it must be of the same form as the protobuf message + ImageObjectDetectionModelDeploymentMetadata + :type image_detection_metadata: dict + :param project_id: ID of the Google Cloud project where model is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param params: Additional domain-specific parameters for the predictions. + :type params: Optional[Dict[str, str]] + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "model_id", + "location", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + model_id: str, + location: str, + project_id: Optional[str] = None, + image_detection_metadata: Optional[dict] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.model_id = model_id + self.image_detection_metadata = image_detection_metadata + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Deploying model_id %s", self.model_id) + + operation = hook.deploy_model( + model_id=self.model_id, + location=self.location, + project_id=self.project_id, + image_detection_metadata=self.image_detection_metadata, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + operation.result() + self.log.info("Model deployed.") + + +class AutoMLTablesListTableSpecsOperator(BaseOperator): + """ + Lists table specs in a dataset. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLTablesListTableSpecsOperator` + + :param dataset_id: Name of the dataset. + :type dataset_id: str + :param filter_: Filter expression, see go/filtering. + :type filter_: str + :param page_size: The maximum number of resources contained in the + underlying API response. If page streaming is performed per + resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number + of resources in a page. + :type page_size: int + :param project_id: ID of the Google Cloud project if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "dataset_id", + "filter_", + "location", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + dataset_id: str, + location: str, + page_size: Optional[int] = None, + filter_: Optional[str] = None, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataset_id = dataset_id + self.filter_ = filter_ + self.page_size = page_size + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Requesting table specs for %s.", self.dataset_id) + page_iterator = hook.list_table_specs( + dataset_id=self.dataset_id, + filter_=self.filter_, + page_size=self.page_size, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = [TableSpec.to_dict(spec) for spec in page_iterator] + self.log.info(result) + self.log.info("Table specs obtained.") + return result + + +class AutoMLListDatasetOperator(BaseOperator): + """ + Lists AutoML Datasets in project. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLListDatasetOperator` + + :param project_id: ID of the Google Cloud project where datasets are located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Requesting datasets") + page_iterator = hook.list_datasets( + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = [Dataset.to_dict(dataset) for dataset in page_iterator] + self.log.info("Datasets obtained.") + + self.xcom_push( + context, + key="dataset_id_list", + value=[hook.extract_object_id(d) for d in result], + ) + return result + + +class AutoMLDeleteDatasetOperator(BaseOperator): + """ + Deletes a dataset and all of its contents. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AutoMLDeleteDatasetOperator` + + :param dataset_id: Name of the dataset_id, list of dataset_id or string of dataset_id + coma separated to be deleted. + :type dataset_id: Union[str, List[str]] + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :type project_id: str + :param location: The location of the project. + :type location: str + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "dataset_id", + "location", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + dataset_id: Union[str, List[str]], + location: str, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.dataset_id = dataset_id + self.location = location + self.project_id = project_id + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @staticmethod + def _parse_dataset_id(dataset_id: Union[str, List[str]]) -> List[str]: + if not isinstance(dataset_id, str): + return dataset_id + try: + return ast.literal_eval(dataset_id) + except (SyntaxError, ValueError): + return dataset_id.split(",") + + def execute(self, context): + hook = CloudAutoMLHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + dataset_id_list = self._parse_dataset_id(self.dataset_id) + for dataset_id in dataset_id_list: + self.log.info("Deleting dataset %s", dataset_id) + hook.delete_dataset( + dataset_id=dataset_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Dataset deleted.") diff --git a/reference/providers/google/cloud/operators/bigquery.py b/reference/providers/google/cloud/operators/bigquery.py new file mode 100644 index 0000000..a6804dc --- /dev/null +++ b/reference/providers/google/cloud/operators/bigquery.py @@ -0,0 +1,2265 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=too-many-lines +"""This module contains Google BigQuery operators.""" +import enum +import hashlib +import json +import re +import uuid +import warnings +from datetime import datetime +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + SupportsAbs, + Union, +) + +import attr +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator, BaseOperatorLink +from airflow.models.taskinstance import TaskInstance +from airflow.operators.sql import ( + SQLCheckOperator, + SQLIntervalCheckOperator, + SQLValueCheckOperator, +) +from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob +from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url +from airflow.utils.decorators import apply_defaults +from google.api_core.exceptions import Conflict +from google.cloud.bigquery import TableReference + +BIGQUERY_JOB_DETAILS_LINK_FMT = "https://console.cloud.google.com/bigquery?j={job_id}" + +_DEPRECATION_MSG = "The bigquery_conn_id parameter has been deprecated. You should pass the gcp_conn_id parameter." + + +class BigQueryUIColors(enum.Enum): + """Hex colors for BigQuery operators""" + + CHECK = "#C0D7FF" + QUERY = "#A1BBFF" + TABLE = "#81A0FF" + DATASET = "#5F86FF" + + +class BigQueryConsoleLink(BaseOperatorLink): + """Helper class for constructing BigQuery link.""" + + name = "BigQuery Console" + + def get_link(self, operator, dttm): + ti = TaskInstance(task=operator, execution_date=dttm) + job_id = ti.xcom_pull(task_ids=operator.task_id, key="job_id") + return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) if job_id else "" + + +@attr.s(auto_attribs=True) +class BigQueryConsoleIndexableLink(BaseOperatorLink): + """Helper class for constructing BigQuery link.""" + + index: int = attr.ib() + + @property + def name(self) -> str: + return f"BigQuery Console #{self.index + 1}" + + def get_link(self, operator: BaseOperator, dttm: datetime): + ti = TaskInstance(task=operator, execution_date=dttm) + job_ids = ti.xcom_pull(task_ids=operator.task_id, key="job_id") + if not job_ids: + return None + if len(job_ids) < self.index: + return None + job_id = job_ids[self.index] + return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) + + +class _BigQueryDbHookMixin: + def get_db_hook(self) -> BigQueryHook: + """Get BigQuery DB Hook""" + return BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + use_legacy_sql=self.use_legacy_sql, + location=self.location, + impersonation_chain=self.impersonation_chain, + labels=self.labels, + ) + + +class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator): + """ + Performs checks against BigQuery. The ``BigQueryCheckOperator`` expects + a sql query that will return a single row. Each value on that + first row is evaluated using python ``bool`` casting. If any of the + values return ``False`` the check is failed and errors out. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryCheckOperator` + + Note that Python bool casting evals the following as ``False``: + + * ``False`` + * ``0`` + * Empty string (``""``) + * Empty list (``[]``) + * Empty dictionary or set (``{}``) + + Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if + the count ``== 0``. You can craft much more complex query that could, + for instance, check that the table has the same number of rows as + the source table upstream, or that the count of today's partition is + greater than yesterday's partition, or that a set of metrics are less + than 3 standard deviation for the 7 day average. + + This operator can be used as a data quality check in your pipeline, and + depending on where you put it in your DAG, you have the choice to + stop the critical path, preventing from + publishing dubious data, or on the side and receive email alerts + without stopping the progress of the DAG. + + :param sql: the sql to be executed + :type sql: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param use_legacy_sql: Whether to use legacy SQL (true) + or standard SQL (false). + :type use_legacy_sql: bool + :param location: The geographic location of the job. See details at: + https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :type location: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + :param labels: a dictionary containing labels for the table, passed to BigQuery + :type labels: dict + """ + + template_fields = ( + "sql", + "gcp_conn_id", + "impersonation_chain", + "labels", + ) + template_ext = (".sql",) + ui_color = BigQueryUIColors.CHECK.value + + @apply_defaults + def __init__( + self, + *, + sql: str, + gcp_conn_id: str = "google_cloud_default", + bigquery_conn_id: Optional[str] = None, + use_legacy_sql: bool = True, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + labels: Optional[dict] = None, + **kwargs, + ) -> None: + super().__init__(sql=sql, **kwargs) + if bigquery_conn_id: + warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3) + gcp_conn_id = bigquery_conn_id + + self.gcp_conn_id = gcp_conn_id + self.sql = sql + self.use_legacy_sql = use_legacy_sql + self.location = location + self.impersonation_chain = impersonation_chain + self.labels = labels + + +class BigQueryValueCheckOperator(_BigQueryDbHookMixin, SQLValueCheckOperator): + """ + Performs a simple value check using sql code. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryValueCheckOperator` + + :param sql: the sql to be executed + :type sql: str + :param use_legacy_sql: Whether to use legacy SQL (true) + or standard SQL (false). + :type use_legacy_sql: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param location: The geographic location of the job. See details at: + https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :type location: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + :param labels: a dictionary containing labels for the table, passed to BigQuery + :type labels: dict + """ + + template_fields = ( + "sql", + "gcp_conn_id", + "pass_value", + "impersonation_chain", + "labels", + ) + template_ext = (".sql",) + ui_color = BigQueryUIColors.CHECK.value + + @apply_defaults + def __init__( + self, + *, + sql: str, + pass_value: Any, + tolerance: Any = None, + gcp_conn_id: str = "google_cloud_default", + bigquery_conn_id: Optional[str] = None, + use_legacy_sql: bool = True, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + labels: Optional[dict] = None, + **kwargs, + ) -> None: + super().__init__(sql=sql, pass_value=pass_value, tolerance=tolerance, **kwargs) + + if bigquery_conn_id: + warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3) + gcp_conn_id = bigquery_conn_id + + self.location = location + self.gcp_conn_id = gcp_conn_id + self.use_legacy_sql = use_legacy_sql + self.impersonation_chain = impersonation_chain + self.labels = labels + + +class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperator): + """ + Checks that the values of metrics given as SQL expressions are within + a certain tolerance of the ones from days_back before. + + This method constructs a query like so :: + + SELECT {metrics_threshold_dict_key} FROM {table} + WHERE {date_filter_column}= + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryIntervalCheckOperator` + + :param table: the table name + :type table: str + :param days_back: number of days between ds and the ds we want to check + against. Defaults to 7 days + :type days_back: int + :param metrics_thresholds: a dictionary of ratios indexed by metrics, for + example 'COUNT(*)': 1.5 would require a 50 percent or less difference + between the current day, and the prior days_back. + :type metrics_thresholds: dict + :param use_legacy_sql: Whether to use legacy SQL (true) + or standard SQL (false). + :type use_legacy_sql: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param location: The geographic location of the job. See details at: + https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :type location: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + :param labels: a dictionary containing labels for the table, passed to BigQuery + :type labels: dict + """ + + template_fields = ( + "table", + "gcp_conn_id", + "sql1", + "sql2", + "impersonation_chain", + "labels", + ) + ui_color = BigQueryUIColors.CHECK.value + + @apply_defaults + def __init__( + self, + *, + table: str, + metrics_thresholds: dict, + date_filter_column: str = "ds", + days_back: SupportsAbs[int] = -7, + gcp_conn_id: str = "google_cloud_default", + bigquery_conn_id: Optional[str] = None, + use_legacy_sql: bool = True, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + labels: Optional[Dict] = None, + **kwargs, + ) -> None: + super().__init__( + table=table, + metrics_thresholds=metrics_thresholds, + date_filter_column=date_filter_column, + days_back=days_back, + **kwargs, + ) + + if bigquery_conn_id: + warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3) + gcp_conn_id = bigquery_conn_id + + self.gcp_conn_id = gcp_conn_id + self.use_legacy_sql = use_legacy_sql + self.location = location + self.impersonation_chain = impersonation_chain + self.labels = labels + + +class BigQueryGetDataOperator(BaseOperator): + """ + Fetches the data from a BigQuery table (alternatively fetch data for selected columns) + and returns data in a python list. The number of elements in the returned list will + be equal to the number of rows fetched. Each element in the list will again be a list + where element would represent the columns values for that row. + + **Example Result**: ``[['Tony', '10'], ['Mike', '20'], ['Steve', '15']]`` + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryGetDataOperator` + + .. note:: + If you pass fields to ``selected_fields`` which are in different order than the + order of columns already in + BQ table, the data will still be in the order of BQ table. + For example if the BQ table has 3 columns as + ``[A,B,C]`` and you pass 'B,A' in the ``selected_fields`` + the data would still be of the form ``'A,B'``. + + **Example**: :: + + get_data = BigQueryGetDataOperator( + task_id='get_data_from_bq', + dataset_id='test_dataset', + table_id='Transaction_partitions', + max_results=100, + selected_fields='DATE', + gcp_conn_id='airflow-conn-id' + ) + + :param dataset_id: The dataset ID of the requested table. (templated) + :type dataset_id: str + :param table_id: The table ID of the requested table. (templated) + :type table_id: str + :param max_results: The maximum number of records (rows) to be fetched + from the table. (templated) + :type max_results: int + :param selected_fields: List of fields to return (comma-separated). If + unspecified, all fields are returned. + :type selected_fields: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param location: The location used for the operation. + :type location: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "dataset_id", + "table_id", + "max_results", + "selected_fields", + "impersonation_chain", + ) + ui_color = BigQueryUIColors.QUERY.value + + @apply_defaults + def __init__( + self, + *, + dataset_id: str, + table_id: str, + max_results: int = 100, + selected_fields: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = bigquery_conn_id + + self.dataset_id = dataset_id + self.table_id = table_id + self.max_results = int(max_results) + self.selected_fields = selected_fields + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.location = location + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> list: + self.log.info( + "Fetching Data from %s.%s max results: %s", + self.dataset_id, + self.table_id, + self.max_results, + ) + + hook = BigQueryHook( + bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + rows = hook.list_rows( + dataset_id=self.dataset_id, + table_id=self.table_id, + max_results=self.max_results, + selected_fields=self.selected_fields, + location=self.location, + ) + + self.log.info("Total extracted rows: %s", len(rows)) + + table_data = [row.values() for row in rows] + return table_data + + +# pylint: disable=too-many-instance-attributes +class BigQueryExecuteQueryOperator(BaseOperator): + """ + Executes BigQuery SQL queries in a specific BigQuery database. + This operator does not assert idempotency. + + :param sql: the sql code to be executed (templated) + :type sql: Can receive a str representing a sql statement, + a list of str (sql statements), or reference to a template file. + Template reference are recognized by str ending in '.sql'. + :param destination_dataset_table: A dotted + ``(.|:).
`` that, if set, will store the results + of the query. (templated) + :type destination_dataset_table: str + :param write_disposition: Specifies the action that occurs if the destination table + already exists. (default: 'WRITE_EMPTY') + :type write_disposition: str + :param create_disposition: Specifies whether the job is allowed to create new tables. + (default: 'CREATE_IF_NEEDED') + :type create_disposition: str + :param allow_large_results: Whether to allow large results. + :type allow_large_results: bool + :param flatten_results: If true and query uses legacy SQL dialect, flattens + all nested and repeated fields in the query results. ``allow_large_results`` + must be ``true`` if this is set to ``false``. For standard SQL queries, this + flag is ignored and results are never flattened. + :type flatten_results: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param udf_config: The User Defined Function configuration for the query. + See https://cloud.google.com/bigquery/user-defined-functions for details. + :type udf_config: list + :param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false). + :type use_legacy_sql: bool + :param maximum_billing_tier: Positive integer that serves as a multiplier + of the basic price. + Defaults to None, in which case it uses the value set in the project. + :type maximum_billing_tier: int + :param maximum_bytes_billed: Limits the bytes billed for this job. + Queries that will have bytes billed beyond this limit will fail + (without incurring a charge). If unspecified, this will be + set to your project default. + :type maximum_bytes_billed: float + :param api_resource_configs: a dictionary that contain params + 'configuration' applied for Google BigQuery Jobs API: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs + for example, {'query': {'useQueryCache': False}}. You could use it + if you need to provide some params that are not supported by BigQueryOperator + like args. + :type api_resource_configs: dict + :param schema_update_options: Allows the schema of the destination + table to be updated as a side effect of the load job. + :type schema_update_options: Optional[Union[list, tuple, set]] + :param query_params: a list of dictionary containing query parameter types and + values, passed to BigQuery. The structure of dictionary should look like + 'queryParameters' in Google BigQuery Jobs API: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs. + For example, [{ 'name': 'corpus', 'parameterType': { 'type': 'STRING' }, + 'parameterValue': { 'value': 'romeoandjuliet' } }]. (templated) + :type query_params: list + :param labels: a dictionary containing labels for the job/query, + passed to BigQuery + :type labels: dict + :param priority: Specifies a priority for the query. + Possible values include INTERACTIVE and BATCH. + The default value is INTERACTIVE. + :type priority: str + :param time_partitioning: configure optional time partitioning fields i.e. + partition by field, type and expiration as per API specifications. + :type time_partitioning: dict + :param cluster_fields: Request that the result of this query be stored sorted + by one or more columns. BigQuery supports clustering for both partitioned and + non-partitioned tables. The order of columns given determines the sort order. + :type cluster_fields: list[str] + :param location: The geographic location of the job. Required except for + US and EU. See details at + https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :type location: str + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "sql", + "destination_dataset_table", + "labels", + "query_params", + "impersonation_chain", + ) + template_ext = (".sql",) + ui_color = BigQueryUIColors.QUERY.value + + @property + def operator_extra_links(self): + """Return operator extra links""" + if isinstance(self.sql, str): + return (BigQueryConsoleLink(),) + return (BigQueryConsoleIndexableLink(i) for i, _ in enumerate(self.sql)) + + # pylint: disable=too-many-arguments, too-many-locals + @apply_defaults + def __init__( + self, + *, + sql: Union[str, Iterable], + destination_dataset_table: Optional[str] = None, + write_disposition: str = "WRITE_EMPTY", + allow_large_results: Optional[bool] = False, + flatten_results: Optional[bool] = None, + gcp_conn_id: str = "google_cloud_default", + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + udf_config: Optional[list] = None, + use_legacy_sql: bool = True, + maximum_billing_tier: Optional[int] = None, + maximum_bytes_billed: Optional[float] = None, + create_disposition: str = "CREATE_IF_NEEDED", + schema_update_options: Optional[Union[list, tuple, set]] = None, + query_params: Optional[list] = None, + labels: Optional[dict] = None, + priority: str = "INTERACTIVE", + time_partitioning: Optional[dict] = None, + api_resource_configs: Optional[dict] = None, + cluster_fields: Optional[List[str]] = None, + location: Optional[str] = None, + encryption_configuration: Optional[dict] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + ) + gcp_conn_id = bigquery_conn_id + + warnings.warn( + "This operator is deprecated. Please use `BigQueryInsertJobOperator`.", + DeprecationWarning, + ) + + self.sql = sql + self.destination_dataset_table = destination_dataset_table + self.write_disposition = write_disposition + self.create_disposition = create_disposition + self.allow_large_results = allow_large_results + self.flatten_results = flatten_results + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.udf_config = udf_config + self.use_legacy_sql = use_legacy_sql + self.maximum_billing_tier = maximum_billing_tier + self.maximum_bytes_billed = maximum_bytes_billed + self.schema_update_options = schema_update_options + self.query_params = query_params + self.labels = labels + self.priority = priority + self.time_partitioning = time_partitioning + self.api_resource_configs = api_resource_configs + self.cluster_fields = cluster_fields + self.location = location + self.encryption_configuration = encryption_configuration + self.hook = None # type: Optional[BigQueryHook] + self.impersonation_chain = impersonation_chain + + def execute(self, context): + if self.hook is None: + self.log.info("Executing: %s", self.sql) + self.hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + use_legacy_sql=self.use_legacy_sql, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + if isinstance(self.sql, str): + job_id = self.hook.run_query( + sql=self.sql, + destination_dataset_table=self.destination_dataset_table, + write_disposition=self.write_disposition, + allow_large_results=self.allow_large_results, + flatten_results=self.flatten_results, + udf_config=self.udf_config, + maximum_billing_tier=self.maximum_billing_tier, + maximum_bytes_billed=self.maximum_bytes_billed, + create_disposition=self.create_disposition, + query_params=self.query_params, + labels=self.labels, + schema_update_options=self.schema_update_options, + priority=self.priority, + time_partitioning=self.time_partitioning, + api_resource_configs=self.api_resource_configs, + cluster_fields=self.cluster_fields, + encryption_configuration=self.encryption_configuration, + ) + elif isinstance(self.sql, Iterable): + job_id = [ + self.hook.run_query( + sql=s, + destination_dataset_table=self.destination_dataset_table, + write_disposition=self.write_disposition, + allow_large_results=self.allow_large_results, + flatten_results=self.flatten_results, + udf_config=self.udf_config, + maximum_billing_tier=self.maximum_billing_tier, + maximum_bytes_billed=self.maximum_bytes_billed, + create_disposition=self.create_disposition, + query_params=self.query_params, + labels=self.labels, + schema_update_options=self.schema_update_options, + priority=self.priority, + time_partitioning=self.time_partitioning, + api_resource_configs=self.api_resource_configs, + cluster_fields=self.cluster_fields, + encryption_configuration=self.encryption_configuration, + ) + for s in self.sql + ] + else: + raise AirflowException( + f"argument 'sql' of type {type(str)} is neither a string nor an iterable" + ) + context["task_instance"].xcom_push(key="job_id", value=job_id) + + def on_kill(self) -> None: + super().on_kill() + if self.hook is not None: + self.log.info("Cancelling running query") + self.hook.cancel_query() + + +class BigQueryCreateEmptyTableOperator(BaseOperator): + """ + Creates a new, empty table in the specified BigQuery dataset, + optionally with schema. + + The schema to be used for the BigQuery table may be specified in one of + two ways. You may either directly pass the schema fields in, or you may + point the operator to a Google Cloud Storage object name. The object in + Google Cloud Storage must be a JSON file with the schema fields in it. + You can also create a table without schema. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryCreateEmptyTableOperator` + + :param project_id: The project to create the table into. (templated) + :type project_id: str + :param dataset_id: The dataset to create the table into. (templated) + :type dataset_id: str + :param table_id: The Name of the table to be created. (templated) + :type table_id: str + :param table_re# Table resource as described in documentation: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#Table + If provided all other parameters are ignored. + :type table_re# Dict[str, Any] + :param schema_fields: If set, the schema field list as defined here: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema + + **Example**: :: + + schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}] + + :type schema_fields: list + :param gcs_schema_object: Full path to the JSON file containing + schema (templated). For + example: ``gs://test-bucket/dir1/dir2/employee_schema.json`` + :type gcs_schema_object: str + :param time_partitioning: configure optional time partitioning fields i.e. + partition by field, type and expiration as per API specifications. + + .. seealso:: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#timePartitioning + :type time_partitioning: dict + :param bigquery_conn_id: [Optional] The connection ID used to connect to Google Cloud and + interact with the Bigquery service. + :type bigquery_conn_id: str + :param google_cloud_storage_conn_id: [Optional] The connection ID used to connect to Google Cloud. + and interact with the Google Cloud Storage service. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param labels: a dictionary containing labels for the table, passed to BigQuery + + **Example (with schema JSON in GCS)**: :: + + CreateTable = BigQueryCreateEmptyTableOperator( + task_id='BigQueryCreateEmptyTableOperator_task', + dataset_id='ODS', + table_id='Employees', + project_id='internal-gcp-project', + gcs_schema_object='gs://schema-bucket/employee_schema.json', + bigquery_conn_id='airflow-conn-id', + google_cloud_storage_conn_id='airflow-conn-id' + ) + + **Corresponding Schema file** (``employee_schema.json``): :: + + [ + { + "mode": "NULLABLE", + "name": "emp_name", + "type": "STRING" + }, + { + "mode": "REQUIRED", + "name": "salary", + "type": "INTEGER" + } + ] + + **Example (with schema in the DAG)**: :: + + CreateTable = BigQueryCreateEmptyTableOperator( + task_id='BigQueryCreateEmptyTableOperator_task', + dataset_id='ODS', + table_id='Employees', + project_id='internal-gcp-project', + schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}], + bigquery_conn_id='airflow-conn-id-account', + google_cloud_storage_conn_id='airflow-conn-id' + ) + :type labels: dict + :param view: [Optional] A dictionary containing definition for the view. + If set, it will create a view instead of a table: + + .. seealso:: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#ViewDefinition + :type view: dict + :param materialized_view: [Optional] The materialized view definition. + :type materialized_view: dict + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + :param location: The location used for the operation. + :type location: str + :param cluster_fields: [Optional] The fields used for clustering. + BigQuery supports clustering for both partitioned and + non-partitioned tables. + + .. seealso:: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#clustering.fields + :type cluster_fields: list + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + :param exists_ok: If ``True``, ignore "already exists" errors when creating the table. + :type exists_ok: bool + """ + + template_fields = ( + "dataset_id", + "table_id", + "project_id", + "gcs_schema_object", + "labels", + "view", + "materialized_view", + "impersonation_chain", + ) + template_fields_renderers = {"table_resource": "json", "materialized_view": "json"} + ui_color = BigQueryUIColors.TABLE.value + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + dataset_id: str, + table_id: str, + table_re# Optional[Dict[str, Any]] = None, + project_id: Optional[str] = None, + schema_fields: Optional[List] = None, + gcs_schema_object: Optional[str] = None, + time_partitioning: Optional[Dict] = None, + bigquery_conn_id: str = "google_cloud_default", + google_cloud_storage_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + labels: Optional[Dict] = None, + view: Optional[Dict] = None, + materialized_view: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None, + location: Optional[str] = None, + cluster_fields: Optional[List[str]] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + exists_ok: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.project_id = project_id + self.dataset_id = dataset_id + self.table_id = table_id + self.schema_fields = schema_fields + self.gcs_schema_object = gcs_schema_object + self.bigquery_conn_id = bigquery_conn_id + self.google_cloud_storage_conn_id = google_cloud_storage_conn_id + self.delegate_to = delegate_to + self.time_partitioning = {} if time_partitioning is None else time_partitioning + self.labels = labels + self.view = view + self.materialized_view = materialized_view + self.encryption_configuration = encryption_configuration + self.location = location + self.cluster_fields = cluster_fields + self.table_resource = table_resource + self.impersonation_chain = impersonation_chain + self.exists_ok = exists_ok + + def execute(self, context) -> None: + bq_hook = BigQueryHook( + gcp_conn_id=self.bigquery_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + if not self.schema_fields and self.gcs_schema_object: + gcs_bucket, gcs_object = _parse_gcs_url(self.gcs_schema_object) + gcs_hook = GCSHook( + gcp_conn_id=self.google_cloud_storage_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + schema_fields = json.loads(gcs_hook.download(gcs_bucket, gcs_object)) + else: + schema_fields = self.schema_fields + + try: + self.log.info("Creating table") + table = bq_hook.create_empty_table( + project_id=self.project_id, + dataset_id=self.dataset_id, + table_id=self.table_id, + schema_fields=schema_fields, + time_partitioning=self.time_partitioning, + cluster_fields=self.cluster_fields, + labels=self.labels, + view=self.view, + materialized_view=self.materialized_view, + encryption_configuration=self.encryption_configuration, + table_resource=self.table_resource, + exists_ok=self.exists_ok, + ) + self.log.info( + "Table %s.%s.%s created successfully", + table.project, + table.dataset_id, + table.table_id, + ) + except Conflict: + self.log.info("Table %s.%s already exists.", self.dataset_id, self.table_id) + + +# pylint: disable=too-many-instance-attributes +class BigQueryCreateExternalTableOperator(BaseOperator): + """ + Creates a new external table in the dataset with the data from Google Cloud + Storage. + + The schema to be used for the BigQuery table may be specified in one of + two ways. You may either directly pass the schema fields in, or you may + point the operator to a Google Cloud Storage object name. The object in + Google Cloud Storage must be a JSON file with the schema fields in it. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryCreateExternalTableOperator` + + :param bucket: The bucket to point the external table to. (templated) + :type bucket: str + :param source_objects: List of Google Cloud Storage URIs to point + table to. If source_format is 'DATASTORE_BACKUP', the list must only contain a single URI. + :type source_objects: list + :param destination_project_dataset_table: The dotted ``(.).
`` + BigQuery table to load data into (templated). If ```` is not included, + project will be the project defined in the connection json. + :type destination_project_dataset_table: str + :param schema_fields: If set, the schema field list as defined here: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema + + **Example**: :: + + schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}] + + Should not be set when source_format is 'DATASTORE_BACKUP'. + :param table_re# Table resource as described in documentation: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#Table + If provided all other parameters are ignored. External schema from object will be resolved. + :type table_re# Dict[str, Any] + :type schema_fields: list + :param schema_object: If set, a GCS object path pointing to a .json file that + contains the schema for the table. (templated) + :type schema_object: str + :param source_format: File format of the data. + :type source_format: str + :param compression: [Optional] The compression type of the data source. + Possible values include GZIP and NONE. + The default value is NONE. + This setting is ignored for Google Cloud Bigtable, + Google Cloud Datastore backups and Avro formats. + :type compression: str + :param skip_leading_rows: Number of rows to skip when loading from a CSV. + :type skip_leading_rows: int + :param field_delimiter: The delimiter to use for the CSV. + :type field_delimiter: str + :param max_bad_records: The maximum number of bad records that BigQuery can + ignore when running the job. + :type max_bad_records: int + :param quote_character: The value that is used to quote data sections in a CSV file. + :type quote_character: str + :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not (false). + :type allow_quoted_newlines: bool + :param allow_jagged_rows: Accept rows that are missing trailing optional columns. + The missing values are treated as nulls. If false, records with missing trailing + columns are treated as bad records, and if there are too many bad records, an + invalid error is returned in the job result. Only applicable to CSV, ignored + for other formats. + :type allow_jagged_rows: bool + :param bigquery_conn_id: (Optional) The connection ID used to connect to Google Cloud and + interact with the Bigquery service. + :type bigquery_conn_id: str + :param google_cloud_storage_conn_id: (Optional) The connection ID used to connect to Google Cloud + and interact with the Google Cloud Storage service. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param src_fmt_configs: configure optional fields specific to the source format + :type src_fmt_configs: dict + :param labels: a dictionary containing labels for the table, passed to BigQuery + :type labels: dict + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + :param location: The location used for the operation. + :type location: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "bucket", + "source_objects", + "schema_object", + "destination_project_dataset_table", + "labels", + "table_resource", + "impersonation_chain", + ) + template_fields_renderers = {"table_resource": "json"} + ui_color = BigQueryUIColors.TABLE.value + + # pylint: disable=too-many-arguments,too-many-locals + @apply_defaults + def __init__( + self, + *, + bucket: str, + source_objects: List, + destination_project_dataset_table: str, + table_re# Optional[Dict[str, Any]] = None, + schema_fields: Optional[List] = None, + schema_object: Optional[str] = None, + source_format: str = "CSV", + compression: str = "NONE", + skip_leading_rows: int = 0, + field_delimiter: str = ",", + max_bad_records: int = 0, + quote_character: Optional[str] = None, + allow_quoted_newlines: bool = False, + allow_jagged_rows: bool = False, + bigquery_conn_id: str = "google_cloud_default", + google_cloud_storage_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + src_fmt_configs: Optional[dict] = None, + labels: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + # GCS config + self.bucket = bucket + self.source_objects = source_objects + self.schema_object = schema_object + + # BQ config + kwargs_passed = any( + [ + destination_project_dataset_table, + schema_fields, + source_format, + compression, + skip_leading_rows, + field_delimiter, + max_bad_records, + quote_character, + allow_quoted_newlines, + allow_jagged_rows, + src_fmt_configs, + labels, + encryption_configuration, + ] + ) + + if not table_re# + warnings.warn( + "Passing table parameters via keywords arguments will be deprecated. " + "Please use provide table definition using `table_resource` parameter." + "You can still use external `schema_object`. ", + DeprecationWarning, + stacklevel=2, + ) + + if table_resource and kwargs_passed: + raise ValueError( + "You provided both `table_resource` and exclusive keywords arguments." + ) + + self.table_resource = table_resource + self.destination_project_dataset_table = destination_project_dataset_table + self.schema_fields = schema_fields + self.source_format = source_format + self.compression = compression + self.skip_leading_rows = skip_leading_rows + self.field_delimiter = field_delimiter + self.max_bad_records = max_bad_records + self.quote_character = quote_character + self.allow_quoted_newlines = allow_quoted_newlines + self.allow_jagged_rows = allow_jagged_rows + + self.bigquery_conn_id = bigquery_conn_id + self.google_cloud_storage_conn_id = google_cloud_storage_conn_id + self.delegate_to = delegate_to + + self.src_fmt_configs = src_fmt_configs or {} + self.labels = labels + self.encryption_configuration = encryption_configuration + self.location = location + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + bq_hook = BigQueryHook( + gcp_conn_id=self.bigquery_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + if ( + not self.schema_fields + and self.schema_object + and self.source_format != "DATASTORE_BACKUP" + ): + gcs_hook = GCSHook( + gcp_conn_id=self.google_cloud_storage_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + schema_fields = json.loads( + gcs_hook.download(self.bucket, self.schema_object) + ) + else: + schema_fields = self.schema_fields + + if schema_fields and self.table_re# + self.table_resource["externalDataConfiguration"]["schema"] = schema_fields + + if self.table_re# + tab_ref = TableReference.from_string(self.destination_project_dataset_table) + bq_hook.create_empty_table( + table_resource=self.table_resource, + project_id=tab_ref.project, + table_id=tab_ref.table_id, + dataset_id=tab_ref.dataset_id, + ) + else: + source_uris = [ + f"gs://{self.bucket}/{source_object}" + for source_object in self.source_objects + ] + + bq_hook.create_external_table( + external_project_dataset_table=self.destination_project_dataset_table, + schema_fields=schema_fields, + source_uris=source_uris, + source_format=self.source_format, + compression=self.compression, + skip_leading_rows=self.skip_leading_rows, + field_delimiter=self.field_delimiter, + max_bad_records=self.max_bad_records, + quote_character=self.quote_character, + allow_quoted_newlines=self.allow_quoted_newlines, + allow_jagged_rows=self.allow_jagged_rows, + src_fmt_configs=self.src_fmt_configs, + labels=self.labels, + encryption_configuration=self.encryption_configuration, + ) + + +class BigQueryDeleteDatasetOperator(BaseOperator): + """ + This operator deletes an existing dataset from your Project in Big query. + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/delete + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryDeleteDatasetOperator` + + :param project_id: The project id of the dataset. + :type project_id: str + :param dataset_id: The dataset to be deleted. + :type dataset_id: str + :param delete_contents: (Optional) Whether to force the deletion even if the dataset is not empty. + Will delete all tables (if any) in the dataset if set to True. + Will raise HttpError 400: "{dataset_id} is still in use" if set to False and dataset is not empty. + The default value is False. + :type delete_contents: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + **Example**: :: + + delete_temp_data = BigQueryDeleteDatasetOperator( + dataset_id='temp-dataset', + project_id='temp-project', + delete_contents=True, # Force the deletion of the dataset as well as its tables (if any). + gcp_conn_id='_my_gcp_conn_', + task_id='Deletetemp', + dag=dag) + """ + + template_fields = ( + "dataset_id", + "project_id", + "impersonation_chain", + ) + ui_color = BigQueryUIColors.DATASET.value + + @apply_defaults + def __init__( + self, + *, + dataset_id: str, + project_id: Optional[str] = None, + delete_contents: bool = False, + gcp_conn_id: str = "google_cloud_default", + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = bigquery_conn_id + + self.dataset_id = dataset_id + self.project_id = project_id + self.delete_contents = delete_contents + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + super().__init__(**kwargs) + + def execute(self, context) -> None: + self.log.info("Dataset id: %s Project id: %s", self.dataset_id, self.project_id) + + bq_hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + bq_hook.delete_dataset( + project_id=self.project_id, + dataset_id=self.dataset_id, + delete_contents=self.delete_contents, + ) + + +class BigQueryCreateEmptyDatasetOperator(BaseOperator): + """ + This operator is used to create new dataset for your Project in BigQuery. + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryCreateEmptyDatasetOperator` + + :param project_id: The name of the project where we want to create the dataset. + :type project_id: str + :param dataset_id: The id of dataset. Don't need to provide, if datasetId in dataset_reference. + :type dataset_id: str + :param location: The geographic location where the dataset should reside. + :type location: str + :param dataset_reference: Dataset reference that could be provided with request body. + More info: + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + :type dataset_reference: dict + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + :param exists_ok: If ``True``, ignore "already exists" errors when creating the dataset. + :type exists_ok: bool + **Example**: :: + + create_new_dataset = BigQueryCreateEmptyDatasetOperator( + dataset_id='new-dataset', + project_id='my-project', + dataset_reference={"friendlyName": "New Dataset"} + gcp_conn_id='_my_gcp_conn_', + task_id='newDatasetCreator', + dag=dag) + """ + + template_fields = ( + "dataset_id", + "project_id", + "dataset_reference", + "impersonation_chain", + ) + template_fields_renderers = {"dataset_reference": "json"} + ui_color = BigQueryUIColors.DATASET.value + + @apply_defaults + def __init__( + self, + *, + dataset_id: Optional[str] = None, + project_id: Optional[str] = None, + dataset_reference: Optional[Dict] = None, + location: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + exists_ok: bool = False, + **kwargs, + ) -> None: + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = bigquery_conn_id + + self.dataset_id = dataset_id + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.dataset_reference = dataset_reference if dataset_reference else {} + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.exists_ok = exists_ok + + super().__init__(**kwargs) + + def execute(self, context) -> None: + bq_hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + try: + bq_hook.create_empty_dataset( + project_id=self.project_id, + dataset_id=self.dataset_id, + dataset_reference=self.dataset_reference, + location=self.location, + exists_ok=self.exists_ok, + ) + except Conflict: + dataset_id = self.dataset_reference.get("datasetReference", {}).get( + "datasetId", self.dataset_id + ) + self.log.info("Dataset %s already exists.", dataset_id) + + +class BigQueryGetDatasetOperator(BaseOperator): + """ + This operator is used to return the dataset specified by dataset_id. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryGetDatasetOperator` + + :param dataset_id: The id of dataset. Don't need to provide, + if datasetId in dataset_reference. + :type dataset_id: str + :param project_id: The name of the project where we want to create the dataset. + Don't need to provide, if projectId in dataset_reference. + :type project_id: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: dataset + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + """ + + template_fields = ( + "dataset_id", + "project_id", + "impersonation_chain", + ) + ui_color = BigQueryUIColors.DATASET.value + + @apply_defaults + def __init__( + self, + *, + dataset_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + self.dataset_id = dataset_id + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def execute(self, context): + bq_hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Start getting dataset: %s:%s", self.project_id, self.dataset_id) + + dataset = bq_hook.get_dataset( + dataset_id=self.dataset_id, project_id=self.project_id + ) + return dataset.to_api_repr() + + +class BigQueryGetDatasetTablesOperator(BaseOperator): + """ + This operator retrieves the list of tables in the specified dataset. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryGetDatasetTablesOperator` + + :param dataset_id: the dataset ID of the requested dataset. + :type dataset_id: str + :param project_id: (Optional) the project of the requested dataset. If None, + self.project_id will be used. + :type project_id: str + :param max_results: (Optional) the maximum number of tables to return. + :type max_results: int + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "dataset_id", + "project_id", + "impersonation_chain", + ) + ui_color = BigQueryUIColors.DATASET.value + + @apply_defaults + def __init__( + self, + *, + dataset_id: str, + project_id: Optional[str] = None, + max_results: Optional[int] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.dataset_id = dataset_id + self.project_id = project_id + self.max_results = max_results + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def execute(self, context): + bq_hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + return bq_hook.get_dataset_tables( + dataset_id=self.dataset_id, + project_id=self.project_id, + max_results=self.max_results, + ) + + +class BigQueryPatchDatasetOperator(BaseOperator): + """ + This operator is used to patch dataset for your Project in BigQuery. + It only replaces fields that are provided in the submitted dataset resource. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryPatchDatasetOperator` + + :param dataset_id: The id of dataset. Don't need to provide, + if datasetId in dataset_reference. + :type dataset_id: str + :param dataset_re# Dataset resource that will be provided with request body. + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + :type dataset_re# dict + :param project_id: The name of the project where we want to create the dataset. + Don't need to provide, if projectId in dataset_reference. + :type project_id: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: dataset + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + """ + + template_fields = ( + "dataset_id", + "project_id", + "impersonation_chain", + ) + template_fields_renderers = {"dataset_resource": "json"} + ui_color = BigQueryUIColors.DATASET.value + + @apply_defaults + def __init__( + self, + *, + dataset_id: str, + dataset_re# dict, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + warnings.warn( + "This operator is deprecated. Please use BigQueryUpdateDatasetOperator.", + DeprecationWarning, + stacklevel=3, + ) + self.dataset_id = dataset_id + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.dataset_resource = dataset_resource + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def execute(self, context): + bq_hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + return bq_hook.patch_dataset( + dataset_id=self.dataset_id, + dataset_resource=self.dataset_resource, + project_id=self.project_id, + ) + + +class BigQueryUpdateTableOperator(BaseOperator): + """ + This operator is used to update table for your Project in BigQuery. + Use ``fields`` to specify which fields of table to update. If a field + is listed in ``fields`` and is ``None`` in table, it will be deleted. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryUpdateTableOperator` + + :param dataset_id: The id of dataset. Don't need to provide, + if datasetId in table_reference. + :param table_id: The id of table. Don't need to provide, + if tableId in table_reference. + :type table_id: str + :param table_re# Dataset resource that will be provided with request body. + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#resource + :type table_re# Dict[str, Any] + :param fields: The fields of ``table`` to change, spelled as the Table + properties (e.g. "friendly_name"). + :type fields: List[str] + :param project_id: The name of the project where we want to create the table. + Don't need to provide, if projectId in table_reference. + :type project_id: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: table + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#resource + """ + + template_fields = ( + "dataset_id", + "table_id", + "project_id", + "impersonation_chain", + ) + template_fields_renderers = {"table_resource": "json"} + ui_color = BigQueryUIColors.TABLE.value + + @apply_defaults + def __init__( + self, + *, + table_re# Dict[str, Any], + fields: Optional[List[str]] = None, + dataset_id: Optional[str] = None, + table_id: Optional[str] = None, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.dataset_id = dataset_id + self.table_id = table_id + self.project_id = project_id + self.fields = fields + self.gcp_conn_id = gcp_conn_id + self.table_resource = table_resource + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def execute(self, context): + bq_hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + return bq_hook.update_table( + table_resource=self.table_resource, + fields=self.fields, + dataset_id=self.dataset_id, + table_id=self.table_id, + project_id=self.project_id, + ) + + +class BigQueryUpdateDatasetOperator(BaseOperator): + """ + This operator is used to update dataset for your Project in BigQuery. + Use ``fields`` to specify which fields of dataset to update. If a field + is listed in ``fields`` and is ``None`` in dataset, it will be deleted. + If no ``fields`` are provided then all fields of provided ``dataset_resource`` + will be used. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryUpdateDatasetOperator` + + :param dataset_id: The id of dataset. Don't need to provide, + if datasetId in dataset_reference. + :type dataset_id: str + :param dataset_re# Dataset resource that will be provided with request body. + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + :type dataset_re# Dict[str, Any] + :param fields: The properties of dataset to change (e.g. "friendly_name"). + :type fields: Sequence[str] + :param project_id: The name of the project where we want to create the dataset. + Don't need to provide, if projectId in dataset_reference. + :type project_id: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: dataset + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + """ + + template_fields = ( + "dataset_id", + "project_id", + "impersonation_chain", + ) + template_fields_renderers = {"dataset_resource": "json"} + ui_color = BigQueryUIColors.DATASET.value + + @apply_defaults + def __init__( + self, + *, + dataset_re# Dict[str, Any], + fields: Optional[List[str]] = None, + dataset_id: Optional[str] = None, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.dataset_id = dataset_id + self.project_id = project_id + self.fields = fields + self.gcp_conn_id = gcp_conn_id + self.dataset_resource = dataset_resource + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def execute(self, context): + bq_hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + fields = self.fields or list(self.dataset_resource.keys()) + + dataset = bq_hook.update_dataset( + dataset_resource=self.dataset_resource, + project_id=self.project_id, + dataset_id=self.dataset_id, + fields=fields, + ) + return dataset.to_api_repr() + + +class BigQueryDeleteTableOperator(BaseOperator): + """ + Deletes BigQuery tables + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryDeleteTableOperator` + + :param deletion_dataset_table: A dotted + ``(.|:).
`` that indicates which table + will be deleted. (templated) + :type deletion_dataset_table: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param ignore_if_missing: if True, then return success even if the + requested table does not exist. + :type ignore_if_missing: bool + :param location: The location used for the operation. + :type location: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "deletion_dataset_table", + "impersonation_chain", + ) + ui_color = BigQueryUIColors.TABLE.value + + @apply_defaults + def __init__( + self, + *, + deletion_dataset_table: str, + gcp_conn_id: str = "google_cloud_default", + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + ignore_if_missing: bool = False, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = bigquery_conn_id + + self.deletion_dataset_table = deletion_dataset_table + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.ignore_if_missing = ignore_if_missing + self.location = location + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + self.log.info("Deleting: %s", self.deletion_dataset_table) + hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + hook.delete_table( + table_id=self.deletion_dataset_table, not_found_ok=self.ignore_if_missing + ) + + +class BigQueryUpsertTableOperator(BaseOperator): + """ + Upsert BigQuery table + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryUpsertTableOperator` + + :param dataset_id: A dotted + ``(.|:)`` that indicates which dataset + will be updated. (templated) + :type dataset_id: str + :param table_re# a table resource. see + https://cloud.google.com/bigquery/docs/reference/v2/tables#resource + :type table_re# dict + :param project_id: The name of the project where we want to update the dataset. + Don't need to provide, if projectId in dataset_reference. + :type project_id: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: str + :param location: The location used for the operation. + :type location: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "dataset_id", + "table_resource", + "impersonation_chain", + ) + template_fields_renderers = {"table_resource": "json"} + ui_color = BigQueryUIColors.TABLE.value + + @apply_defaults + def __init__( + self, + *, + dataset_id: str, + table_re# dict, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = bigquery_conn_id + + self.dataset_id = dataset_id + self.table_resource = table_resource + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.location = location + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + self.log.info( + "Upserting Dataset: %s with table_re# %s", + self.dataset_id, + self.table_resource, + ) + hook = BigQueryHook( + bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + hook.run_table_upsert( + dataset_id=self.dataset_id, + table_resource=self.table_resource, + project_id=self.project_id, + ) + + +# pylint: disable=too-many-arguments +class BigQueryInsertJobOperator(BaseOperator): + """ + Executes a BigQuery job. Waits for the job to complete and returns job id. + This operator work in the following way: + + - it calculates a unique hash of the job using job's configuration or uuid if ``force_rerun`` is True + - creates ``job_id`` in form of + ``[provided_job_id | airflow_{dag_id}_{task_id}_{exec_date}]_{uniqueness_suffix}`` + - submits a BigQuery job using the ``job_id`` + - if job with given id already exists then it tries to reattach to the job if its not done and its + state is in ``reattach_states``. If the job is done the operator will raise ``AirflowException``. + + Using ``force_rerun`` will submit a new job every time without attaching to already existing ones. + + For job definition see here: + + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryInsertJobOperator` + + + :param configuration: The configuration parameter maps directly to BigQuery's + configuration field in the job object. For more details see + https://cloud.google.com/bigquery/docs/reference/v2/jobs + :type configuration: Dict[str, Any] + :param job_id: The ID of the job. It will be suffixed with hash of job configuration + unless ``force_rerun`` is True. + The ID must contain only letters (a-z, A-Z), numbers (0-9), underscores (_), or + dashes (-). The maximum length is 1,024 characters. If not provided then uuid will + be generated. + :type job_id: str + :param force_rerun: If True then operator will use hash of uuid as job id suffix + :type force_rerun: bool + :param reattach_states: Set of BigQuery job's states in case of which we should reattach + to the job. Should be other than final states. + :param project_id: Google Cloud Project where the job is running + :type project_id: str + :param location: location the job is running + :type location: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + :param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called + :type cancel_on_kill: bool + """ + + template_fields = ( + "configuration", + "job_id", + "impersonation_chain", + ) + template_ext = (".json",) + template_fields_renderers = {"configuration": "json"} + ui_color = BigQueryUIColors.QUERY.value + + @apply_defaults + def __init__( + self, + configuration: Dict[str, Any], + project_id: Optional[str] = None, + location: Optional[str] = None, + job_id: Optional[str] = None, + force_rerun: bool = True, + reattach_states: Optional[Set[str]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + cancel_on_kill: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.configuration = configuration + self.location = location + self.job_id = job_id + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.force_rerun = force_rerun + self.reattach_states: Set[str] = reattach_states or set() + self.impersonation_chain = impersonation_chain + self.cancel_on_kill = cancel_on_kill + self.hook: Optional[BigQueryHook] = None + + def prepare_template(self) -> None: + # If .json is passed then we have to read the file + if isinstance(self.configuration, str) and self.configuration.endswith(".json"): + with open(self.configuration) as file: + self.configuration = json.loads(file.read()) + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + # Submit a new job + job = hook.insert_job( + configuration=self.configuration, + project_id=self.project_id, + location=self.location, + job_id=job_id, + ) + # Start the job and wait for it to complete and get the result. + job.result() + return job + + @staticmethod + def _handle_job_error(job: BigQueryJob) -> None: + if job.error_result: + raise AirflowException( + f"BigQuery job {job.job_id} failed: {job.error_result}" + ) + + def _job_id(self, context): + if self.force_rerun: + hash_base = str(uuid.uuid4()) + else: + hash_base = json.dumps(self.configuration, sort_keys=True) + + uniqueness_suffix = hashlib.md5(hash_base.encode()).hexdigest() + + if self.job_id: + return f"{self.job_id}_{uniqueness_suffix}" + + exec_date = context["execution_date"].isoformat() + job_id = f"airflow_{self.dag_id}_{self.task_id}_{exec_date}_{uniqueness_suffix}" + return re.sub(r"[:\-+.]", "_", job_id) + + def execute(self, context: Any): + hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + self.hook = hook + + job_id = self._job_id(context) + + try: + job = self._submit_job(hook, job_id) + self._handle_job_error(job) + except Conflict: + # If the job already exists retrieve it + job = hook.get_job( + project_id=self.project_id, + location=self.location, + job_id=job_id, + ) + if job.state in self.reattach_states: + # We are reattaching to a job + job.result() + self._handle_job_error(job) + else: + # Same job configuration so we need force_rerun + raise AirflowException( + f"Job with id: {job_id} already exists and is in {job.state} state. If you " + f"want to force rerun it consider setting `force_rerun=True`." + f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`" + ) + + self.job_id = job.job_id + return job.job_id + + def on_kill(self) -> None: + if self.job_id and self.cancel_on_kill: + self.hook.cancel_job( # type: ignore[union-attr] + job_id=self.job_id, project_id=self.project_id, location=self.location + ) diff --git a/reference/providers/google/cloud/operators/bigquery_dts.py b/reference/providers/google/cloud/operators/bigquery_dts.py new file mode 100644 index 0000000..1ee91c5 --- /dev/null +++ b/reference/providers/google/cloud/operators/bigquery_dts.py @@ -0,0 +1,301 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google BigQuery Data Transfer Service operators.""" +from typing import Optional, Sequence, Tuple, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.bigquery_dts import ( + BiqQueryDataTransferServiceHook, + get_object_id, +) +from airflow.utils.decorators import apply_defaults +from google.api_core.retry import Retry +from google.cloud.bigquery_datatransfer_v1 import ( + StartManualTransferRunsResponse, + TransferConfig, +) + + +class BigQueryCreateDataTransferOperator(BaseOperator): + """ + Creates a new data transfer configuration. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryCreateDataTransferOperator` + + :param transfer_config: Data transfer configuration to create. + :type transfer_config: dict + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the Google Cloud connection + is used. + :type project_id: str + :param authorization_code: authorization code to use with this transfer configuration. + This is required if new credentials are needed. + :type authorization_code: Optional[str] + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "transfer_config", + "project_id", + "authorization_code", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + transfer_config: dict, + project_id: Optional[str] = None, + authorization_code: Optional[str] = None, + retry: Retry = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id="google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.transfer_config = transfer_config + self.authorization_code = authorization_code + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = BiqQueryDataTransferServiceHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info("Creating DTS transfer config") + response = hook.create_transfer_config( + project_id=self.project_id, + transfer_config=self.transfer_config, + authorization_code=self.authorization_code, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = TransferConfig.to_dict(response) + self.log.info("Created DTS transfer config %s", get_object_id(result)) + self.xcom_push(context, key="transfer_config_id", value=get_object_id(result)) + return result + + +class BigQueryDeleteDataTransferConfigOperator(BaseOperator): + """ + Deletes transfer configuration. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryDeleteDataTransferConfigOperator` + + :param transfer_config_id: Id of transfer config to be used. + :type transfer_config_id: str + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "transfer_config_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + transfer_config_id: str, + project_id: Optional[str] = None, + retry: Retry = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id="google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.transfer_config_id = transfer_config_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = BiqQueryDataTransferServiceHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.delete_transfer_config( + transfer_config_id=self.transfer_config_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class BigQueryDataTransferServiceStartTransferRunsOperator(BaseOperator): + """ + Start manual transfer runs to be executed now with schedule_time equal + to current time. The transfer runs can be created for a time range where + the run_time is between start_time (inclusive) and end_time + (exclusive), or for a specific run_time. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryDataTransferServiceStartTransferRunsOperator` + + :param transfer_config_id: Id of transfer config to be used. + :type transfer_config_id: str + :param requested_time_range: Time range for the transfer runs that should be started. + If a dict is provided, it must be of the same form as the protobuf + message `~google.cloud.bigquery_datatransfer_v1.types.TimeRange` + :type requested_time_range: Union[dict, ~google.cloud.bigquery_datatransfer_v1.types.TimeRange] + :param requested_run_time: Specific run_time for a transfer run to be started. The + requested_run_time must not be in the future. If a dict is provided, it + must be of the same form as the protobuf message + `~google.cloud.bigquery_datatransfer_v1.types.Timestamp` + :type requested_run_time: Union[dict, ~google.cloud.bigquery_datatransfer_v1.types.Timestamp] + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "transfer_config_id", + "project_id", + "requested_time_range", + "requested_run_time", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + transfer_config_id: str, + project_id: Optional[str] = None, + requested_time_range: Optional[dict] = None, + requested_run_time: Optional[dict] = None, + retry: Retry = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id="google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.transfer_config_id = transfer_config_id + self.requested_time_range = requested_time_range + self.requested_run_time = requested_run_time + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = BiqQueryDataTransferServiceHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info("Submitting manual transfer for %s", self.transfer_config_id) + response = hook.start_manual_transfer_runs( + transfer_config_id=self.transfer_config_id, + requested_time_range=self.requested_time_range, + requested_run_time=self.requested_run_time, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = StartManualTransferRunsResponse.to_dict(response) + run_id = get_object_id(result["runs"][0]) + self.xcom_push(context, key="run_id", value=run_id) + self.log.info("Transfer run %s submitted successfully.", run_id) + return result diff --git a/reference/providers/google/cloud/operators/bigtable.py b/reference/providers/google/cloud/operators/bigtable.py new file mode 100644 index 0000000..0955cfe --- /dev/null +++ b/reference/providers/google/cloud/operators/bigtable.py @@ -0,0 +1,670 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud Bigtable operators.""" +import enum +from typing import Dict, Iterable, List, Optional, Sequence, Union + +import google.api_core.exceptions +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.bigtable import BigtableHook +from airflow.utils.decorators import apply_defaults +from google.cloud.bigtable.column_family import GarbageCollectionRule +from google.cloud.bigtable_admin_v2 import enums + + +class BigtableValidationMixin: + """Common class for Cloud Bigtable operators for validating required fields.""" + + REQUIRED_ATTRIBUTES = [] # type: Iterable[str] + + def _validate_inputs(self): + for attr_name in self.REQUIRED_ATTRIBUTES: + if not getattr(self, attr_name): + raise AirflowException(f"Empty parameter: {attr_name}") + + +class BigtableCreateInstanceOperator(BaseOperator, BigtableValidationMixin): + """ + Creates a new Cloud Bigtable instance. + If the Cloud Bigtable instance with the given ID exists, the operator does not + compare its configuration + and immediately succeeds. No changes are made to the existing instance. + + For more details about instance creation have a look at the reference: + https://googleapis.github.io/google-cloud-python/latest/bigtable/instance.html#google.cloud.bigtable.instance.Instance.create + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigtableCreateInstanceOperator` + + :type instance_id: str + :param instance_id: The ID of the Cloud Bigtable instance to create. + :type main_cluster_id: str + :param main_cluster_id: The ID for main cluster for the new instance. + :type main_cluster_zone: str + :param main_cluster_zone: The zone for main cluster + See https://cloud.google.com/bigtable/docs/locations for more details. + :type project_id: str + :param project_id: Optional, the ID of the Google Cloud project. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type replica_clusters: List[Dict[str, str]] + :param replica_clusters: (optional) A list of replica clusters for the new + instance. Each cluster dictionary contains an id and a zone. + Example: [{"id": "replica-1", "zone": "us-west1-a"}] + :type replica_cluster_id: str + :param replica_cluster_id: (deprecated) The ID for replica cluster for the new + instance. + :type replica_cluster_zone: str + :param replica_cluster_zone: (deprecated) The zone for replica cluster. + :type instance_type: enum.IntEnum + :param instance_type: (optional) The type of the instance. + :type instance_display_name: str + :param instance_display_name: (optional) Human-readable name of the instance. Defaults + to ``instance_id``. + :type instance_labels: dict + :param instance_labels: (optional) Dictionary of labels to associate + with the instance. + :type cluster_nodes: int + :param cluster_nodes: (optional) Number of nodes for cluster. + :type cluster_storage_type: enum.IntEnum + :param cluster_storage_type: (optional) The type of storage. + :type timeout: int + :param timeout: (optional) timeout (in seconds) for instance creation. + If None is not specified, Operator will wait indefinitely. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + REQUIRED_ATTRIBUTES: Iterable[str] = ( + "instance_id", + "main_cluster_id", + "main_cluster_zone", + ) + template_fields: Iterable[str] = [ + "project_id", + "instance_id", + "main_cluster_id", + "main_cluster_zone", + "impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, # pylint: disable=too-many-arguments + instance_id: str, + main_cluster_id: str, + main_cluster_zone: str, + project_id: Optional[str] = None, + replica_clusters: Optional[List[Dict[str, str]]] = None, + replica_cluster_id: Optional[str] = None, + replica_cluster_zone: Optional[str] = None, + instance_display_name: Optional[str] = None, + instance_type: Optional[enums.Instance.Type] = None, + instance_labels: Optional[Dict] = None, + cluster_nodes: Optional[int] = None, + cluster_storage_type: Optional[enums.StorageType] = None, + timeout: Optional[float] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.project_id = project_id + self.instance_id = instance_id + self.main_cluster_id = main_cluster_id + self.main_cluster_zone = main_cluster_zone + self.replica_clusters = replica_clusters + self.replica_cluster_id = replica_cluster_id + self.replica_cluster_zone = replica_cluster_zone + self.instance_display_name = instance_display_name + self.instance_type = instance_type + self.instance_labels = instance_labels + self.cluster_nodes = cluster_nodes + self.cluster_storage_type = cluster_storage_type + self.timeout = timeout + self._validate_inputs() + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def execute(self, context) -> None: + hook = BigtableHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + instance = hook.get_instance( + project_id=self.project_id, instance_id=self.instance_id + ) + if instance: + # Based on Instance.__eq__ instance with the same ID and client is + # considered as equal. + self.log.info( + "The instance '%s' already exists in this project. Consider it as created", + self.instance_id, + ) + return + try: + hook.create_instance( + project_id=self.project_id, + instance_id=self.instance_id, + main_cluster_id=self.main_cluster_id, + main_cluster_zone=self.main_cluster_zone, + replica_clusters=self.replica_clusters, + replica_cluster_id=self.replica_cluster_id, + replica_cluster_zone=self.replica_cluster_zone, + instance_display_name=self.instance_display_name, + instance_type=self.instance_type, + instance_labels=self.instance_labels, + cluster_nodes=self.cluster_nodes, + cluster_storage_type=self.cluster_storage_type, + timeout=self.timeout, + ) + except google.api_core.exceptions.GoogleAPICallError as e: + self.log.error("An error occurred. Exiting.") + raise e + + +class BigtableUpdateInstanceOperator(BaseOperator, BigtableValidationMixin): + """ + Updates an existing Cloud Bigtable instance. + + For more details about instance creation have a look at the reference: + https://googleapis.dev/python/bigtable/latest/instance.html#google.cloud.bigtable.instance.Instance.update + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigtableUpdateInstanceOperator` + + :type instance_id: str + :param instance_id: The ID of the Cloud Bigtable instance to update. + :type project_id: str + :param project_id: Optional, the ID of the Google Cloud project. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type instance_display_name: str + :param instance_display_name: (optional) Human-readable name of the instance. + :type instance_type: enums.Instance.Type or enum.IntEnum + :param instance_type: (optional) The type of the instance. + :type instance_labels: dict + :param instance_labels: (optional) Dictionary of labels to associate + with the instance. + :type timeout: int + :param timeout: (optional) timeout (in seconds) for instance update. + If None is not specified, Operator will wait indefinitely. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + REQUIRED_ATTRIBUTES: Iterable[str] = ["instance_id"] + template_fields: Iterable[str] = [ + "project_id", + "instance_id", + "impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + project_id: Optional[str] = None, + instance_display_name: Optional[str] = None, + instance_type: Optional[Union[enums.Instance.Type, enum.IntEnum]] = None, + instance_labels: Optional[Dict] = None, + timeout: Optional[float] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.project_id = project_id + self.instance_id = instance_id + self.instance_display_name = instance_display_name + self.instance_type = instance_type + self.instance_labels = instance_labels + self.timeout = timeout + self._validate_inputs() + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def execute(self, context) -> None: + hook = BigtableHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + instance = hook.get_instance( + project_id=self.project_id, instance_id=self.instance_id + ) + if not instance: + raise AirflowException( + f"Dependency: instance '{self.instance_id}' does not exist." + ) + + try: + hook.update_instance( + project_id=self.project_id, + instance_id=self.instance_id, + instance_display_name=self.instance_display_name, + instance_type=self.instance_type, + instance_labels=self.instance_labels, + timeout=self.timeout, + ) + except google.api_core.exceptions.GoogleAPICallError as e: + self.log.error("An error occurred. Exiting.") + raise e + + +class BigtableDeleteInstanceOperator(BaseOperator, BigtableValidationMixin): + """ + Deletes the Cloud Bigtable instance, including its clusters and all related tables. + + For more details about deleting instance have a look at the reference: + https://googleapis.github.io/google-cloud-python/latest/bigtable/instance.html#google.cloud.bigtable.instance.Instance.delete + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigtableDeleteInstanceOperator` + + :type instance_id: str + :param instance_id: The ID of the Cloud Bigtable instance to delete. + :param project_id: Optional, the ID of the Google Cloud project. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + REQUIRED_ATTRIBUTES = ("instance_id",) # type: Iterable[str] + template_fields = [ + "project_id", + "instance_id", + "impersonation_chain", + ] # type: Iterable[str] + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.project_id = project_id + self.instance_id = instance_id + self._validate_inputs() + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def execute(self, context) -> None: + hook = BigtableHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + hook.delete_instance( + project_id=self.project_id, instance_id=self.instance_id + ) + except google.api_core.exceptions.NotFound: + self.log.info( + "The instance '%s' does not exist in project '%s'. Consider it as deleted", + self.instance_id, + self.project_id, + ) + except google.api_core.exceptions.GoogleAPICallError as e: + self.log.error("An error occurred. Exiting.") + raise e + + +class BigtableCreateTableOperator(BaseOperator, BigtableValidationMixin): + """ + Creates the table in the Cloud Bigtable instance. + + For more details about creating table have a look at the reference: + https://googleapis.github.io/google-cloud-python/latest/bigtable/table.html#google.cloud.bigtable.table.Table.create + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigtableCreateTableOperator` + + :type instance_id: str + :param instance_id: The ID of the Cloud Bigtable instance that will + hold the new table. + :type table_id: str + :param table_id: The ID of the table to be created. + :type project_id: str + :param project_id: Optional, the ID of the Google Cloud project. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type initial_split_keys: list + :param initial_split_keys: (Optional) list of row keys in bytes that will be used to + initially split the table into several tablets. + :type column_families: dict + :param column_families: (Optional) A map columns to create. + The key is the column_id str and the value is a + :class:`google.cloud.bigtable.column_family.GarbageCollectionRule` + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + REQUIRED_ATTRIBUTES = ("instance_id", "table_id") # type: Iterable[str] + template_fields = [ + "project_id", + "instance_id", + "table_id", + "impersonation_chain", + ] # type: Iterable[str] + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + table_id: str, + project_id: Optional[str] = None, + initial_split_keys: Optional[List] = None, + column_families: Optional[Dict[str, GarbageCollectionRule]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.project_id = project_id + self.instance_id = instance_id + self.table_id = table_id + self.initial_split_keys = initial_split_keys or [] + self.column_families = column_families or {} + self._validate_inputs() + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def _compare_column_families(self, hook, instance) -> bool: + table_column_families = hook.get_column_families_for_table( + instance, self.table_id + ) + if set(table_column_families.keys()) != set(self.column_families.keys()): + self.log.error( + "Table '%s' has different set of Column Families", self.table_id + ) + self.log.error("Expected: %s", self.column_families.keys()) + self.log.error("Actual: %s", table_column_families.keys()) + return False + + for key in table_column_families: + # There is difference in structure between local Column Families + # and remote ones + # Local `self.column_families` is dict with column_id as key + # and GarbageCollectionRule as value. + # Remote `table_column_families` is list of ColumnFamily objects. + # For more information about ColumnFamily please refer to the documentation: + # https://googleapis.github.io/google-cloud-python/latest/bigtable/column-family.html#google.cloud.bigtable.column_family.ColumnFamily + if table_column_families[key].gc_rule != self.column_families[key]: + self.log.error( + "Column Family '%s' differs for table '%s'.", key, self.table_id + ) + return False + return True + + def execute(self, context) -> None: + hook = BigtableHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + instance = hook.get_instance( + project_id=self.project_id, instance_id=self.instance_id + ) + if not instance: + raise AirflowException( + "Dependency: instance '{}' does not exist in project '{}'.".format( + self.instance_id, self.project_id + ) + ) + try: + hook.create_table( + instance=instance, + table_id=self.table_id, + initial_split_keys=self.initial_split_keys, + column_families=self.column_families, + ) + except google.api_core.exceptions.AlreadyExists: + if not self._compare_column_families(hook, instance): + raise AirflowException( + f"Table '{self.table_id}' already exists with different Column Families." + ) + self.log.info( + "The table '%s' already exists. Consider it as created", self.table_id + ) + + +class BigtableDeleteTableOperator(BaseOperator, BigtableValidationMixin): + """ + Deletes the Cloud Bigtable table. + + For more details about deleting table have a look at the reference: + https://googleapis.github.io/google-cloud-python/latest/bigtable/table.html#google.cloud.bigtable.table.Table.delete + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigtableDeleteTableOperator` + + :type instance_id: str + :param instance_id: The ID of the Cloud Bigtable instance. + :type table_id: str + :param table_id: The ID of the table to be deleted. + :type project_id: str + :param project_id: Optional, the ID of the Google Cloud project. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type app_profile_id: str + :param app_profile_id: Application profile. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + REQUIRED_ATTRIBUTES = ("instance_id", "table_id") # type: Iterable[str] + template_fields = [ + "project_id", + "instance_id", + "table_id", + "impersonation_chain", + ] # type: Iterable[str] + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + table_id: str, + project_id: Optional[str] = None, + app_profile_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.project_id = project_id + self.instance_id = instance_id + self.table_id = table_id + self.app_profile_id = app_profile_id + self._validate_inputs() + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def execute(self, context) -> None: + hook = BigtableHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + instance = hook.get_instance( + project_id=self.project_id, instance_id=self.instance_id + ) + if not instance: + raise AirflowException( + f"Dependency: instance '{self.instance_id}' does not exist." + ) + + try: + hook.delete_table( + project_id=self.project_id, + instance_id=self.instance_id, + table_id=self.table_id, + ) + except google.api_core.exceptions.NotFound: + # It's OK if table doesn't exists. + self.log.info( + "The table '%s' no longer exists. Consider it as deleted", self.table_id + ) + except google.api_core.exceptions.GoogleAPICallError as e: + self.log.error("An error occurred. Exiting.") + raise e + + +class BigtableUpdateClusterOperator(BaseOperator, BigtableValidationMixin): + """ + Updates a Cloud Bigtable cluster. + + For more details about updating a Cloud Bigtable cluster, + have a look at the reference: + https://googleapis.github.io/google-cloud-python/latest/bigtable/cluster.html#google.cloud.bigtable.cluster.Cluster.update + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigtableUpdateClusterOperator` + + :type instance_id: str + :param instance_id: The ID of the Cloud Bigtable instance. + :type cluster_id: str + :param cluster_id: The ID of the Cloud Bigtable cluster to update. + :type nodes: int + :param nodes: The desired number of nodes for the Cloud Bigtable cluster. + :type project_id: str + :param project_id: Optional, the ID of the Google Cloud project. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + REQUIRED_ATTRIBUTES = ("instance_id", "cluster_id", "nodes") # type: Iterable[str] + template_fields = [ + "project_id", + "instance_id", + "cluster_id", + "nodes", + "impersonation_chain", + ] # type: Iterable[str] + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + cluster_id: str, + nodes: int, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.project_id = project_id + self.instance_id = instance_id + self.cluster_id = cluster_id + self.nodes = nodes + self._validate_inputs() + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def execute(self, context) -> None: + hook = BigtableHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + instance = hook.get_instance( + project_id=self.project_id, instance_id=self.instance_id + ) + if not instance: + raise AirflowException( + f"Dependency: instance '{self.instance_id}' does not exist." + ) + + try: + hook.update_cluster( + instance=instance, cluster_id=self.cluster_id, nodes=self.nodes + ) + except google.api_core.exceptions.NotFound: + raise AirflowException( + "Dependency: cluster '{}' does not exist for instance '{}'.".format( + self.cluster_id, self.instance_id + ) + ) + except google.api_core.exceptions.GoogleAPICallError as e: + self.log.error("An error occurred. Exiting.") + raise e diff --git a/reference/providers/google/cloud/operators/cloud_build.py b/reference/providers/google/cloud/operators/cloud_build.py new file mode 100644 index 0000000..7a4921d --- /dev/null +++ b/reference/providers/google/cloud/operators/cloud_build.py @@ -0,0 +1,258 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operators that integrate with Google Cloud Build service.""" +import json +import re +from copy import deepcopy +from typing import Any, Dict, Optional, Sequence, Union +from urllib.parse import unquote, urlparse + +try: + import airflow.utils.yaml as yaml +except ImportError: + import yaml + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.cloud_build import CloudBuildHook +from airflow.utils.decorators import apply_defaults + +REGEX_REPO_PATH = re.compile(r"^/p/(?P[^/]+)/r/(?P[^/]+)") + + +class BuildProcessor: + """ + Processes build configurations to add additional functionality to support the use of operators. + + The following improvements are made: + + * It is required to provide the source and only one type can be given, + * It is possible to provide the source as the URL address instead dict. + + :param body: The request body. + See: https://cloud.google.com/cloud-build/docs/api/reference/rest/v1/projects.builds + :type body: dict + """ + + def __init__(self, body: dict) -> None: + self.body = deepcopy(body) + + def _verify_source(self) -> None: + is_storage = "storageSource" in self.body["source"] + is_repo = "repoSource" in self.body["source"] + + sources_count = sum([is_storage, is_repo]) + + if sources_count != 1: + raise AirflowException( + "The source could not be determined. Please choose one data source from: " + "storageSource and repoSource." + ) + + def _reformat_source(self) -> None: + self._reformat_repo_source() + self._reformat_storage_source() + + def _reformat_repo_source(self) -> None: + if "repoSource" not in self.body["source"]: + return + + source = self.body["source"]["repoSource"] + + if not isinstance(source, str): + return + + self.body["source"]["repoSource"] = self._convert_repo_url_to_dict(source) + + def _reformat_storage_source(self) -> None: + if "storageSource" not in self.body["source"]: + return + + source = self.body["source"]["storageSource"] + + if not isinstance(source, str): + return + + self.body["source"]["storageSource"] = self._convert_storage_url_to_dict(source) + + def process_body(self) -> dict: + """ + Processes the body passed in the constructor + + :return: the body. + :type: dict + """ + if "source" in self.body: + self._verify_source() + self._reformat_source() + return self.body + + @staticmethod + def _convert_repo_url_to_dict(source): + """ + Convert url to repository in Google Cloud Source to a format supported by the API + + Example valid input: + + .. code-block:: none + + https://source.developers.google.com/p/airflow-project/r/airflow-repo#branch-name + + """ + url_parts = urlparse(source) + + match = REGEX_REPO_PATH.search(url_parts.path) + + if ( + url_parts.scheme != "https" + or url_parts.hostname != "source.developers.google.com" + or not match + ): + raise AirflowException( + "Invalid URL. You must pass the URL in the format: " + "https://source.developers.google.com/p/airflow-project/r/airflow-repo#branch-name" + ) + + project_id = unquote(match.group("project_id")) + repo_name = unquote(match.group("repo_name")) + + source_dict = { + "projectId": project_id, + "repoName": repo_name, + "branchName": "master", + } + + if url_parts.fragment: + source_dict["branchName"] = url_parts.fragment + + return source_dict + + @staticmethod + def _convert_storage_url_to_dict(storage_url: str) -> Dict[str, Any]: + """ + Convert url to object in Google Cloud Storage to a format supported by the API + + Example valid input: + + .. code-block:: none + + gs://bucket-name/object-name.tar.gz + + """ + url_parts = urlparse(storage_url) + + if ( + url_parts.scheme != "gs" + or not url_parts.hostname + or not url_parts.path + or url_parts.path == "/" + ): + raise AirflowException( + "Invalid URL. You must pass the URL in the format: " + "gs://bucket-name/object-name.tar.gz#24565443" + ) + + source_dict = {"bucket": url_parts.hostname, "object": url_parts.path[1:]} + + if url_parts.fragment: + source_dict["generation"] = url_parts.fragment + + return source_dict + + +class CloudBuildCreateBuildOperator(BaseOperator): + """ + Starts a build with the specified configuration. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudBuildCreateBuildOperator` + + :param body: The build config with instructions to perform with CloudBuild. + Can be a dictionary or path to a file type like YAML or JSON. + See: https://cloud.google.com/cloud-build/docs/api/reference/rest/v1/projects.builds + :type body: dict or string + :param project_id: ID of the Google Cloud project if None then + default project_id is used. + :type project_id: str + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (for example v1 or v1beta1). + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + template_ext = [".yml", ".yaml", ".json"] + + @apply_defaults + def __init__( + self, + *, + body: Union[dict, str], + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.body = body + # Not template fields to keep original value + self.body_raw = body + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self._validate_inputs() + self.impersonation_chain = impersonation_chain + + def prepare_template(self) -> None: + # if no file is specified, skip + if not isinstance(self.body_raw, str): + return + with open(self.body_raw) as file: + if any(self.body_raw.endswith(ext) for ext in [".yaml", ".yml"]): + self.body = yaml.load(file.read(), Loader=yaml.FullLoader) + if self.body_raw.endswith(".json"): + self.body = json.loads(file.read()) + + def _validate_inputs(self) -> None: + if not self.body: + raise AirflowException("The required parameter 'body' is missing") + + def execute(self, context): + hook = CloudBuildHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + body = BuildProcessor(body=self.body).process_body() + return hook.create_build(body=body, project_id=self.project_id) diff --git a/reference/providers/google/cloud/operators/cloud_memorystore.py b/reference/providers/google/cloud/operators/cloud_memorystore.py new file mode 100644 index 0000000..998acba --- /dev/null +++ b/reference/providers/google/cloud/operators/cloud_memorystore.py @@ -0,0 +1,1735 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operators for Google Cloud Memorystore service""" +from typing import Dict, Optional, Sequence, Tuple, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.cloud_memorystore import ( + CloudMemorystoreHook, + CloudMemorystoreMemcachedHook, +) +from airflow.utils.decorators import apply_defaults +from google.api_core.retry import Retry +from google.cloud.memcache_v1beta2.types import cloud_memcache +from google.cloud.redis_v1 import ( + FailoverInstanceRequest, + InputConfig, + Instance, + OutputConfig, +) +from google.protobuf.field_mask_pb2 import FieldMask + + +class CloudMemorystoreCreateInstanceOperator(BaseOperator): + """ + Creates a Redis instance based on the specified tier and memory size. + + By default, the instance is accessible from the project's `default network + `__. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreCreateInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: Required. The logical name of the Redis instance in the customer project with the + following restrictions: + + - Must contain only lowercase letters, numbers, and hyphens. + - Must start with a letter. + - Must be between 1-40 characters. + - Must end with a number or a letter. + - Must be unique within the customer project / location + :type instance_id: str + :param instance: Required. A Redis [Instance] resource + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.Instance` + :type instance: Union[Dict, google.cloud.redis_v1.types.Instance] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "instance_id", + "instance", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + instance_id: str, + instance: Union[Dict, Instance], + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.instance_id = instance_id + self.instance = instance + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict): + hook = CloudMemorystoreHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + result = hook.create_instance( + location=self.location, + instance_id=self.instance_id, + instance=self.instance, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Instance.to_dict(result) + + +class CloudMemorystoreDeleteInstanceOperator(BaseOperator): + """ + Deletes a specific Redis instance. Instance stops serving and data is deleted. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreDeleteInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "instance", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + instance: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.instance = instance + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudMemorystoreHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.delete_instance( + location=self.location, + instance=self.instance, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreExportInstanceOperator(BaseOperator): + """ + Export Redis instance data into a Redis RDB format file in Cloud Storage. + + Redis will continue serving during this operation. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreExportInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param output_config: Required. Specify data to be exported. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.OutputConfig` + :type output_config: Union[Dict, google.cloud.redis_v1.types.OutputConfig] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "instance", + "output_config", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + instance: str, + output_config: Union[Dict, OutputConfig], + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.instance = instance + self.output_config = output_config + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudMemorystoreHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + + hook.export_instance( + location=self.location, + instance=self.instance, + output_config=self.output_config, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreFailoverInstanceOperator(BaseOperator): + """ + Initiates a failover of the master node to current replica node for a specific STANDARD tier Cloud + Memorystore for Redis instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreFailoverInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param data_protection_mode: Optional. Available data protection modes that the user can choose. If it's + unspecified, data protection mode will be LIMITED_DATA_LOSS by default. + :type data_protection_mode: google.cloud.redis_v1.gapic.enums.FailoverInstanceRequest.DataProtectionMode + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "instance", + "data_protection_mode", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + instance: str, + data_protection_mode: FailoverInstanceRequest.DataProtectionMode, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.instance = instance + self.data_protection_mode = data_protection_mode + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudMemorystoreHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.failover_instance( + location=self.location, + instance=self.instance, + data_protection_mode=self.data_protection_mode, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreGetInstanceOperator(BaseOperator): + """ + Gets the details of a specific Redis instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreGetInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "instance", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + instance: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.instance = instance + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict): + hook = CloudMemorystoreHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + result = hook.get_instance( + location=self.location, + instance=self.instance, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Instance.to_dict(result) + + +class CloudMemorystoreImportOperator(BaseOperator): + """ + Import a Redis RDB snapshot file from Cloud Storage into a Redis instance. + + Redis may stop serving during this operation. Instance state will be IMPORTING for entire operation. When + complete, the instance will contain only data from the imported file. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreImportOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param input_config: Required. Specify data to be imported. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.InputConfig` + :type input_config: Union[Dict, google.cloud.redis_v1.types.InputConfig] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "instance", + "input_config", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + instance: str, + input_config: Union[Dict, InputConfig], + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.instance = instance + self.input_config = input_config + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudMemorystoreHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.import_instance( + location=self.location, + instance=self.instance, + input_config=self.input_config, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreListInstancesOperator(BaseOperator): + """ + Lists all Redis instances owned by a project in either the specified location (region) or all locations. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreListInstancesOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + If it is specified as ``-`` (wildcard), then all regions available to the project are + queried, and the results are aggregated. + :type location: str + :param page_size: The maximum number of resources contained in the underlying API response. If page + streaming is performed per- resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number of resources in a page. + :type page_size: int + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "page_size", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + page_size: int, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.page_size = page_size + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict): + hook = CloudMemorystoreHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + result = hook.list_instances( + location=self.location, + page_size=self.page_size, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + instances = [Instance.to_dict(a) for a in result] + return instances + + +class CloudMemorystoreUpdateInstanceOperator(BaseOperator): + """ + Updates the metadata and configuration of a specific Redis instance. + + :param update_mask: Required. Mask of fields to update. At least one path must be supplied in this field. + The elements of the repeated paths field may only include these fields from ``Instance``: + + - ``displayName`` + - ``labels`` + - ``memorySizeGb`` + - ``redisConfig`` + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.FieldMask` + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreUpdateInstanceOperator` + + :type update_mask: Union[Dict, google.cloud.redis_v1.types.FieldMask] + :param instance: Required. Update description. Only fields specified in update_mask are updated. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.Instance` + :type instance: Union[Dict, google.cloud.redis_v1.types.Instance] + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: The logical name of the Redis instance in the customer project. + :type instance_id: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "update_mask", + "instance", + "location", + "instance_id", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + update_mask: Union[Dict, FieldMask], + instance: Union[Dict, Instance], + location: Optional[str] = None, + instance_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.update_mask = update_mask + self.instance = instance + self.location = location + self.instance_id = instance_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudMemorystoreHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.update_instance( + update_mask=self.update_mask, + instance=self.instance, + location=self.location, + instance_id=self.instance_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreScaleInstanceOperator(BaseOperator): + """ + Updates the metadata and configuration of a specific Redis instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreScaleInstanceOperator` + + :param memory_size_gb: Redis memory size in GiB. + :type memory_size_gb: int + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: The logical name of the Redis instance in the customer project. + :type instance_id: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "memory_size_gb", + "location", + "instance_id", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + memory_size_gb: int, + location: Optional[str] = None, + instance_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.memory_size_gb = memory_size_gb + self.location = location + self.instance_id = instance_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudMemorystoreHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + + hook.update_instance( + update_mask={"paths": ["memory_size_gb"]}, + instance={"memory_size_gb": self.memory_size_gb}, + location=self.location, + instance_id=self.instance_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreCreateInstanceAndImportOperator(BaseOperator): + """ + Creates a Redis instance based on the specified tier and memory size and import a Redis RDB snapshot file + from Cloud Storage into a this instance. + + By default, the instance is accessible from the project's `default network + `__. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreCreateInstanceAndImportOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: Required. The logical name of the Redis instance in the customer project with the + following restrictions: + + - Must contain only lowercase letters, numbers, and hyphens. + - Must start with a letter. + - Must be between 1-40 characters. + - Must end with a number or a letter. + - Must be unique within the customer project / location + :type instance_id: str + :param instance: Required. A Redis [Instance] resource + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.Instance` + :type instance: Union[Dict, google.cloud.redis_v1.types.Instance] + :param input_config: Required. Specify data to be imported. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.InputConfig` + :type input_config: Union[Dict, google.cloud.redis_v1.types.InputConfig] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "instance_id", + "instance", + "input_config", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + instance_id: str, + instance: Union[Dict, Instance], + input_config: Union[Dict, InputConfig], + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.instance_id = instance_id + self.instance = instance + self.input_config = input_config + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudMemorystoreHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + + hook.create_instance( + location=self.location, + instance_id=self.instance_id, + instance=self.instance, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + hook.import_instance( + location=self.location, + instance=self.instance_id, + input_config=self.input_config, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreExportAndDeleteInstanceOperator(BaseOperator): + """ + Export Redis instance data into a Redis RDB format file in Cloud Storage. In next step, deletes a this + instance. + + Redis will continue serving during this operation. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreExportAndDeleteInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param output_config: Required. Specify data to be exported. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.OutputConfig` + :type output_config: Union[Dict, google.cloud.redis_v1.types.OutputConfig] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "instance", + "output_config", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + instance: str, + output_config: Union[Dict, OutputConfig], + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.instance = instance + self.output_config = output_config + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudMemorystoreHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + + hook.export_instance( + location=self.location, + instance=self.instance, + output_config=self.output_config, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + hook.delete_instance( + location=self.location, + instance=self.instance, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreMemcachedApplyParametersOperator(BaseOperator): + """ + Will update current set of Parameters to the set of specified nodes of the Memcached Instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreMemcachedApplyParametersOperator` + + :param node_ids: Nodes to which we should apply the instance-level parameter group. + :type node_ids: Sequence[str] + :param apply_all: Whether to apply instance-level parameter group to all nodes. If set to true, + will explicitly restrict users from specifying any nodes, and apply parameter group updates + to all nodes within the instance. + :type apply_all: bool + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: The logical name of the Memcached instance in the customer project. + :type instance_id: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + + template_fields = ( + "node_ids", + "apply_all", + "location", + "instance_id", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + node_ids: Sequence[str], + apply_all: bool, + location: str, + instance_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.node_ids = node_ids + self.apply_all = apply_all + self.location = location + self.instance_id = instance_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = CloudMemorystoreMemcachedHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.apply_parameters( + node_ids=self.node_ids, + apply_all=self.apply_all, + location=self.location, + instance_id=self.instance_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreMemcachedCreateInstanceOperator(BaseOperator): + """ + Creates a Memcached instance based on the specified tier and memory size. + + By default, the instance is accessible from the project's `default network + `__. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreMemcachedCreateInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: Required. The logical name of the Memcached instance in the customer project with the + following restrictions: + + - Must contain only lowercase letters, numbers, and hyphens. + - Must start with a letter. + - Must be between 1-40 characters. + - Must end with a number or a letter. + - Must be unique within the customer project / location + :type instance_id: str + :param instance: Required. A Memcached [Instance] resource + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.memcache_v1beta2.types.cloud_memcache.Instance` + :type instance: Union[Dict, google.cloud.memcache_v1beta2.types.cloud_memcache.Instance] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ( + "location", + "instance_id", + "instance", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + location: str, + instance_id: str, + instance: Union[Dict, cloud_memcache.Instance], + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.location = location + self.instance_id = instance_id + self.instance = instance + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = CloudMemorystoreMemcachedHook(gcp_conn_id=self.gcp_conn_id) + result = hook.create_instance( + location=self.location, + instance_id=self.instance_id, + instance=self.instance, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return cloud_memcache.Instance.to_dict(result) + + +class CloudMemorystoreMemcachedDeleteInstanceOperator(BaseOperator): + """ + Deletes a specific Memcached instance. Instance stops serving and data is deleted. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreMemcachedDeleteInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Memcached instance in the customer project. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ( + "location", + "instance", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + location: str, + instance: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.location = location + self.instance = instance + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = CloudMemorystoreMemcachedHook(gcp_conn_id=self.gcp_conn_id) + hook.delete_instance( + location=self.location, + instance=self.instance, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreMemcachedGetInstanceOperator(BaseOperator): + """ + Gets the details of a specific Memcached instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreMemcachedGetInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Memcached instance in the customer project. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "instance", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + instance: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.instance = instance + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = CloudMemorystoreMemcachedHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + result = hook.get_instance( + location=self.location, + instance=self.instance, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return cloud_memcache.Instance.to_dict(result) + + +class CloudMemorystoreMemcachedListInstancesOperator(BaseOperator): + """ + Lists all Memcached instances owned by a project in either the specified location (region) or all + locations. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreMemcachedListInstancesOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + If it is specified as ``-`` (wildcard), then all regions available to the project are + queried, and the results are aggregated. + :type location: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = CloudMemorystoreMemcachedHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + result = hook.list_instances( + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + instances = [cloud_memcache.Instance.to_dict(a) for a in result] + return instances + + +class CloudMemorystoreMemcachedUpdateInstanceOperator(BaseOperator): + """ + Updates the metadata and configuration of a specific Memcached instance. + + :param update_mask: Required. Mask of fields to update. At least one path must be supplied in this field. + The elements of the repeated paths field may only include these fields from ``Instance``: + + - ``displayName`` + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.memcache_v1beta2.types.cloud_memcache.field_mask.FieldMas` + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreMemcachedUpdateInstanceOperator` + + :type update_mask: Union[Dict, google.cloud.memcache_v1beta2.types.cloud_memcache.field_mask.FieldMask] + :param instance: Required. Update description. Only fields specified in update_mask are updated. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.memcache_v1beta2.types.cloud_memcache.Instance` + :type instance: Union[Dict, google.cloud.memcache_v1beta2.types.cloud_memcache.Instance] + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: The logical name of the Memcached instance in the customer project. + :type instance_id: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "update_mask", + "instance", + "location", + "instance_id", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + update_mask: Union[Dict, cloud_memcache.field_mask.FieldMask], + instance: Union[Dict, cloud_memcache.Instance], + location: Optional[str] = None, + instance_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.update_mask = update_mask + self.instance = instance + self.location = location + self.instance_id = instance_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = CloudMemorystoreMemcachedHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.update_instance( + update_mask=self.update_mask, + instance=self.instance, + location=self.location, + instance_id=self.instance_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreMemcachedUpdateParametersOperator(BaseOperator): + """ + Updates the defined Memcached Parameters for an existing Instance. This method only stages the + parameters, it must be followed by apply_parameters to apply the parameters to nodes of + the Memcached Instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreMemcachedApplyParametersOperator` + + :param update_mask: Required. Mask of fields to update. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.memcache_v1beta2.types.cloud_memcache.field_mask.FieldMask` + :type update_mask: + Union[Dict, google.cloud.memcache_v1beta2.types.cloud_memcache.field_mask.FieldMask] + :param parameters: The parameters to apply to the instance. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.memcache_v1beta2.types.cloud_memcache.MemcacheParameters` + :type parameters: Union[Dict, google.cloud.memcache_v1beta2.types.cloud_memcache.MemcacheParameters] + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: The logical name of the Memcached instance in the customer project. + :type instance_id: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + + template_fields = ( + "update_mask", + "parameters", + "location", + "instance_id", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + update_mask: Union[Dict, cloud_memcache.field_mask.FieldMask], + parameters: Union[Dict, cloud_memcache.MemcacheParameters], + location: str, + instance_id: str, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.update_mask = update_mask + self.parameters = parameters + self.location = location + self.instance_id = instance_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = CloudMemorystoreMemcachedHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.update_parameters( + update_mask=self.update_mask, + parameters=self.parameters, + location=self.location, + instance_id=self.instance_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) diff --git a/reference/providers/google/cloud/operators/cloud_sql.py b/reference/providers/google/cloud/operators/cloud_sql.py new file mode 100644 index 0000000..274c083 --- /dev/null +++ b/reference/providers/google/cloud/operators/cloud_sql.py @@ -0,0 +1,1155 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud SQL operators.""" +from typing import Dict, Iterable, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.cloud_sql import ( + CloudSQLDatabaseHook, + CloudSQLHook, +) +from airflow.providers.google.cloud.utils.field_validator import GcpBodyFieldValidator +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.providers.postgres.hooks.postgres import PostgresHook +from airflow.utils.decorators import apply_defaults +from googleapiclient.errors import HttpError + +SETTINGS = "settings" +SETTINGS_VERSION = "settingsVersion" + +CLOUD_SQL_CREATE_VALIDATION = [ + dict(name="name", allow_empty=False), + dict( + name="settings", + type="dict", + fields=[ + dict(name="tier", allow_empty=False), + dict( + name="backupConfiguration", + type="dict", + fields=[ + dict(name="binaryLogEnabled", optional=True), + dict(name="enabled", optional=True), + dict(name="replicationLogArchivingEnabled", optional=True), + dict(name="startTime", allow_empty=False, optional=True), + ], + optional=True, + ), + dict(name="activationPolicy", allow_empty=False, optional=True), + dict(name="authorizedGaeApplications", type="list", optional=True), + dict(name="crashSafeReplicationEnabled", optional=True), + dict(name="dataDiskSizeGb", optional=True), + dict(name="dataDiskType", allow_empty=False, optional=True), + dict(name="databaseFlags", type="list", optional=True), + dict( + name="ipConfiguration", + type="dict", + fields=[ + dict( + name="authorizedNetworks", + type="list", + fields=[ + dict(name="expirationTime", optional=True), + dict(name="name", allow_empty=False, optional=True), + dict(name="value", allow_empty=False, optional=True), + ], + optional=True, + ), + dict(name="ipv4Enabled", optional=True), + dict(name="privateNetwork", allow_empty=False, optional=True), + dict(name="requireSsl", optional=True), + ], + optional=True, + ), + dict( + name="locationPreference", + type="dict", + fields=[ + dict(name="followGaeApplication", allow_empty=False, optional=True), + dict(name="zone", allow_empty=False, optional=True), + ], + optional=True, + ), + dict( + name="maintenanceWindow", + type="dict", + fields=[ + dict(name="hour", optional=True), + dict(name="day", optional=True), + dict(name="updateTrack", allow_empty=False, optional=True), + ], + optional=True, + ), + dict(name="pricingPlan", allow_empty=False, optional=True), + dict(name="replicationType", allow_empty=False, optional=True), + dict(name="storageAutoResize", optional=True), + dict(name="storageAutoResizeLimit", optional=True), + dict(name="userLabels", type="dict", optional=True), + ], + ), + dict(name="databaseVersion", allow_empty=False, optional=True), + dict( + name="failoverReplica", + type="dict", + fields=[dict(name="name", allow_empty=False)], + optional=True, + ), + dict(name="masterInstanceName", allow_empty=False, optional=True), + dict(name="onPremisesConfiguration", type="dict", optional=True), + dict(name="region", allow_empty=False, optional=True), + dict( + name="replicaConfiguration", + type="dict", + fields=[ + dict(name="failoverTarget", optional=True), + dict( + name="mysqlReplicaConfiguration", + type="dict", + fields=[ + dict(name="caCertificate", allow_empty=False, optional=True), + dict(name="clientCertificate", allow_empty=False, optional=True), + dict(name="clientKey", allow_empty=False, optional=True), + dict(name="connectRetryInterval", optional=True), + dict(name="dumpFilePath", allow_empty=False, optional=True), + dict(name="masterHeartbeatPeriod", optional=True), + dict(name="password", allow_empty=False, optional=True), + dict(name="sslCipher", allow_empty=False, optional=True), + dict(name="username", allow_empty=False, optional=True), + dict(name="verifyServerCertificate", optional=True), + ], + optional=True, + ), + ], + optional=True, + ), +] +CLOUD_SQL_EXPORT_VALIDATION = [ + dict( + name="exportContext", + type="dict", + fields=[ + dict(name="fileType", allow_empty=False), + dict(name="uri", allow_empty=False), + dict(name="databases", optional=True, type="list"), + dict( + name="sqlExportOptions", + type="dict", + optional=True, + fields=[ + dict(name="tables", optional=True, type="list"), + dict(name="schemaOnly", optional=True), + ], + ), + dict( + name="csvExportOptions", + type="dict", + optional=True, + fields=[dict(name="selectQuery")], + ), + ], + ) +] +CLOUD_SQL_IMPORT_VALIDATION = [ + dict( + name="importContext", + type="dict", + fields=[ + dict(name="fileType", allow_empty=False), + dict(name="uri", allow_empty=False), + dict(name="database", optional=True, allow_empty=False), + dict(name="importUser", optional=True), + dict( + name="csvImportOptions", + type="dict", + optional=True, + fields=[ + dict(name="table"), + dict(name="columns", type="list", optional=True), + ], + ), + ], + ) +] +CLOUD_SQL_DATABASE_CREATE_VALIDATION = [ + dict(name="instance", allow_empty=False), + dict(name="name", allow_empty=False), + dict(name="project", allow_empty=False), +] +CLOUD_SQL_DATABASE_PATCH_VALIDATION = [ + dict(name="instance", optional=True), + dict(name="name", optional=True), + dict(name="project", optional=True), + dict(name="etag", optional=True), + dict(name="charset", optional=True), + dict(name="collation", optional=True), +] + + +class CloudSQLBaseOperator(BaseOperator): + """ + Abstract base operator for Google Cloud SQL operators to inherit from. + + :param instance: Cloud SQL instance ID. This does not include the project ID. + :type instance: str + :param project_id: Optional, Google Cloud Project ID. f set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1beta4). + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + @apply_defaults + def __init__( + self, + *, + instance: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.project_id = project_id + self.instance = instance + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.impersonation_chain = impersonation_chain + self._validate_inputs() + super().__init__(**kwargs) + + def _validate_inputs(self) -> None: + if self.project_id == "": + raise AirflowException("The required parameter 'project_id' is empty") + if not self.instance: + raise AirflowException("The required parameter 'instance' is empty or None") + + def _check_if_instance_exists( + self, instance, hook: CloudSQLHook + ) -> Union[dict, bool]: + try: + return hook.get_instance(project_id=self.project_id, instance=instance) + except HttpError as e: + status = e.resp.status + if status == 404: + return False + raise e + + def _check_if_db_exists(self, db_name, hook: CloudSQLHook) -> Union[dict, bool]: + try: + return hook.get_database( + project_id=self.project_id, instance=self.instance, database=db_name + ) + except HttpError as e: + status = e.resp.status + if status == 404: + return False + raise e + + def execute(self, context): + pass + + @staticmethod + def _get_settings_version(instance): + return instance.get(SETTINGS).get(SETTINGS_VERSION) + + +class CloudSQLCreateInstanceOperator(CloudSQLBaseOperator): + """ + Creates a new Cloud SQL instance. + If an instance with the same name exists, no action will be taken and + the operator will succeed. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudSQLCreateInstanceOperator` + + :param body: Body required by the Cloud SQL insert API, as described in + https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/instances/insert + #request-body + :type body: dict + :param instance: Cloud SQL instance ID. This does not include the project ID. + :type instance: str + :param project_id: Optional, Google Cloud Project ID. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1beta4). + :type api_version: str + :param validate_body: True if body should be validated, False otherwise. + :type validate_body: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_sql_create_template_fields] + template_fields = ( + "project_id", + "instance", + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gcp_sql_create_template_fields] + + @apply_defaults + def __init__( + self, + *, + body: dict, + instance: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.body = body + self.validate_body = validate_body + super().__init__( + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.body: + raise AirflowException("The required parameter 'body' is empty") + + def _validate_body_fields(self) -> None: + if self.validate_body: + GcpBodyFieldValidator( + CLOUD_SQL_CREATE_VALIDATION, api_version=self.api_version + ).validate(self.body) + + def execute(self, context) -> None: + hook = CloudSQLHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self._validate_body_fields() + if not self._check_if_instance_exists(self.instance, hook): + hook.create_instance(project_id=self.project_id, body=self.body) + else: + self.log.info( + "Cloud SQL instance with ID %s already exists. Aborting create.", + self.instance, + ) + + instance_resource = hook.get_instance( + project_id=self.project_id, instance=self.instance + ) + service_account_email = instance_resource["serviceAccountEmailAddress"] + task_instance = context["task_instance"] + task_instance.xcom_push( + key="service_account_email", value=service_account_email + ) + + +class CloudSQLInstancePatchOperator(CloudSQLBaseOperator): + """ + Updates settings of a Cloud SQL instance. + + Caution: This is a partial update, so only included values for the settings will be + updated. + + In the request body, supply the relevant portions of an instance resource, according + to the rules of patch semantics. + https://cloud.google.com/sql/docs/mysql/admin-api/how-tos/performance#patch + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudSQLInstancePatchOperator` + + :param body: Body required by the Cloud SQL patch API, as described in + https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/instances/patch#request-body + :type body: dict + :param instance: Cloud SQL instance ID. This does not include the project ID. + :type instance: str + :param project_id: Optional, Google Cloud Project ID. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1beta4). + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_sql_patch_template_fields] + template_fields = ( + "project_id", + "instance", + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gcp_sql_patch_template_fields] + + @apply_defaults + def __init__( + self, + *, + body: dict, + instance: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.body = body + super().__init__( + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.body: + raise AirflowException("The required parameter 'body' is empty") + + def execute(self, context): + hook = CloudSQLHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + if not self._check_if_instance_exists(self.instance, hook): + raise AirflowException( + "Cloud SQL instance with ID {} does not exist. " + "Please specify another instance to patch.".format(self.instance) + ) + else: + return hook.patch_instance( + project_id=self.project_id, body=self.body, instance=self.instance + ) + + +class CloudSQLDeleteInstanceOperator(CloudSQLBaseOperator): + """ + Deletes a Cloud SQL instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudSQLDeleteInstanceOperator` + + :param instance: Cloud SQL instance ID. This does not include the project ID. + :type instance: str + :param project_id: Optional, Google Cloud Project ID. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1beta4). + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_sql_delete_template_fields] + template_fields = ( + "project_id", + "instance", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gcp_sql_delete_template_fields] + + @apply_defaults + def __init__( + self, + *, + instance: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__( + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def execute(self, context) -> Optional[bool]: + hook = CloudSQLHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + if not self._check_if_instance_exists(self.instance, hook): + print( + f"Cloud SQL instance with ID {self.instance} does not exist. Aborting delete." + ) + return True + else: + return hook.delete_instance( + project_id=self.project_id, instance=self.instance + ) + + +class CloudSQLCreateInstanceDatabaseOperator(CloudSQLBaseOperator): + """ + Creates a new database inside a Cloud SQL instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudSQLCreateInstanceDatabaseOperator` + + :param instance: Database instance ID. This does not include the project ID. + :type instance: str + :param body: The request body, as described in + https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/databases/insert#request-body + :type body: dict + :param project_id: Optional, Google Cloud Project ID. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1beta4). + :type api_version: str + :param validate_body: Whether the body should be validated. Defaults to True. + :type validate_body: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_sql_db_create_template_fields] + template_fields = ( + "project_id", + "instance", + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gcp_sql_db_create_template_fields] + + @apply_defaults + def __init__( + self, + *, + instance: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.body = body + self.validate_body = validate_body + super().__init__( + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.body: + raise AirflowException("The required parameter 'body' is empty") + + def _validate_body_fields(self) -> None: + if self.validate_body: + GcpBodyFieldValidator( + CLOUD_SQL_DATABASE_CREATE_VALIDATION, api_version=self.api_version + ).validate(self.body) + + def execute(self, context) -> Optional[bool]: + self._validate_body_fields() + database = self.body.get("name") + if not database: + self.log.error( + "Body doesn't contain 'name'. Cannot check if the" + " database already exists in the instance %s.", + self.instance, + ) + return False + hook = CloudSQLHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + if self._check_if_db_exists(database, hook): + self.log.info( + "Cloud SQL instance with ID %s already contains database '%s'. Aborting database insert.", + self.instance, + database, + ) + return True + else: + return hook.create_database( + project_id=self.project_id, instance=self.instance, body=self.body + ) + + +class CloudSQLPatchInstanceDatabaseOperator(CloudSQLBaseOperator): + """ + Updates a resource containing information about a database inside a Cloud SQL + instance using patch semantics. + See: https://cloud.google.com/sql/docs/mysql/admin-api/how-tos/performance#patch + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudSQLPatchInstanceDatabaseOperator` + + :param instance: Database instance ID. This does not include the project ID. + :type instance: str + :param database: Name of the database to be updated in the instance. + :type database: str + :param body: The request body, as described in + https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/databases/patch#request-body + :type body: dict + :param project_id: Optional, Google Cloud Project ID. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1beta4). + :type api_version: str + :param validate_body: Whether the body should be validated. Defaults to True. + :type validate_body: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_sql_db_patch_template_fields] + template_fields = ( + "project_id", + "instance", + "body", + "database", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gcp_sql_db_patch_template_fields] + + @apply_defaults + def __init__( + self, + *, + instance: str, + database: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.database = database + self.body = body + self.validate_body = validate_body + super().__init__( + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.body: + raise AirflowException("The required parameter 'body' is empty") + if not self.database: + raise AirflowException("The required parameter 'database' is empty") + + def _validate_body_fields(self) -> None: + if self.validate_body: + GcpBodyFieldValidator( + CLOUD_SQL_DATABASE_PATCH_VALIDATION, api_version=self.api_version + ).validate(self.body) + + def execute(self, context) -> None: + self._validate_body_fields() + hook = CloudSQLHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + if not self._check_if_db_exists(self.database, hook): + raise AirflowException( + "Cloud SQL instance with ID {instance} does not contain " + "database '{database}'. " + "Please specify another database to patch.".format( + instance=self.instance, database=self.database + ) + ) + else: + return hook.patch_database( + project_id=self.project_id, + instance=self.instance, + database=self.database, + body=self.body, + ) + + +class CloudSQLDeleteInstanceDatabaseOperator(CloudSQLBaseOperator): + """ + Deletes a database from a Cloud SQL instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudSQLDeleteInstanceDatabaseOperator` + + :param instance: Database instance ID. This does not include the project ID. + :type instance: str + :param database: Name of the database to be deleted in the instance. + :type database: str + :param project_id: Optional, Google Cloud Project ID. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1beta4). + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_sql_db_delete_template_fields] + template_fields = ( + "project_id", + "instance", + "database", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gcp_sql_db_delete_template_fields] + + @apply_defaults + def __init__( + self, + *, + instance: str, + database: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.database = database + super().__init__( + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.database: + raise AirflowException("The required parameter 'database' is empty") + + def execute(self, context) -> Optional[bool]: + hook = CloudSQLHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + if not self._check_if_db_exists(self.database, hook): + print( + "Cloud SQL instance with ID {} does not contain database '{}'. " + "Aborting database delete.".format(self.instance, self.database) + ) + return True + else: + return hook.delete_database( + project_id=self.project_id, + instance=self.instance, + database=self.database, + ) + + +class CloudSQLExportInstanceOperator(CloudSQLBaseOperator): + """ + Exports data from a Cloud SQL instance to a Cloud Storage bucket as a SQL dump + or CSV file. + + Note: This operator is idempotent. If executed multiple times with the same + export file URI, the export file in GCS will simply be overridden. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudSQLExportInstanceOperator` + + :param instance: Cloud SQL instance ID. This does not include the project ID. + :type instance: str + :param body: The request body, as described in + https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/instances/export#request-body + :type body: dict + :param project_id: Optional, Google Cloud Project ID. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1beta4). + :type api_version: str + :param validate_body: Whether the body should be validated. Defaults to True. + :type validate_body: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_sql_export_template_fields] + template_fields = ( + "project_id", + "instance", + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gcp_sql_export_template_fields] + + @apply_defaults + def __init__( + self, + *, + instance: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.body = body + self.validate_body = validate_body + super().__init__( + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.body: + raise AirflowException("The required parameter 'body' is empty") + + def _validate_body_fields(self) -> None: + if self.validate_body: + GcpBodyFieldValidator( + CLOUD_SQL_EXPORT_VALIDATION, api_version=self.api_version + ).validate(self.body) + + def execute(self, context) -> None: + self._validate_body_fields() + hook = CloudSQLHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + return hook.export_instance( + project_id=self.project_id, instance=self.instance, body=self.body + ) + + +class CloudSQLImportInstanceOperator(CloudSQLBaseOperator): + """ + Imports data into a Cloud SQL instance from a SQL dump or CSV file in Cloud Storage. + + CSV IMPORT: + + This operator is NOT idempotent for a CSV import. If the same file is imported + multiple times, the imported data will be duplicated in the database. + Moreover, if there are any unique constraints the duplicate import may result in an + error. + + SQL IMPORT: + + This operator is idempotent for a SQL import if it was also exported by Cloud SQL. + The exported SQL contains 'DROP TABLE IF EXISTS' statements for all tables + to be imported. + + If the import file was generated in a different way, idempotence is not guaranteed. + It has to be ensured on the SQL file level. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudSQLImportInstanceOperator` + + :param instance: Cloud SQL instance ID. This does not include the project ID. + :type instance: str + :param body: The request body, as described in + https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/instances/export#request-body + :type body: dict + :param project_id: Optional, Google Cloud Project ID. If set to None or missing, + the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1beta4). + :type api_version: str + :param validate_body: Whether the body should be validated. Defaults to True. + :type validate_body: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_sql_import_template_fields] + template_fields = ( + "project_id", + "instance", + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gcp_sql_import_template_fields] + + @apply_defaults + def __init__( + self, + *, + instance: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.body = body + self.validate_body = validate_body + super().__init__( + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.body: + raise AirflowException("The required parameter 'body' is empty") + + def _validate_body_fields(self) -> None: + if self.validate_body: + GcpBodyFieldValidator( + CLOUD_SQL_IMPORT_VALIDATION, api_version=self.api_version + ).validate(self.body) + + def execute(self, context) -> None: + self._validate_body_fields() + hook = CloudSQLHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + return hook.import_instance( + project_id=self.project_id, instance=self.instance, body=self.body + ) + + +class CloudSQLExecuteQueryOperator(BaseOperator): + """ + Performs DML or DDL query on an existing Cloud Sql instance. It optionally uses + cloud-sql-proxy to establish secure connection with the database. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudSQLExecuteQueryOperator` + + :param sql: SQL query or list of queries to run (should be DML or DDL query - + this operator does not return any data from the database, + so it is useless to pass it DQL queries. Note that it is responsibility of the + author of the queries to make sure that the queries are idempotent. For example + you can use CREATE TABLE IF NOT EXISTS to create a table. + :type sql: str or list[str] + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: dict or iterable + :param autocommit: if True, each command is automatically committed. + (default value: False) + :type autocommit: bool + :param gcp_conn_id: The connection ID used to connect to Google Cloud for + cloud-sql-proxy authentication. + :type gcp_conn_id: str + :param gcp_cloudsql_conn_id: The connection ID used to connect to Google Cloud SQL + its schema should be gcpcloudsql://. + See :class:`~airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook` for + details on how to define ``gcpcloudsql://`` connection. + :type gcp_cloudsql_conn_id: str + """ + + # [START gcp_sql_query_template_fields] + template_fields = ("sql", "gcp_cloudsql_conn_id", "gcp_conn_id") + template_ext = (".sql",) + # [END gcp_sql_query_template_fields] + + @apply_defaults + def __init__( + self, + *, + sql: Union[List[str], str], + autocommit: bool = False, + parameters: Optional[Union[Dict, Iterable]] = None, + gcp_conn_id: str = "google_cloud_default", + gcp_cloudsql_conn_id: str = "google_cloud_sql_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sql = sql + self.gcp_conn_id = gcp_conn_id + self.gcp_cloudsql_conn_id = gcp_cloudsql_conn_id + self.autocommit = autocommit + self.parameters = parameters + self.gcp_connection = None + + def _execute_query( + self, hook: CloudSQLDatabaseHook, database_hook: Union[PostgresHook, MySqlHook] + ) -> None: + cloud_sql_proxy_runner = None + try: + if hook.use_proxy: + cloud_sql_proxy_runner = hook.get_sqlproxy_runner() + hook.free_reserved_port() + # There is very, very slim chance that the socket will + # be taken over here by another bind(0). + # It's quite unlikely to happen though! + cloud_sql_proxy_runner.start_proxy() + self.log.info('Executing: "%s"', self.sql) + database_hook.run(self.sql, self.autocommit, parameters=self.parameters) + finally: + if cloud_sql_proxy_runner: + cloud_sql_proxy_runner.stop_proxy() + + def execute(self, context): + self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id) + hook = CloudSQLDatabaseHook( + gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id, + gcp_conn_id=self.gcp_conn_id, + default_gcp_project_id=self.gcp_connection.extra_dejson.get( + "extra__google_cloud_platform__project" + ), + ) + hook.validate_ssl_certs() + connection = hook.create_connection() + hook.validate_socket_path_length() + database_hook = hook.get_database_hook(connection=connection) + try: + self._execute_query(hook, database_hook) + finally: + hook.cleanup_database_hook() diff --git a/reference/providers/google/cloud/operators/cloud_storage_transfer_service.py b/reference/providers/google/cloud/operators/cloud_storage_transfer_service.py new file mode 100644 index 0000000..cc3a7c4 --- /dev/null +++ b/reference/providers/google/cloud/operators/cloud_storage_transfer_service.py @@ -0,0 +1,1121 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains Google Cloud Transfer operators.""" +from copy import deepcopy +from datetime import date, time +from typing import Dict, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( + ACCESS_KEY_ID, + AWS_ACCESS_KEY, + AWS_S3_DATA_SOURCE, + BUCKET_NAME, + DAY, + DESCRIPTION, + GCS_DATA_SINK, + GCS_DATA_SOURCE, + HOURS, + HTTP_DATA_SOURCE, + MINUTES, + MONTH, + NAME, + OBJECT_CONDITIONS, + PROJECT_ID, + SCHEDULE, + SCHEDULE_END_DATE, + SCHEDULE_START_DATE, + SECONDS, + SECRET_ACCESS_KEY, + START_TIME_OF_DAY, + STATUS, + TRANSFER_OPTIONS, + TRANSFER_SPEC, + YEAR, + CloudDataTransferServiceHook, + GcpTransferJobsStatus, +) +from airflow.utils.decorators import apply_defaults + + +class TransferJobPreprocessor: + """Helper class for preprocess of transfer job body.""" + + def __init__( + self, + body: dict, + aws_conn_id: str = "aws_default", + default_schedule: bool = False, + ) -> None: + self.body = body + self.aws_conn_id = aws_conn_id + self.default_schedule = default_schedule + + def _inject_aws_credentials(self) -> None: + if ( + TRANSFER_SPEC in self.body + and AWS_S3_DATA_SOURCE in self.body[TRANSFER_SPEC] + ): + aws_hook = AwsBaseHook(self.aws_conn_id, resource_type="s3") + aws_credentials = aws_hook.get_credentials() + aws_access_key_id = aws_credentials.access_key # type: ignore[attr-defined] + aws_secret_access_key = aws_credentials.secret_key # type: ignore[attr-defined] + self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = { + ACCESS_KEY_ID: aws_access_key_id, + SECRET_ACCESS_KEY: aws_secret_access_key, + } + + def _reformat_date(self, field_key: str) -> None: + schedule = self.body[SCHEDULE] + if field_key not in schedule: + return + if isinstance(schedule[field_key], date): + schedule[field_key] = self._convert_date_to_dict(schedule[field_key]) + + def _reformat_time(self, field_key: str) -> None: + schedule = self.body[SCHEDULE] + if field_key not in schedule: + return + if isinstance(schedule[field_key], time): + schedule[field_key] = self._convert_time_to_dict(schedule[field_key]) + + def _reformat_schedule(self) -> None: + if SCHEDULE not in self.body: + if self.default_schedule: + self.body[SCHEDULE] = { + SCHEDULE_START_DATE: date.today(), + SCHEDULE_END_DATE: date.today(), + } + else: + return + self._reformat_date(SCHEDULE_START_DATE) + self._reformat_date(SCHEDULE_END_DATE) + self._reformat_time(START_TIME_OF_DAY) + + def process_body(self) -> dict: + """ + Injects AWS credentials into body if needed and + reformats schedule information. + + :return: Preprocessed body + :rtype: dict + """ + self._inject_aws_credentials() + self._reformat_schedule() + return self.body + + @staticmethod + def _convert_date_to_dict(field_date: date) -> dict: + """Convert native python ``datetime.date`` object to a format supported by the API""" + return {DAY: field_date.day, MONTH: field_date.month, YEAR: field_date.year} + + @staticmethod + def _convert_time_to_dict(time_object: time) -> dict: + """Convert native python ``datetime.time`` object to a format supported by the API""" + return { + HOURS: time_object.hour, + MINUTES: time_object.minute, + SECONDS: time_object.second, + } + + +class TransferJobValidator: + """Helper class for validating transfer job body.""" + + def __init__(self, body: dict) -> None: + if not body: + raise AirflowException("The required parameter 'body' is empty or None") + + self.body = body + + def _verify_data_source(self) -> None: + is_gcs = GCS_DATA_SOURCE in self.body[TRANSFER_SPEC] + is_aws_s3 = AWS_S3_DATA_SOURCE in self.body[TRANSFER_SPEC] + is_http = HTTP_DATA_SOURCE in self.body[TRANSFER_SPEC] + + sources_count = sum([is_gcs, is_aws_s3, is_http]) + if sources_count > 1: + raise AirflowException( + "More than one data source detected. Please choose exactly one data source from: " + "gcsDataSource, awsS3DataSource and httpDataSource." + ) + + def _restrict_aws_credentials(self) -> None: + aws_transfer = AWS_S3_DATA_SOURCE in self.body[TRANSFER_SPEC] + if ( + aws_transfer + and AWS_ACCESS_KEY in self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE] + ): + raise AirflowException( + "AWS credentials detected inside the body parameter (awsAccessKey). This is not allowed, " + "please use Airflow connections to store credentials." + ) + + def validate_body(self) -> None: + """ + Validates the body. Checks if body specifies `transferSpec` + if yes, then check if AWS credentials are passed correctly and + no more than 1 data source was selected. + + :raises: AirflowException + """ + if TRANSFER_SPEC in self.body: + self._restrict_aws_credentials() + self._verify_data_source() + + +class CloudDataTransferServiceCreateJobOperator(BaseOperator): + """ + Creates a transfer job that runs periodically. + + .. warning:: + + This operator is NOT idempotent in the following cases: + + * `name` is not passed in body param + * transfer job `name` has been soft deleted. In this case, + each new task will receive a unique suffix + + If you run it many times, many transfer jobs will be created in the Google Cloud. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataTransferServiceCreateJobOperator` + + :param body: (Required) The request body, as described in + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs#TransferJob + With three additional improvements: + + * dates can be given in the form :class:`datetime.date` + * times can be given in the form :class:`datetime.time` + * credentials to Amazon Web Service should be stored in the connection and indicated by the + aws_conn_id parameter + + :type body: dict + :param aws_conn_id: The connection ID used to retrieve credentials to + Amazon Web Service. + :type aws_conn_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1). + :type api_version: str + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_transfer_job_create_template_fields] + template_fields = ( + "body", + "gcp_conn_id", + "aws_conn_id", + "google_impersonation_chain", + ) + # [END gcp_transfer_job_create_template_fields] + + @apply_defaults + def __init__( + self, + *, + body: dict, + aws_conn_id: str = "aws_default", + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.body = deepcopy(body) + self.aws_conn_id = aws_conn_id + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.google_impersonation_chain = google_impersonation_chain + self._validate_inputs() + + def _validate_inputs(self) -> None: + TransferJobValidator(body=self.body).validate_body() + + def execute(self, context) -> dict: + TransferJobPreprocessor( + body=self.body, aws_conn_id=self.aws_conn_id + ).process_body() + hook = CloudDataTransferServiceHook( + api_version=self.api_version, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.google_impersonation_chain, + ) + return hook.create_transfer_job(body=self.body) + + +class CloudDataTransferServiceUpdateJobOperator(BaseOperator): + """ + Updates a transfer job that runs periodically. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataTransferServiceUpdateJobOperator` + + :param job_name: (Required) Name of the job to be updated + :type job_name: str + :param body: (Required) The request body, as described in + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/patch#request-body + With three additional improvements: + + * dates can be given in the form :class:`datetime.date` + * times can be given in the form :class:`datetime.time` + * credentials to Amazon Web Service should be stored in the connection and indicated by the + aws_conn_id parameter + + :type body: dict + :param aws_conn_id: The connection ID used to retrieve credentials to + Amazon Web Service. + :type aws_conn_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1). + :type api_version: str + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_transfer_job_update_template_fields] + template_fields = ( + "job_name", + "body", + "gcp_conn_id", + "aws_conn_id", + "google_impersonation_chain", + ) + # [END gcp_transfer_job_update_template_fields] + + @apply_defaults + def __init__( + self, + *, + job_name: str, + body: dict, + aws_conn_id: str = "aws_default", + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_name = job_name + self.body = body + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.aws_conn_id = aws_conn_id + self.google_impersonation_chain = google_impersonation_chain + self._validate_inputs() + + def _validate_inputs(self) -> None: + TransferJobValidator(body=self.body).validate_body() + if not self.job_name: + raise AirflowException("The required parameter 'job_name' is empty or None") + + def execute(self, context) -> dict: + TransferJobPreprocessor( + body=self.body, aws_conn_id=self.aws_conn_id + ).process_body() + hook = CloudDataTransferServiceHook( + api_version=self.api_version, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.google_impersonation_chain, + ) + return hook.update_transfer_job(job_name=self.job_name, body=self.body) + + +class CloudDataTransferServiceDeleteJobOperator(BaseOperator): + """ + Delete a transfer job. This is a soft delete. After a transfer job is + deleted, the job and all the transfer executions are subject to garbage + collection. Transfer jobs become eligible for garbage collection + 30 days after soft delete. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataTransferServiceDeleteJobOperator` + + :param job_name: (Required) Name of the TRANSFER operation + :type job_name: str + :param project_id: (Optional) the ID of the project that owns the Transfer + Job. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1). + :type api_version: str + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_transfer_job_delete_template_fields] + template_fields = ( + "job_name", + "project_id", + "gcp_conn_id", + "api_version", + "google_impersonation_chain", + ) + # [END gcp_transfer_job_delete_template_fields] + + @apply_defaults + def __init__( + self, + *, + job_name: str, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + project_id: Optional[str] = None, + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_name = job_name + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.google_impersonation_chain = google_impersonation_chain + self._validate_inputs() + + def _validate_inputs(self) -> None: + if not self.job_name: + raise AirflowException("The required parameter 'job_name' is empty or None") + + def execute(self, context) -> None: + self._validate_inputs() + hook = CloudDataTransferServiceHook( + api_version=self.api_version, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.google_impersonation_chain, + ) + hook.delete_transfer_job(job_name=self.job_name, project_id=self.project_id) + + +class CloudDataTransferServiceGetOperationOperator(BaseOperator): + """ + Gets the latest state of a long-running operation in Google Storage Transfer + Service. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataTransferServiceGetOperationOperator` + + :param operation_name: (Required) Name of the transfer operation. + :type operation_name: str + :param gcp_conn_id: The connection ID used to connect to Google + Cloud Platform. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1). + :type api_version: str + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_transfer_operation_get_template_fields] + template_fields = ( + "operation_name", + "gcp_conn_id", + "google_impersonation_chain", + ) + # [END gcp_transfer_operation_get_template_fields] + + @apply_defaults + def __init__( + self, + *, + operation_name: str, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.operation_name = operation_name + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.google_impersonation_chain = google_impersonation_chain + self._validate_inputs() + + def _validate_inputs(self) -> None: + if not self.operation_name: + raise AirflowException( + "The required parameter 'operation_name' is empty or None" + ) + + def execute(self, context) -> dict: + hook = CloudDataTransferServiceHook( + api_version=self.api_version, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.google_impersonation_chain, + ) + operation = hook.get_transfer_operation(operation_name=self.operation_name) + return operation + + +class CloudDataTransferServiceListOperationsOperator(BaseOperator): + """ + Lists long-running operations in Google Storage Transfer + Service that match the specified filter. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataTransferServiceListOperationsOperator` + + :param request_filter: (Required) A request filter, as described in + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/list#body.QUERY_PARAMETERS.filter + :type request_filter: dict + :param gcp_conn_id: The connection ID used to connect to Google + Cloud Platform. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1). + :type api_version: str + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_transfer_operations_list_template_fields] + template_fields = ( + "filter", + "gcp_conn_id", + "google_impersonation_chain", + ) + # [END gcp_transfer_operations_list_template_fields] + + def __init__( + self, + request_filter: Optional[Dict] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + # To preserve backward compatibility + # TODO: remove one day + if request_filter is None: + if "filter" in kwargs: + request_filter = kwargs["filter"] + DeprecationWarning( + "Use 'request_filter' instead 'filter' to pass the argument." + ) + else: + TypeError( + "__init__() missing 1 required positional argument: 'request_filter'" + ) + + super().__init__(**kwargs) + self.filter = request_filter + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.google_impersonation_chain = google_impersonation_chain + self._validate_inputs() + + def _validate_inputs(self) -> None: + if not self.filter: + raise AirflowException("The required parameter 'filter' is empty or None") + + def execute(self, context) -> List[dict]: + hook = CloudDataTransferServiceHook( + api_version=self.api_version, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.google_impersonation_chain, + ) + operations_list = hook.list_transfer_operations(request_filter=self.filter) + self.log.info(operations_list) + return operations_list + + +class CloudDataTransferServicePauseOperationOperator(BaseOperator): + """ + Pauses a transfer operation in Google Storage Transfer Service. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataTransferServicePauseOperationOperator` + + :param operation_name: (Required) Name of the transfer operation. + :type operation_name: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1). + :type api_version: str + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_transfer_operation_pause_template_fields] + template_fields = ( + "operation_name", + "gcp_conn_id", + "api_version", + "google_impersonation_chain", + ) + # [END gcp_transfer_operation_pause_template_fields] + + @apply_defaults + def __init__( + self, + *, + operation_name: str, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.operation_name = operation_name + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.google_impersonation_chain = google_impersonation_chain + self._validate_inputs() + + def _validate_inputs(self) -> None: + if not self.operation_name: + raise AirflowException( + "The required parameter 'operation_name' is empty or None" + ) + + def execute(self, context) -> None: + hook = CloudDataTransferServiceHook( + api_version=self.api_version, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.google_impersonation_chain, + ) + hook.pause_transfer_operation(operation_name=self.operation_name) + + +class CloudDataTransferServiceResumeOperationOperator(BaseOperator): + """ + Resumes a transfer operation in Google Storage Transfer Service. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataTransferServiceResumeOperationOperator` + + :param operation_name: (Required) Name of the transfer operation. + :type operation_name: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :param api_version: API version used (e.g. v1). + :type api_version: str + :type gcp_conn_id: str + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_transfer_operation_resume_template_fields] + template_fields = ( + "operation_name", + "gcp_conn_id", + "api_version", + "google_impersonation_chain", + ) + # [END gcp_transfer_operation_resume_template_fields] + + @apply_defaults + def __init__( + self, + *, + operation_name: str, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.operation_name = operation_name + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.google_impersonation_chain = google_impersonation_chain + self._validate_inputs() + super().__init__(**kwargs) + + def _validate_inputs(self) -> None: + if not self.operation_name: + raise AirflowException( + "The required parameter 'operation_name' is empty or None" + ) + + def execute(self, context) -> None: + hook = CloudDataTransferServiceHook( + api_version=self.api_version, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.google_impersonation_chain, + ) + hook.resume_transfer_operation(operation_name=self.operation_name) + + +class CloudDataTransferServiceCancelOperationOperator(BaseOperator): + """ + Cancels a transfer operation in Google Storage Transfer Service. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataTransferServiceCancelOperationOperator` + + :param operation_name: (Required) Name of the transfer operation. + :type operation_name: str + :param api_version: API version used (e.g. v1). + :type api_version: str + :param gcp_conn_id: The connection ID used to connect to Google + Cloud Platform. + :type gcp_conn_id: str + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_transfer_operation_cancel_template_fields] + template_fields = ( + "operation_name", + "gcp_conn_id", + "api_version", + "google_impersonation_chain", + ) + # [END gcp_transfer_operation_cancel_template_fields] + + @apply_defaults + def __init__( + self, + *, + operation_name: str, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.operation_name = operation_name + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.google_impersonation_chain = google_impersonation_chain + self._validate_inputs() + + def _validate_inputs(self) -> None: + if not self.operation_name: + raise AirflowException( + "The required parameter 'operation_name' is empty or None" + ) + + def execute(self, context) -> None: + hook = CloudDataTransferServiceHook( + api_version=self.api_version, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.google_impersonation_chain, + ) + hook.cancel_transfer_operation(operation_name=self.operation_name) + + +class CloudDataTransferServiceS3ToGCSOperator(BaseOperator): + """ + Synchronizes an S3 bucket with a Google Cloud Storage bucket using the + Google Cloud Storage Transfer Service. + + .. warning:: + + This operator is NOT idempotent. If you run it many times, many transfer + jobs will be created in the Google Cloud. + + **Example**: + + .. code-block:: python + + s3_to_gcs_transfer_op = S3ToGoogleCloudStorageTransferOperator( + task_id='s3_to_gcs_transfer_example', + s3_bucket='my-s3-bucket', + project_id='my-gcp-project', + gcs_bucket='my-gcs-bucket', + dag=my_dag) + + :param s3_bucket: The S3 bucket where to find the objects. (templated) + :type s3_bucket: str + :param gcs_bucket: The destination Google Cloud Storage bucket + where you want to store the files. (templated) + :type gcs_bucket: str + :param project_id: Optional ID of the Google Cloud Console project that + owns the job + :type project_id: str + :param aws_conn_id: The source S3 connection + :type aws_conn_id: str + :param gcp_conn_id: The destination connection ID to use + when connecting to Google Cloud Storage. + :type gcp_conn_id: str + :param delegate_to: Google account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param description: Optional transfer service job description + :type description: str + :param schedule: Optional transfer service schedule; + If not set, run transfer job once as soon as the operator runs + The format is described + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs. + With two additional improvements: + + * dates they can be passed as :class:`datetime.date` + * times they can be passed as :class:`datetime.time` + + :type schedule: dict + :param object_conditions: Optional transfer service object conditions; see + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/TransferSpec + :type object_conditions: dict + :param transfer_options: Optional transfer service transfer options; see + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/TransferSpec + :type transfer_options: dict + :param wait: Wait for transfer to finish. It must be set to True, if + 'delete_job_after_completion' is set to True. + :type wait: bool + :param timeout: Time to wait for the operation to end in seconds. Defaults to 60 seconds if not specified. + :type timeout: Optional[Union[float, timedelta]] + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + :param delete_job_after_completion: If True, delete the job after complete. + If set to True, 'wait' must be set to True. + :type delete_job_after_completion: bool + """ + + template_fields = ( + "gcp_conn_id", + "s3_bucket", + "gcs_bucket", + "description", + "object_conditions", + "google_impersonation_chain", + ) + ui_color = "#e09411" + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + s3_bucket: str, + gcs_bucket: str, + project_id: Optional[str] = None, + aws_conn_id: str = "aws_default", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + description: Optional[str] = None, + schedule: Optional[Dict] = None, + object_conditions: Optional[Dict] = None, + transfer_options: Optional[Dict] = None, + wait: bool = True, + timeout: Optional[float] = None, + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delete_job_after_completion: bool = False, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self.s3_bucket = s3_bucket + self.gcs_bucket = gcs_bucket + self.project_id = project_id + self.aws_conn_id = aws_conn_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.description = description + self.schedule = schedule + self.object_conditions = object_conditions + self.transfer_options = transfer_options + self.wait = wait + self.timeout = timeout + self.google_impersonation_chain = google_impersonation_chain + self.delete_job_after_completion = delete_job_after_completion + self._validate_inputs() + + def _validate_inputs(self) -> None: + if self.delete_job_after_completion and not self.wait: + raise AirflowException( + "If 'delete_job_after_completion' is True, then 'wait' must also be True." + ) + + def execute(self, context) -> None: + hook = CloudDataTransferServiceHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.google_impersonation_chain, + ) + body = self._create_body() + + TransferJobPreprocessor( + body=body, aws_conn_id=self.aws_conn_id, default_schedule=True + ).process_body() + + job = hook.create_transfer_job(body=body) + + if self.wait: + hook.wait_for_transfer_job(job, timeout=self.timeout) + if self.delete_job_after_completion: + hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id) + + def _create_body(self) -> dict: + body = { + DESCRIPTION: self.description, + STATUS: GcpTransferJobsStatus.ENABLED, + TRANSFER_SPEC: { + AWS_S3_DATA_# {BUCKET_NAME: self.s3_bucket}, + GCS_DATA_SINK: {BUCKET_NAME: self.gcs_bucket}, + }, + } + + if self.project_id is not None: + body[PROJECT_ID] = self.project_id + + if self.schedule is not None: + body[SCHEDULE] = self.schedule + + if self.object_conditions is not None: + body[TRANSFER_SPEC][OBJECT_CONDITIONS] = self.object_conditions # type: ignore[index] + + if self.transfer_options is not None: + body[TRANSFER_SPEC][TRANSFER_OPTIONS] = self.transfer_options # type: ignore[index] + + return body + + +class CloudDataTransferServiceGCSToGCSOperator(BaseOperator): + """ + Copies objects from a bucket to another using the Google Cloud Storage Transfer Service. + + .. warning:: + + This operator is NOT idempotent. If you run it many times, many transfer + jobs will be created in the Google Cloud. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GCSToGCSOperator` + + **Example**: + + .. code-block:: python + + gcs_to_gcs_transfer_op = GoogleCloudStorageToGoogleCloudStorageTransferOperator( + task_id='gcs_to_gcs_transfer_example', + source_bucket='my-source-bucket', + destination_bucket='my-destination-bucket', + project_id='my-gcp-project', + dag=my_dag) + + :param source_bucket: The source Google Cloud Storage bucket where the + object is. (templated) + :type source_bucket: str + :param destination_bucket: The destination Google Cloud Storage bucket + where the object should be. (templated) + :type destination_bucket: str + :param project_id: The ID of the Google Cloud Console project that + owns the job + :type project_id: str + :param gcp_conn_id: Optional connection ID to use when connecting to Google Cloud + Storage. + :type gcp_conn_id: str + :param delegate_to: Google account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param description: Optional transfer service job description + :type description: str + :param schedule: Optional transfer service schedule; + If not set, run transfer job once as soon as the operator runs + See: + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs. + With two additional improvements: + + * dates they can be passed as :class:`datetime.date` + * times they can be passed as :class:`datetime.time` + + :type schedule: dict + :param object_conditions: Optional transfer service object conditions; see + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/TransferSpec#ObjectConditions + :type object_conditions: dict + :param transfer_options: Optional transfer service transfer options; see + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/TransferSpec#TransferOptions + :type transfer_options: dict + :param wait: Wait for transfer to finish. It must be set to True, if + 'delete_job_after_completion' is set to True. + :type wait: bool + :param timeout: Time to wait for the operation to end in seconds. Defaults to 60 seconds if not specified. + :type timeout: Optional[Union[float, timedelta]] + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + :param delete_job_after_completion: If True, delete the job after complete. + If set to True, 'wait' must be set to True. + :type delete_job_after_completion: bool + """ + + template_fields = ( + "gcp_conn_id", + "source_bucket", + "destination_bucket", + "description", + "object_conditions", + "google_impersonation_chain", + ) + ui_color = "#e09411" + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + source_bucket: str, + destination_bucket: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + description: Optional[str] = None, + schedule: Optional[Dict] = None, + object_conditions: Optional[Dict] = None, + transfer_options: Optional[Dict] = None, + wait: bool = True, + timeout: Optional[float] = None, + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delete_job_after_completion: bool = False, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self.source_bucket = source_bucket + self.destination_bucket = destination_bucket + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.description = description + self.schedule = schedule + self.object_conditions = object_conditions + self.transfer_options = transfer_options + self.wait = wait + self.timeout = timeout + self.google_impersonation_chain = google_impersonation_chain + self.delete_job_after_completion = delete_job_after_completion + self._validate_inputs() + + def _validate_inputs(self) -> None: + if self.delete_job_after_completion and not self.wait: + raise AirflowException( + "If 'delete_job_after_completion' is True, then 'wait' must also be True." + ) + + def execute(self, context) -> None: + hook = CloudDataTransferServiceHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.google_impersonation_chain, + ) + + body = self._create_body() + + TransferJobPreprocessor(body=body, default_schedule=True).process_body() + + job = hook.create_transfer_job(body=body) + + if self.wait: + hook.wait_for_transfer_job(job, timeout=self.timeout) + if self.delete_job_after_completion: + hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id) + + def _create_body(self) -> dict: + body = { + DESCRIPTION: self.description, + STATUS: GcpTransferJobsStatus.ENABLED, + TRANSFER_SPEC: { + GCS_DATA_# {BUCKET_NAME: self.source_bucket}, + GCS_DATA_SINK: {BUCKET_NAME: self.destination_bucket}, + }, + } + + if self.project_id is not None: + body[PROJECT_ID] = self.project_id + + if self.schedule is not None: + body[SCHEDULE] = self.schedule + + if self.object_conditions is not None: + body[TRANSFER_SPEC][OBJECT_CONDITIONS] = self.object_conditions # type: ignore[index] + + if self.transfer_options is not None: + body[TRANSFER_SPEC][TRANSFER_OPTIONS] = self.transfer_options # type: ignore[index] + + return body diff --git a/reference/providers/google/cloud/operators/compute.py b/reference/providers/google/cloud/operators/compute.py new file mode 100644 index 0000000..59c4a8f --- /dev/null +++ b/reference/providers/google/cloud/operators/compute.py @@ -0,0 +1,676 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Compute Engine operators.""" + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook +from airflow.providers.google.cloud.utils.field_sanitizer import GcpBodyFieldSanitizer +from airflow.providers.google.cloud.utils.field_validator import GcpBodyFieldValidator +from airflow.utils.decorators import apply_defaults +from googleapiclient.errors import HttpError +from json_merge_patch import merge + + +class ComputeEngineBaseOperator(BaseOperator): + """Abstract base operator for Google Compute Engine operators to inherit from.""" + + @apply_defaults + def __init__( + self, + *, + zone: str, + resource_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.project_id = project_id + self.zone = zone + self.resource_id = resource_id + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.impersonation_chain = impersonation_chain + self._validate_inputs() + super().__init__(**kwargs) + + def _validate_inputs(self) -> None: + if self.project_id == "": + raise AirflowException("The required parameter 'project_id' is missing") + if not self.zone: + raise AirflowException("The required parameter 'zone' is missing") + if not self.resource_id: + raise AirflowException("The required parameter 'resource_id' is missing") + + def execute(self, context): + pass + + +class ComputeEngineStartInstanceOperator(ComputeEngineBaseOperator): + """ + Starts an instance in Google Compute Engine. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ComputeEngineStartInstanceOperator` + + :param zone: Google Cloud zone where the instance exists. + :type zone: str + :param resource_id: Name of the Compute Engine instance resource. + :type resource_id: str + :param project_id: Optional, Google Cloud Project ID where the Compute + Engine Instance exists. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param api_version: Optional, API version used (for example v1 - or beta). Defaults + to v1. + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gce_instance_start_template_fields] + template_fields = ( + "project_id", + "zone", + "resource_id", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gce_instance_start_template_fields] + + @apply_defaults + def __init__( + self, + *, + zone: str, + resource_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__( + project_id=project_id, + zone=zone, + resource_id=resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def execute(self, context) -> None: + hook = ComputeEngineHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + return hook.start_instance( + zone=self.zone, resource_id=self.resource_id, project_id=self.project_id + ) + + +class ComputeEngineStopInstanceOperator(ComputeEngineBaseOperator): + """ + Stops an instance in Google Compute Engine. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ComputeEngineStopInstanceOperator` + + :param zone: Google Cloud zone where the instance exists. + :type zone: str + :param resource_id: Name of the Compute Engine instance resource. + :type resource_id: str + :param project_id: Optional, Google Cloud Project ID where the Compute + Engine Instance exists. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param api_version: Optional, API version used (for example v1 - or beta). Defaults + to v1. + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gce_instance_stop_template_fields] + template_fields = ( + "project_id", + "zone", + "resource_id", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gce_instance_stop_template_fields] + + @apply_defaults + def __init__( + self, + *, + zone: str, + resource_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__( + project_id=project_id, + zone=zone, + resource_id=resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def execute(self, context) -> None: + hook = ComputeEngineHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + hook.stop_instance( + zone=self.zone, resource_id=self.resource_id, project_id=self.project_id + ) + + +SET_MACHINE_TYPE_VALIDATION_SPECIFICATION = [ + dict(name="machineType", regexp="^.+$"), +] + + +class ComputeEngineSetMachineTypeOperator(ComputeEngineBaseOperator): + """ + Changes the machine type for a stopped instance to the machine type specified in + the request. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ComputeEngineSetMachineTypeOperator` + + :param zone: Google Cloud zone where the instance exists. + :type zone: str + :param resource_id: Name of the Compute Engine instance resource. + :type resource_id: str + :param body: Body required by the Compute Engine setMachineType API, as described in + https://cloud.google.com/compute/docs/reference/rest/v1/instances/setMachineType#request-body + :type body: dict + :param project_id: Optional, Google Cloud Project ID where the Compute + Engine Instance exists. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param api_version: Optional, API version used (for example v1 - or beta). Defaults + to v1. + :type api_version: str + :param validate_body: Optional, If set to False, body validation is not performed. + Defaults to False. + :type validate_body: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gce_instance_set_machine_type_template_fields] + template_fields = ( + "project_id", + "zone", + "resource_id", + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gce_instance_set_machine_type_template_fields] + + @apply_defaults + def __init__( + self, + *, + zone: str, + resource_id: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.body = body + self._field_validator = None # type: Optional[GcpBodyFieldValidator] + if validate_body: + self._field_validator = GcpBodyFieldValidator( + SET_MACHINE_TYPE_VALIDATION_SPECIFICATION, api_version=api_version + ) + super().__init__( + project_id=project_id, + zone=zone, + resource_id=resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _validate_all_body_fields(self) -> None: + if self._field_validator: + self._field_validator.validate(self.body) + + def execute(self, context) -> None: + hook = ComputeEngineHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self._validate_all_body_fields() + return hook.set_machine_type( + zone=self.zone, + resource_id=self.resource_id, + body=self.body, + project_id=self.project_id, + ) + + +GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION = [ + dict(name="name", regexp="^.+$"), + dict(name="description", optional=True), + dict( + name="properties", + type="dict", + optional=True, + fields=[ + dict(name="description", optional=True), + dict( + name="tags", optional=True, fields=[dict(name="items", optional=True)] + ), + dict(name="machineType", optional=True), + dict(name="canIpForward", optional=True), + dict(name="networkInterfaces", optional=True), # not validating deeper + dict(name="disks", optional=True), # not validating the array deeper + dict( + name="metadata", + optional=True, + fields=[ + dict(name="fingerprint", optional=True), + dict(name="items", optional=True), + dict(name="kind", optional=True), + ], + ), + dict(name="serviceAccounts", optional=True), # not validating deeper + dict( + name="scheduling", + optional=True, + fields=[ + dict(name="onHostMaintenance", optional=True), + dict(name="automaticRestart", optional=True), + dict(name="preemptible", optional=True), + dict(name="nodeAffinities", optional=True), # not validating deeper + ], + ), + dict(name="labels", optional=True), + dict(name="guestAccelerators", optional=True), # not validating deeper + dict(name="minCpuPlatform", optional=True), + ], + ), +] # type: List[Dict[str, Any]] + +GCE_INSTANCE_TEMPLATE_FIELDS_TO_SANITIZE = [ + "kind", + "id", + "name", + "creationTimestamp", + "properties.disks.sha256", + "properties.disks.kind", + "properties.disks.sourceImageEncryptionKey.sha256", + "properties.disks.index", + "properties.disks.licenses", + "properties.networkInterfaces.kind", + "properties.networkInterfaces.accessConfigs.kind", + "properties.networkInterfaces.name", + "properties.metadata.kind", + "selfLink", +] + + +class ComputeEngineCopyInstanceTemplateOperator(ComputeEngineBaseOperator): + """ + Copies the instance template, applying specified changes. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ComputeEngineCopyInstanceTemplateOperator` + + :param resource_id: Name of the Instance Template + :type resource_id: str + :param body_patch: Patch to the body of instanceTemplates object following rfc7386 + PATCH semantics. The body_patch content follows + https://cloud.google.com/compute/docs/reference/rest/v1/instanceTemplates + Name field is required as we need to rename the template, + all the other fields are optional. It is important to follow PATCH semantics + - arrays are replaced fully, so if you need to update an array you should + provide the whole target array as patch element. + :type body_patch: dict + :param project_id: Optional, Google Cloud Project ID where the Compute + Engine Instance exists. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :param request_id: Optional, unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again). + It should be in UUID format as defined in RFC 4122. + :type request_id: str + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param api_version: Optional, API version used (for example v1 - or beta). Defaults + to v1. + :type api_version: str + :param validate_body: Optional, If set to False, body validation is not performed. + Defaults to False. + :type validate_body: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gce_instance_template_copy_operator_template_fields] + template_fields = ( + "project_id", + "resource_id", + "request_id", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gce_instance_template_copy_operator_template_fields] + + @apply_defaults + def __init__( + self, + *, + resource_id: str, + body_patch: dict, + project_id: Optional[str] = None, + request_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.body_patch = body_patch + self.request_id = request_id + self._field_validator = None # Optional[GcpBodyFieldValidator] + if "name" not in self.body_patch: + raise AirflowException( + "The body '{}' should contain at least " + "name for the new operator in the 'name' field".format(body_patch) + ) + if validate_body: + self._field_validator = GcpBodyFieldValidator( + GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION, + api_version=api_version, + ) + self._field_sanitizer = GcpBodyFieldSanitizer( + GCE_INSTANCE_TEMPLATE_FIELDS_TO_SANITIZE + ) + super().__init__( + project_id=project_id, + zone="global", + resource_id=resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _validate_all_body_fields(self) -> None: + if self._field_validator: + self._field_validator.validate(self.body_patch) + + def execute(self, context) -> dict: + hook = ComputeEngineHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self._validate_all_body_fields() + try: + # Idempotence check (sort of) - we want to check if the new template + # is already created and if is, then we assume it was created by previous run + # of CopyTemplate operator - we do not check if content of the template + # is as expected. Templates are immutable so we cannot update it anyway + # and deleting/recreating is not worth the hassle especially + # that we cannot delete template if it is already used in some Instance + # Group Manager. We assume success if the template is simply present + existing_template = hook.get_instance_template( + resource_id=self.body_patch["name"], project_id=self.project_id + ) + self.log.info( + "The %s template already existed. It was likely created by previous run of the operator. " + "Assuming success.", + existing_template, + ) + return existing_template + except HttpError as e: + # We actually expect to get 404 / Not Found here as the template should + # not yet exist + if not e.resp.status == 404: + raise e + old_body = hook.get_instance_template( + resource_id=self.resource_id, project_id=self.project_id + ) + new_body = deepcopy(old_body) + self._field_sanitizer.sanitize(new_body) + new_body = merge(new_body, self.body_patch) + self.log.info( + "Calling insert instance template with updated body: %s", new_body + ) + hook.insert_instance_template( + body=new_body, request_id=self.request_id, project_id=self.project_id + ) + return hook.get_instance_template( + resource_id=self.body_patch["name"], project_id=self.project_id + ) + + +class ComputeEngineInstanceGroupUpdateManagerTemplateOperator( + ComputeEngineBaseOperator +): + """ + Patches the Instance Group Manager, replacing source template URL with the + destination one. API V1 does not have update/patch operations for Instance + Group Manager, so you must use beta or newer API version. Beta is the default. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ComputeEngineInstanceGroupUpdateManagerTemplateOperator` + + :param resource_id: Name of the Instance Group Manager + :type resource_id: str + :param zone: Google Cloud zone where the Instance Group Manager exists. + :type zone: str + :param source_template: URL of the template to replace. + :type source_template: str + :param destination_template: URL of the target template. + :type destination_template: str + :param project_id: Optional, Google Cloud Project ID where the Compute + Engine Instance exists. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :param request_id: Optional, unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again). + It should be in UUID format as defined in RFC 4122. + :type request_id: str + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param api_version: Optional, API version used (for example v1 - or beta). Defaults + to v1. + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gce_igm_update_template_operator_template_fields] + template_fields = ( + "project_id", + "resource_id", + "zone", + "request_id", + "source_template", + "destination_template", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gce_igm_update_template_operator_template_fields] + + @apply_defaults + def __init__( + self, + *, + resource_id: str, + zone: str, + source_template: str, + destination_template: str, + project_id: Optional[str] = None, + update_policy: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version="beta", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.zone = zone + self.source_template = source_template + self.destination_template = destination_template + self.request_id = request_id + self.update_policy = update_policy + self._change_performed = False + if api_version == "v1": + raise AirflowException( + "Api version v1 does not have update/patch " + "operations for Instance Group Managers. Use beta" + " api version or above" + ) + super().__init__( + project_id=project_id, + zone=self.zone, + resource_id=resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _possibly_replace_template(self, dictionary: dict) -> None: + if dictionary.get("instanceTemplate") == self.source_template: + dictionary["instanceTemplate"] = self.destination_template + self._change_performed = True + + def execute(self, context) -> Optional[bool]: + hook = ComputeEngineHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + old_instance_group_manager = hook.get_instance_group_manager( + zone=self.zone, resource_id=self.resource_id, project_id=self.project_id + ) + patch_body = {} + if "versions" in old_instance_group_manager: + patch_body["versions"] = old_instance_group_manager["versions"] + if "instanceTemplate" in old_instance_group_manager: + patch_body["instanceTemplate"] = old_instance_group_manager[ + "instanceTemplate" + ] + if self.update_policy: + patch_body["updatePolicy"] = self.update_policy + self._possibly_replace_template(patch_body) + if "versions" in patch_body: + for version in patch_body["versions"]: + self._possibly_replace_template(version) + if self._change_performed or self.update_policy: + self.log.info( + "Calling patch instance template with updated body: %s", patch_body + ) + return hook.patch_instance_group_manager( + zone=self.zone, + resource_id=self.resource_id, + body=patch_body, + request_id=self.request_id, + project_id=self.project_id, + ) + else: + # Idempotence achieved + return True diff --git a/reference/providers/google/cloud/operators/datacatalog.py b/reference/providers/google/cloud/operators/datacatalog.py new file mode 100644 index 0000000..b8eba28 --- /dev/null +++ b/reference/providers/google/cloud/operators/datacatalog.py @@ -0,0 +1,2274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict, Optional, Sequence, Tuple, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.datacatalog import CloudDataCatalogHook +from airflow.utils.decorators import apply_defaults +from google.api_core.exceptions import AlreadyExists, NotFound +from google.api_core.retry import Retry +from google.cloud.datacatalog_v1beta1 import DataCatalogClient, SearchCatalogResult +from google.cloud.datacatalog_v1beta1.types import ( + Entry, + EntryGroup, + SearchCatalogRequest, + Tag, + TagTemplate, + TagTemplateField, +) +from google.protobuf.field_mask_pb2 import FieldMask + + +class CloudDataCatalogCreateEntryOperator(BaseOperator): + """ + Creates an entry. + + Currently only entries of 'FILESET' type can be created. + + The newly created entry ID are saved under the ``entry_id`` key in XCOM. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogCreateEntryOperator` + + :param location: Required. The location of the entry to create. + :type location: str + :param entry_group: Required. Entry group ID under which the entry is created. + :type entry_group: str + :param entry_id: Required. The id of the entry to create. + :type entry_id: str + :param entry: Required. The entry to create. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.Entry` + :type entry: Union[Dict, google.cloud.datacatalog_v1beta1.types.Entry] + :param project_id: The ID of the Google Cloud project that owns the entry. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If set to ``None`` or missing, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "entry_group", + "entry_id", + "entry", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + entry_group: str, + entry_id: str, + entry: Union[Dict, Entry], + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.entry_group = entry_group + self.entry_id = entry_id + self.entry = entry + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict): + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + try: + result = hook.create_entry( + location=self.location, + entry_group=self.entry_group, + entry_id=self.entry_id, + entry=self.entry, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except AlreadyExists: + self.log.info("Entry already exists. Skipping create operation.") + result = hook.get_entry( + location=self.location, + entry_group=self.entry_group, + entry=self.entry_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + _, _, entry_id = result.name.rpartition("/") + self.log.info("Current entry_id ID: %s", entry_id) + context["task_instance"].xcom_push(key="entry_id", value=entry_id) + return Entry.to_dict(result) + + +class CloudDataCatalogCreateEntryGroupOperator(BaseOperator): + """ + Creates an EntryGroup. + + The newly created entry group ID are saved under the ``entry_group_id`` key in XCOM. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogCreateEntryGroupOperator` + + :param location: Required. The location of the entry group to create. + :type location: str + :param entry_group_id: Required. The id of the entry group to create. The id must begin with a letter + or underscore, contain only English letters, numbers and underscores, and be at most 64 + characters. + :type entry_group_id: str + :param entry_group: The entry group to create. Defaults to an empty entry group. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.EntryGroup` + :type entry_group: Union[Dict, google.cloud.datacatalog_v1beta1.types.EntryGroup] + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "entry_group_id", + "entry_group", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + entry_group_id: str, + entry_group: Union[Dict, EntryGroup], + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.entry_group_id = entry_group_id + self.entry_group = entry_group + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict): + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + try: + result = hook.create_entry_group( + location=self.location, + entry_group_id=self.entry_group_id, + entry_group=self.entry_group, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except AlreadyExists: + self.log.info("Entry already exists. Skipping create operation.") + result = hook.get_entry_group( + location=self.location, + entry_group=self.entry_group_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + _, _, entry_group_id = result.name.rpartition("/") + self.log.info("Current entry group ID: %s", entry_group_id) + context["task_instance"].xcom_push(key="entry_group_id", value=entry_group_id) + return EntryGroup.to_dict(result) + + +class CloudDataCatalogCreateTagOperator(BaseOperator): + """ + Creates a tag on an entry. + + The newly created tag ID are saved under the ``tag_id`` key in XCOM. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogCreateTagOperator` + + :param location: Required. The location of the tag to create. + :type location: str + :param entry_group: Required. Entry group ID under which the tag is created. + :type entry_group: str + :param entry: Required. Entry group ID under which the tag is created. + :type entry: str + :param tag: Required. The tag to create. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.Tag` + :type tag: Union[Dict, google.cloud.datacatalog_v1beta1.types.Tag] + :param template_id: Required. Template ID used to create tag + :type template_id: Optional[str] + :param project_id: The ID of the Google Cloud project that owns the tag. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "entry_group", + "entry", + "tag", + "template_id", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + location: str, + entry_group: str, + entry: str, + tag: Union[Dict, Tag], + template_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.entry_group = entry_group + self.entry = entry + self.tag = tag + self.template_id = template_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict): + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + try: + tag = hook.create_tag( + location=self.location, + entry_group=self.entry_group, + entry=self.entry, + tag=self.tag, + template_id=self.template_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except AlreadyExists: + self.log.info("Tag already exists. Skipping create operation.") + if self.template_id: + template_name = DataCatalogClient.tag_template_path( + self.project_id or hook.project_id, self.location, self.template_id + ) + else: + if isinstance(self.tag, Tag): + template_name = self.tag.template + else: + template_name = self.tag["template"] + + tag = hook.get_tag_for_template_name( + location=self.location, + entry_group=self.entry_group, + template_name=template_name, + entry=self.entry, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + _, _, tag_id = tag.name.rpartition("/") + self.log.info("Current Tag ID: %s", tag_id) + context["task_instance"].xcom_push(key="tag_id", value=tag_id) + return Tag.to_dict(tag) + + +class CloudDataCatalogCreateTagTemplateOperator(BaseOperator): + """ + Creates a tag template. + + The newly created tag template are saved under the ``tag_template_id`` key in XCOM. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogCreateTagTemplateOperator` + + :param location: Required. The location of the tag template to create. + :type location: str + :param tag_template_id: Required. The id of the tag template to create. + :type tag_template_id: str + :param tag_template: Required. The tag template to create. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.TagTemplate` + :type tag_template: Union[Dict, google.cloud.datacatalog_v1beta1.types.TagTemplate] + :param project_id: The ID of the Google Cloud project that owns the tag template. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "tag_template_id", + "tag_template", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + tag_template_id: str, + tag_template: Union[Dict, TagTemplate], + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.tag_template_id = tag_template_id + self.tag_template = tag_template + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict): + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + try: + result = hook.create_tag_template( + location=self.location, + tag_template_id=self.tag_template_id, + tag_template=self.tag_template, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except AlreadyExists: + self.log.info("Tag Template already exists. Skipping create operation.") + result = hook.get_tag_template( + location=self.location, + tag_template=self.tag_template_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + _, _, tag_template = result.name.rpartition("/") + self.log.info("Current Tag ID: %s", tag_template) + context["task_instance"].xcom_push(key="tag_template_id", value=tag_template) + return TagTemplate.to_dict(result) + + +class CloudDataCatalogCreateTagTemplateFieldOperator(BaseOperator): + r""" + Creates a field in a tag template. + + The newly created tag template field are saved under the ``tag_template_field_id`` key in XCOM. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogCreateTagTemplateFieldOperator` + + :param location: Required. The location of the tag template field to create. + :type location: str + :param tag_template: Required. The id of the tag template to create. + :type tag_template: str + :param tag_template_field_id: Required. The ID of the tag template field to create. Field ids can + contain letters (both uppercase and lowercase), numbers (0-9), underscores (\_) and dashes (-). + Field IDs must be at least 1 character long and at most 128 characters long. Field IDs must also + be unique within their template. + :type tag_template_field_id: str + :param tag_template_field: Required. The tag template field to create. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.TagTemplateField` + :type tag_template_field: Union[Dict, google.cloud.datacatalog_v1beta1.types.TagTemplateField] + :param project_id: The ID of the Google Cloud project that owns the tag template field. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "tag_template", + "tag_template_field_id", + "tag_template_field", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + tag_template: str, + tag_template_field_id: str, + tag_template_field: Union[Dict, TagTemplateField], + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.tag_template = tag_template + self.tag_template_field_id = tag_template_field_id + self.tag_template_field = tag_template_field + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict): + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + try: + result = hook.create_tag_template_field( + location=self.location, + tag_template=self.tag_template, + tag_template_field_id=self.tag_template_field_id, + tag_template_field=self.tag_template_field, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except AlreadyExists: + self.log.info( + "Tag template field already exists. Skipping create operation." + ) + tag_template = hook.get_tag_template( + location=self.location, + tag_template=self.tag_template, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = tag_template.fields[self.tag_template_field_id] + + self.log.info("Current Tag ID: %s", self.tag_template_field_id) + context["task_instance"].xcom_push( + key="tag_template_field_id", value=self.tag_template_field_id + ) + return TagTemplateField.to_dict(result) + + +class CloudDataCatalogDeleteEntryOperator(BaseOperator): + """ + Deletes an existing entry. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogDeleteEntryOperator` + + :param location: Required. The location of the entry to delete. + :type location: str + :param entry_group: Required. Entry group ID for entries that is deleted. + :type entry_group: str + :param entry: Entry ID that is deleted. + :type entry: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "entry_group", + "entry", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + entry_group: str, + entry: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.entry_group = entry_group + self.entry = entry + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + try: + hook.delete_entry( + location=self.location, + entry_group=self.entry_group, + entry=self.entry, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except NotFound: + self.log.info("Entry doesn't exists. Skipping.") + + +class CloudDataCatalogDeleteEntryGroupOperator(BaseOperator): + """ + Deletes an EntryGroup. + + Only entry groups that do not contain entries can be deleted. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogDeleteEntryGroupOperator` + + :param location: Required. The location of the entry group to delete. + :type location: str + :param entry_group: Entry group ID that is deleted. + :type entry_group: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "entry_group", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + entry_group: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.entry_group = entry_group + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + try: + hook.delete_entry_group( + location=self.location, + entry_group=self.entry_group, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except NotFound: + self.log.info("Entry doesn't exists. skipping") + + +class CloudDataCatalogDeleteTagOperator(BaseOperator): + """ + Deletes a tag. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogDeleteTagOperator` + + :param location: Required. The location of the tag to delete. + :type location: str + :param entry_group: Entry group ID for tag that is deleted. + :type entry_group: str + :param entry: Entry ID for tag that is deleted. + :type entry: str + :param tag: Identifier for TAG that is deleted. + :type tag: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "entry_group", + "entry", + "tag", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + entry_group: str, + entry: str, + tag: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.entry_group = entry_group + self.entry = entry + self.tag = tag + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + try: + hook.delete_tag( + location=self.location, + entry_group=self.entry_group, + entry=self.entry, + tag=self.tag, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except NotFound: + self.log.info("Entry doesn't exists. skipping") + + +class CloudDataCatalogDeleteTagTemplateOperator(BaseOperator): + """ + Deletes a tag template and all tags using the template. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogDeleteTagTemplateOperator` + + :param location: Required. The location of the tag template to delete. + :type location: str + :param tag_template: ID for tag template that is deleted. + :type tag_template: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param force: Required. Currently, this field must always be set to ``true``. This confirms the + deletion of any possible tags using this template. ``force = false`` will be supported in the + future. + :type force: bool + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "tag_template", + "force", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + tag_template: str, + force: bool, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.tag_template = tag_template + self.force = force + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + try: + hook.delete_tag_template( + location=self.location, + tag_template=self.tag_template, + force=self.force, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except NotFound: + self.log.info("Tag Template doesn't exists. skipping") + + +class CloudDataCatalogDeleteTagTemplateFieldOperator(BaseOperator): + """ + Deletes a field in a tag template and all uses of that field. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogDeleteTagTemplateFieldOperator` + + :param location: Required. The location of the tag template to delete. + :type location: str + :param tag_template: Tag Template ID for tag template field that is deleted. + :type tag_template: str + :param field: Name of field that is deleted. + :type field: str + :param force: Required. This confirms the deletion of this field from any tags using this field. + :type force: bool + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "tag_template", + "field", + "force", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + tag_template: str, + field: str, + force: bool, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.tag_template = tag_template + self.field = field + self.force = force + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + try: + hook.delete_tag_template_field( + location=self.location, + tag_template=self.tag_template, + field=self.field, + force=self.force, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except NotFound: + self.log.info("Tag Template field doesn't exists. skipping") + + +class CloudDataCatalogGetEntryOperator(BaseOperator): + """ + Gets an entry. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogGetEntryOperator` + + :param location: Required. The location of the entry to get. + :type location: str + :param entry_group: Required. The entry group of the entry to get. + :type entry_group: str + :param entry: The ID of the entry to get. + :type entry: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "entry_group", + "entry", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + entry_group: str, + entry: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.entry_group = entry_group + self.entry = entry + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> dict: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + result = hook.get_entry( + location=self.location, + entry_group=self.entry_group, + entry=self.entry, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Entry.to_dict(result) + + +class CloudDataCatalogGetEntryGroupOperator(BaseOperator): + """ + Gets an entry group. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogGetEntryGroupOperator` + + :param location: Required. The location of the entry group to get. + :type location: str + :param entry_group: The ID of the entry group to get. + :type entry_group: str + :param read_mask: The fields to return. If not set or empty, all fields are returned. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type read_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "entry_group", + "read_mask", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + entry_group: str, + read_mask: Union[Dict, FieldMask], + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.entry_group = entry_group + self.read_mask = read_mask + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> dict: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + result = hook.get_entry_group( + location=self.location, + entry_group=self.entry_group, + read_mask=self.read_mask, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return EntryGroup.to_dict(result) + + +class CloudDataCatalogGetTagTemplateOperator(BaseOperator): + """ + Gets a tag template. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogGetTagTemplateOperator` + + :param location: Required. The location of the tag template to get. + :type location: str + :param tag_template: Required. The ID of the tag template to get. + :type tag_template: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "tag_template", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + tag_template: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.tag_template = tag_template + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> dict: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + result = hook.get_tag_template( + location=self.location, + tag_template=self.tag_template, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return TagTemplate.to_dict(result) + + +class CloudDataCatalogListTagsOperator(BaseOperator): + """ + Lists the tags on an Entry. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogListTagsOperator` + + :param location: Required. The location of the tags to get. + :type location: str + :param entry_group: Required. The entry group of the tags to get. + :type entry_group: str + :param entry: Required. The entry of the tags to get. + :type entry: str + :param page_size: The maximum number of resources contained in the underlying API response. If page + streaming is performed per- resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number of resources in a page. + (Default: 100) + :type page_size: int + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "entry_group", + "entry", + "page_size", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + entry_group: str, + entry: str, + page_size: int = 100, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.entry_group = entry_group + self.entry = entry + self.page_size = page_size + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> list: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + result = hook.list_tags( + location=self.location, + entry_group=self.entry_group, + entry=self.entry, + page_size=self.page_size, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return [Tag.to_dict(item) for item in result] + + +class CloudDataCatalogLookupEntryOperator(BaseOperator): + r""" + Get an entry by target resource name. + + This method allows clients to use the resource name from the source Google Cloud service + to get the Data Catalog Entry. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogLookupEntryOperator` + + :param linked_re# The full name of the Google Cloud resource the Data Catalog entry + represents. See: https://cloud.google.com/apis/design/resource\_names#full\_resource\_name. Full + names are case-sensitive. + :type linked_re# str + :param sql_re# The SQL name of the entry. SQL names are case-sensitive. + :type sql_re# str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "linked_resource", + "sql_resource", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + linked_re# Optional[str] = None, + sql_re# Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.linked_resource = linked_resource + self.sql_resource = sql_resource + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> dict: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + result = hook.lookup_entry( + linked_resource=self.linked_resource, + sql_resource=self.sql_resource, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Entry.to_dict(result) + + +class CloudDataCatalogRenameTagTemplateFieldOperator(BaseOperator): + """ + Renames a field in a tag template. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogRenameTagTemplateFieldOperator` + + :param location: Required. The location of the tag template field to rename. + :type location: str + :param tag_template: The tag template ID for field that is renamed. + :type tag_template: str + :param field: Required. The old ID of this tag template field. For example, + ``my_old_field``. + :type field: str + :param new_tag_template_field_id: Required. The new ID of this tag template field. For example, + ``my_new_field``. + :type new_tag_template_field_id: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "tag_template", + "field", + "new_tag_template_field_id", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + tag_template: str, + field: str, + new_tag_template_field_id: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.tag_template = tag_template + self.field = field + self.new_tag_template_field_id = new_tag_template_field_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.rename_tag_template_field( + location=self.location, + tag_template=self.tag_template, + field=self.field, + new_tag_template_field_id=self.new_tag_template_field_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudDataCatalogSearchCatalogOperator(BaseOperator): + r""" + Searches Data Catalog for multiple resources like entries, tags that match a query. + + This does not return the complete resource, only the resource identifier and high level fields. + Clients can subsequently call ``Get`` methods. + + Note that searches do not have full recall. There may be results that match your query but are not + returned, even in subsequent pages of results. These missing results may vary across repeated calls to + search. Do not rely on this method if you need to guarantee full recall. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogSearchCatalogOperator` + + :param scope: Required. The scope of this search request. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.Scope` + :type scope: Union[Dict, google.cloud.datacatalog_v1beta1.types.SearchCatalogRequest.Scope] + :param query: Required. The query string in search query syntax. The query must be non-empty. + + Query strings can be simple as "x" or more qualified as: + + - name:x + - column:x + - description:y + + Note: Query tokens need to have a minimum of 3 characters for substring matching to work + correctly. See `Data Catalog Search Syntax `__ for more information. + :type query: str + :param page_size: The maximum number of resources contained in the underlying API response. If page + streaming is performed per-resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number of resources in a page. + :type page_size: int + :param order_by: Specifies the ordering of results, currently supported case-sensitive choices are: + + - ``relevance``, only supports descending + - ``last_access_timestamp [asc|desc]``, defaults to descending if not specified + - ``last_modified_timestamp [asc|desc]``, defaults to descending if not specified + + If not specified, defaults to ``relevance`` descending. + :type order_by: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "scope", + "query", + "page_size", + "order_by", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + scope: Union[Dict, SearchCatalogRequest.Scope], + query: str, + page_size: int = 100, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.scope = scope + self.query = query + self.page_size = page_size + self.order_by = order_by + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> list: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + result = hook.search_catalog( + scope=self.scope, + query=self.query, + page_size=self.page_size, + order_by=self.order_by, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return [SearchCatalogResult.to_dict(item) for item in result] + + +class CloudDataCatalogUpdateEntryOperator(BaseOperator): + """ + Updates an existing entry. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogUpdateEntryOperator` + + :param entry: Required. The updated entry. The "name" field must be set. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.Entry` + :type entry: Union[Dict, google.cloud.datacatalog_v1beta1.types.Entry] + :param update_mask: The fields to update on the entry. If absent or empty, all modifiable fields are + updated. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] + :param location: Required. The location of the entry to update. + :type location: str + :param entry_group: The entry group ID for the entry that is being updated. + :type entry_group: str + :param entry_id: The entry ID that is being updated. + :type entry_id: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "entry", + "update_mask", + "location", + "entry_group", + "entry_id", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + entry: Union[Dict, Entry], + update_mask: Union[Dict, FieldMask], + location: Optional[str] = None, + entry_group: Optional[str] = None, + entry_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.entry = entry + self.update_mask = update_mask + self.location = location + self.entry_group = entry_group + self.entry_id = entry_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.update_entry( + entry=self.entry, + update_mask=self.update_mask, + location=self.location, + entry_group=self.entry_group, + entry_id=self.entry_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudDataCatalogUpdateTagOperator(BaseOperator): + """ + Updates an existing tag. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogUpdateTagOperator` + + :param tag: Required. The updated tag. The "name" field must be set. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.Tag` + :type tag: Union[Dict, google.cloud.datacatalog_v1beta1.types.Tag] + :param update_mask: The fields to update on the Tag. If absent or empty, all modifiable fields are + updated. Currently the only modifiable field is the field ``fields``. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] + :param location: Required. The location of the tag to rename. + :type location: str + :param entry_group: The entry group ID for the tag that is being updated. + :type entry_group: str + :param entry: The entry ID for the tag that is being updated. + :type entry: str + :param tag_id: The tag ID that is being updated. + :type tag_id: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "tag", + "update_mask", + "location", + "entry_group", + "entry", + "tag_id", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + tag: Union[Dict, Tag], + update_mask: Union[Dict, FieldMask], + location: Optional[str] = None, + entry_group: Optional[str] = None, + entry: Optional[str] = None, + tag_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.tag = tag + self.update_mask = update_mask + self.location = location + self.entry_group = entry_group + self.entry = entry + self.tag_id = tag_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.update_tag( + tag=self.tag, + update_mask=self.update_mask, + location=self.location, + entry_group=self.entry_group, + entry=self.entry, + tag_id=self.tag_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudDataCatalogUpdateTagTemplateOperator(BaseOperator): + """ + Updates a tag template. + + This method cannot be used to update the fields of a template. The tag + template fields are represented as separate resources and should be updated using their own + create/update/delete methods. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogUpdateTagTemplateOperator` + + :param tag_template: Required. The template to update. The "name" field must be set. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.TagTemplate` + :type tag_template: Union[Dict, google.cloud.datacatalog_v1beta1.types.TagTemplate] + :param update_mask: The field mask specifies the parts of the template to overwrite. + + If absent or empty, all of the allowed fields above will be updated. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] + :param location: Required. The location of the tag template to rename. + :type location: str + :param tag_template_id: Optional. The tag template ID for the entry that is being updated. + :type tag_template_id: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "tag_template", + "update_mask", + "location", + "tag_template_id", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + tag_template: Union[Dict, TagTemplate], + update_mask: Union[Dict, FieldMask], + location: Optional[str] = None, + tag_template_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.tag_template = tag_template + self.update_mask = update_mask + self.location = location + self.tag_template_id = tag_template_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.update_tag_template( + tag_template=self.tag_template, + update_mask=self.update_mask, + location=self.location, + tag_template_id=self.tag_template_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudDataCatalogUpdateTagTemplateFieldOperator(BaseOperator): + """ + Updates a field in a tag template. This method cannot be used to update the field type. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataCatalogUpdateTagTemplateFieldOperator` + + :param tag_template_field: Required. The template to update. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.datacatalog_v1beta1.types.TagTemplateField` + :type tag_template_field: Union[Dict, google.cloud.datacatalog_v1beta1.types.TagTemplateField] + :param update_mask: The field mask specifies the parts of the template to be updated. Allowed fields: + + - ``display_name`` + - ``type.enum_type`` + + If ``update_mask`` is not set or empty, all of the allowed fields above will be updated. + + When updating an enum type, the provided values will be merged with the existing values. + Therefore, enum values can only be added, existing enum values cannot be deleted nor renamed. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] + :param tag_template_field_name: Optional. The name of the tag template field to rename. + :type tag_template_field_name: str + :param location: Optional. The location of the tag to rename. + :type location: str + :param tag_template: Optional. The tag template ID for tag template field to rename. + :type tag_template: str + :param tag_template_field_id: Optional. The ID of tag template field to rename. + :type tag_template_field_id: str + :param project_id: The ID of the Google Cloud project that owns the entry group. + If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. + :type project_id: Optional[str] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will be + retried using a default configuration. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "tag_template_field", + "update_mask", + "tag_template_field_name", + "location", + "tag_template", + "tag_template_field_id", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + tag_template_field: Union[Dict, TagTemplateField], + update_mask: Union[Dict, FieldMask], + tag_template_field_name: Optional[str] = None, + location: Optional[str] = None, + tag_template: Optional[str] = None, + tag_template_field_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.tag_template_field_name = tag_template_field_name + self.location = location + self.tag_template = tag_template + self.tag_template_field_id = tag_template_field_id + self.project_id = project_id + self.tag_template_field = tag_template_field + self.update_mask = update_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = CloudDataCatalogHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.update_tag_template_field( + tag_template_field=self.tag_template_field, + update_mask=self.update_mask, + tag_template_field_name=self.tag_template_field_name, + location=self.location, + tag_template=self.tag_template, + tag_template_field_id=self.tag_template_field_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) diff --git a/reference/providers/google/cloud/operators/dataflow.py b/reference/providers/google/cloud/operators/dataflow.py new file mode 100644 index 0000000..4cb0118 --- /dev/null +++ b/reference/providers/google/cloud/operators/dataflow.py @@ -0,0 +1,1179 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Dataflow operators.""" +import copy +import re +import warnings +from contextlib import ExitStack +from enum import Enum +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType +from airflow.providers.google.cloud.hooks.dataflow import ( + DEFAULT_DATAFLOW_LOCATION, + DataflowHook, + process_line_and_extract_dataflow_job_id_callback, +) +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils.decorators import apply_defaults +from airflow.version import version + + +class CheckJobRunning(Enum): + """ + Helper enum for choosing what to do if job is already running + IgnoreJob - do not check if running + FinishIfRunning - finish current dag run with no action + WaitForRun - wait for job to finish and then continue with new job + """ + + IgnoreJob = 1 + FinishIfRunning = 2 + WaitForRun = 3 + + +class DataflowConfiguration: + """Dataflow configuration that can be passed to + :py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator` and + :py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator`. + + :param job_name: The 'jobName' to use when executing the DataFlow job + (templated). This ends up being set in the pipeline options, so any entry + with key ``'jobName'`` or ``'job_name'``in ``options`` will be overwritten. + :type job_name: str + :param append_job_name: True if unique suffix has to be appended to job name. + :type append_job_name: bool + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param location: Job location. + :type location: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param poll_sleep: The time in seconds to sleep between polling Google + Cloud Platform for the dataflow job status while the job is in the + JOB_STATE_RUNNING state. + :type poll_sleep: int + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + :param drain_pipeline: Optional, set to True if want to stop streaming job by draining it + instead of canceling during killing task instance. See: + https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + :type drain_pipeline: bool + :param cancel_timeout: How long (in seconds) operator should wait for the pipeline to be + successfully cancelled when task is being killed. + :type cancel_timeout: Optional[int] + :param wait_until_finished: (Optional) + If True, wait for the end of pipeline execution before exiting. + If False, only submits job. + If None, default behavior. + + The default behavior depends on the type of pipeline: + + * for the streaming pipeline, wait for jobs to start, + * for the batch pipeline, wait for the jobs to complete. + + .. warning:: + + You cannot call ``PipelineResult.wait_until_finish`` method in your pipeline code for the operator + to work properly. i. e. you must use asynchronous execution. Otherwise, your pipeline will + always wait until finished. For more information, look at: + `Asynchronous execution + `__ + + The process of starting the Dataflow job in Airflow consists of two steps: + + * running a subprocess and reading the stderr/stderr log for the job id. + * loop waiting for the end of the job ID from the previous step. + This loop checks the status of the job. + + Step two is started just after step one has finished, so if you have wait_until_finished in your + pipeline code, step two will not start until the process stops. When this process stops, + steps two will run, but it will only execute one iteration as the job will be in a terminal state. + + If you in your pipeline do not call the wait_for_pipeline method but pass wait_until_finish=True + to the operator, the second loop will wait for the job's terminal state. + + If you in your pipeline do not call the wait_for_pipeline method, and pass wait_until_finish=False + to the operator, the second loop will check once is job not in terminal state and exit the loop. + :type wait_until_finished: Optional[bool] + :param multiple_jobs: If pipeline creates multiple jobs then monitor all jobs. Supported only by + :py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator` + :type multiple_jobs: boolean + :param check_if_running: Before running job, validate that a previous run is not in process. + IgnoreJob = do not check if running. + FinishIfRunning = if job is running finish with nothing. + WaitForRun = wait until job finished and the run job. + Supported only by: + :py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator` + :type check_if_running: CheckJobRunning + """ + + template_fields = ["job_name", "location"] + + def __init__( + self, + *, + job_name: Optional[str] = "{{task.task_id}}", + append_job_name: bool = True, + project_id: Optional[str] = None, + location: Optional[str] = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + poll_sleep: int = 10, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + drain_pipeline: bool = False, + cancel_timeout: Optional[int] = 5 * 60, + wait_until_finished: Optional[bool] = None, + multiple_jobs: Optional[bool] = None, + check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun, + ) -> None: + self.job_name = job_name + self.append_job_name = append_job_name + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.poll_sleep = poll_sleep + self.impersonation_chain = impersonation_chain + self.drain_pipeline = drain_pipeline + self.cancel_timeout = cancel_timeout + self.wait_until_finished = wait_until_finished + self.multiple_jobs = multiple_jobs + self.check_if_running = check_if_running + + +# pylint: disable=too-many-instance-attributes +class DataflowCreateJavaJobOperator(BaseOperator): + """ + Start a Java Cloud DataFlow batch job. The parameters of the operation + will be passed to the job. + + This class is deprecated. + Please use `providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`. + + **Example**: :: + + default_args = { + 'owner': 'airflow', + 'depends_on_past': False, + 'start_date': + (2016, 8, 1), + 'email': ['alex@vanboxel.be'], + 'email_on_failure': False, + 'email_on_retry': False, + 'retries': 1, + 'retry_delay': timedelta(minutes=30), + 'dataflow_default_options': { + 'project': 'my-gcp-project', + 'zone': 'us-central1-f', + 'stagingLocation': 'gs://bucket/tmp/dataflow/staging/', + } + } + + dag = DAG('test-dag', default_args=default_args) + + task = DataFlowJavaOperator( + gcp_conn_id='gcp_default', + task_id='normalize-cal', + jar='{{var.value.gcp_dataflow_base}}pipeline-ingress-cal-normalize-1.0.jar', + options={ + 'autoscalingAlgorithm': 'BASIC', + 'maxNumWorkers': '50', + 'start': '{{ds}}', + 'partitionType': 'DAY' + + }, + dag=dag) + + .. seealso:: + For more detail on job submission have a look at the reference: + https://cloud.google.com/dataflow/pipelines/specifying-exec-params + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataflowCreateJavaJobOperator` + + :param jar: The reference to a self executing DataFlow jar (templated). + :type jar: str + :param job_name: The 'jobName' to use when executing the DataFlow job + (templated). This ends up being set in the pipeline options, so any entry + with key ``'jobName'`` in ``options`` will be overwritten. + :type job_name: str + :param dataflow_default_options: Map of default job options. + :type dataflow_default_options: dict + :param options: Map of job specific options.The key must be a dictionary. + The value can contain different types: + + * If the value is None, the single option - ``--key`` (without value) will be added. + * If the value is False, this option will be skipped + * If the value is True, the single option - ``--key`` (without value) will be added. + * If the value is list, the many options will be added for each key. + If the value is ``['A', 'B']`` and the key is ``key`` then the ``--key=A --key-B`` options + will be left + * Other value types will be replaced with the Python textual representation. + + When defining labels (``labels`` option), you can also provide a dictionary. + :type options: dict + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param location: Job location. + :type location: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param poll_sleep: The time in seconds to sleep between polling Google + Cloud Platform for the dataflow job status while the job is in the + JOB_STATE_RUNNING state. + :type poll_sleep: int + :param job_class: The name of the dataflow job class to be executed, it + is often not the main class configured in the dataflow jar file. + :type job_class: str + + :param multiple_jobs: If pipeline creates multiple jobs then monitor all jobs + :type multiple_jobs: boolean + :param check_if_running: before running job, validate that a previous run is not in process + :type check_if_running: CheckJobRunning(IgnoreJob = do not check if running, FinishIfRunning= + if job is running finish with nothing, WaitForRun= wait until job finished and the run job) + ``jar``, ``options``, and ``job_name`` are templated so you can use variables in them. + :param cancel_timeout: How long (in seconds) operator should wait for the pipeline to be + successfully cancelled when task is being killed. + :type cancel_timeout: Optional[int] + :param wait_until_finished: (Optional) + If True, wait for the end of pipeline execution before exiting. + If False, only submits job. + If None, default behavior. + + The default behavior depends on the type of pipeline: + + * for the streaming pipeline, wait for jobs to start, + * for the batch pipeline, wait for the jobs to complete. + + .. warning:: + + You cannot call ``PipelineResult.wait_until_finish`` method in your pipeline code for the operator + to work properly. i. e. you must use asynchronous execution. Otherwise, your pipeline will + always wait until finished. For more information, look at: + `Asynchronous execution + `__ + + The process of starting the Dataflow job in Airflow consists of two steps: + + * running a subprocess and reading the stderr/stderr log for the job id. + * loop waiting for the end of the job ID from the previous step. + This loop checks the status of the job. + + Step two is started just after step one has finished, so if you have wait_until_finished in your + pipeline code, step two will not start until the process stops. When this process stops, + steps two will run, but it will only execute one iteration as the job will be in a terminal state. + + If you in your pipeline do not call the wait_for_pipeline method but pass wait_until_finish=True + to the operator, the second loop will wait for the job's terminal state. + + If you in your pipeline do not call the wait_for_pipeline method, and pass wait_until_finish=False + to the operator, the second loop will check once is job not in terminal state and exit the loop. + :type wait_until_finished: Optional[bool] + + Note that both + ``dataflow_default_options`` and ``options`` will be merged to specify pipeline + execution parameter, and ``dataflow_default_options`` is expected to save + high-level options, for instances, project and zone information, which + apply to all dataflow operators in the DAG. + + It's a good practice to define dataflow_* parameters in the default_args of the dag + like the project, zone and staging location. + + .. code-block:: python + + default_args = { + 'dataflow_default_options': { + 'zone': 'europe-west1-d', + 'stagingLocation': 'gs://my-staging-bucket/staging/' + } + } + + You need to pass the path to your dataflow as a file reference with the ``jar`` + parameter, the jar needs to be a self executing jar (see documentation here: + https://beam.apache.org/documentation/runners/dataflow/#self-executing-jar). + Use ``options`` to pass on options to your job. + + .. code-block:: python + + t1 = DataFlowJavaOperator( + task_id='dataflow_example', + jar='{{var.value.gcp_dataflow_base}}pipeline/build/libs/pipeline-example-1.0.jar', + options={ + 'autoscalingAlgorithm': 'BASIC', + 'maxNumWorkers': '50', + 'start': '{{ds}}', + 'partitionType': 'DAY', + 'labels': {'foo' : 'bar'} + }, + gcp_conn_id='airflow-conn-id', + dag=my-dag) + + """ + + template_fields = ["options", "jar", "job_name"] + ui_color = "#0273d4" + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + jar: str, + job_name: str = "{{task.task_id}}", + dataflow_default_options: Optional[dict] = None, + options: Optional[dict] = None, + project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + poll_sleep: int = 10, + job_class: Optional[str] = None, + check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun, + multiple_jobs: Optional[bool] = None, + cancel_timeout: Optional[int] = 10 * 60, + wait_until_finished: Optional[bool] = None, + **kwargs, + ) -> None: + # TODO: Remove one day + warnings.warn( + "The `{cls}` operator is deprecated, please use " + "`providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator` instead." + "".format(cls=self.__class__.__name__), + DeprecationWarning, + stacklevel=2, + ) + super().__init__(**kwargs) + + dataflow_default_options = dataflow_default_options or {} + options = options or {} + options.setdefault("labels", {}).update( + {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")} + ) + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.jar = jar + self.multiple_jobs = multiple_jobs + self.job_name = job_name + self.dataflow_default_options = dataflow_default_options + self.options = options + self.poll_sleep = poll_sleep + self.job_class = job_class + self.check_if_running = check_if_running + self.cancel_timeout = cancel_timeout + self.wait_until_finished = wait_until_finished + self.job_id = None + self.beam_hook: Optional[BeamHook] = None + self.dataflow_hook: Optional[DataflowHook] = None + + def execute(self, context): + """Execute the Apache Beam Pipeline.""" + self.beam_hook = BeamHook(runner=BeamRunnerType.DataflowRunner) + self.dataflow_hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + poll_sleep=self.poll_sleep, + cancel_timeout=self.cancel_timeout, + wait_until_finished=self.wait_until_finished, + ) + job_name = self.dataflow_hook.build_dataflow_job_name(job_name=self.job_name) + pipeline_options = copy.deepcopy(self.dataflow_default_options) + + pipeline_options["jobName"] = self.job_name + pipeline_options["project"] = self.project_id or self.dataflow_hook.project_id + pipeline_options["region"] = self.location + pipeline_options.update(self.options) + pipeline_options.setdefault("labels", {}).update( + {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")} + ) + pipeline_options.update(self.options) + + def set_current_job_id(job_id): + self.job_id = job_id + + process_line_callback = process_line_and_extract_dataflow_job_id_callback( + on_new_job_id_callback=set_current_job_id + ) + + with ExitStack() as exit_stack: + if self.jar.lower().startswith("gs://"): + gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to) + tmp_gcs_file = exit_stack.enter_context( # pylint: disable=no-member + gcs_hook.provide_file(object_url=self.jar) + ) + self.jar = tmp_gcs_file.name + + is_running = False + if self.check_if_running != CheckJobRunning.IgnoreJob: + is_running = self.dataflow_hook.is_job_dataflow_running( # pylint: disable=no-value-for-parameter + name=self.job_name, + variables=pipeline_options, + ) + while ( + is_running + and self.check_if_running == CheckJobRunning.WaitForRun + ): + # pylint: disable=no-value-for-parameter + is_running = self.dataflow_hook.is_job_dataflow_running( + name=self.job_name, + variables=pipeline_options, + ) + if not is_running: + pipeline_options["jobName"] = job_name + self.beam_hook.start_java_pipeline( + variables=pipeline_options, + jar=self.jar, + job_class=self.job_class, + process_line_callback=process_line_callback, + ) + self.dataflow_hook.wait_for_done( # pylint: disable=no-value-for-parameter + job_name=job_name, + location=self.location, + job_id=self.job_id, + multiple_jobs=self.multiple_jobs, + ) + + return {"job_id": self.job_id} + + def on_kill(self) -> None: + self.log.info("On kill.") + if self.job_id: + self.dataflow_hook.cancel_job( + job_id=self.job_id, + project_id=self.project_id or self.dataflow_hook.project_id, + ) + + +# pylint: disable=too-many-instance-attributes +class DataflowTemplatedJobStartOperator(BaseOperator): + """ + Start a Templated Cloud DataFlow job. The parameters of the operation + will be passed to the job. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataflowTemplatedJobStartOperator` + + :param template: The reference to the DataFlow template. + :type template: str + :param job_name: The 'jobName' to use when executing the DataFlow template + (templated). + :type job_name: Optional[str] + :param options: Map of job runtime environment options. + It will update environment argument if passed. + + .. seealso:: + For more information on possible configurations, look at the API documentation + `https://cloud.google.com/dataflow/pipelines/specifying-exec-params + `__ + + :type options: dict + :param dataflow_default_options: Map of default job environment options. + :type dataflow_default_options: dict + :param parameters: Map of job specific parameters for the template. + :type parameters: dict + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param location: Job location. + :type location: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param poll_sleep: The time in seconds to sleep between polling Google + Cloud Platform for the dataflow job status while the job is in the + JOB_STATE_RUNNING state. + :type poll_sleep: int + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + :param environment: Optional, Map of job runtime environment options. + + .. seealso:: + For more information on possible configurations, look at the API documentation + `https://cloud.google.com/dataflow/pipelines/specifying-exec-params + `__ + :type environment: Optional[dict] + :param cancel_timeout: How long (in seconds) operator should wait for the pipeline to be + successfully cancelled when task is being killed. + :type cancel_timeout: Optional[int] + :param wait_until_finished: (Optional) + If True, wait for the end of pipeline execution before exiting. + If False, only submits job. + If None, default behavior. + + The default behavior depends on the type of pipeline: + + * for the streaming pipeline, wait for jobs to start, + * for the batch pipeline, wait for the jobs to complete. + + .. warning:: + + You cannot call ``PipelineResult.wait_until_finish`` method in your pipeline code for the operator + to work properly. i. e. you must use asynchronous execution. Otherwise, your pipeline will + always wait until finished. For more information, look at: + `Asynchronous execution + `__ + + The process of starting the Dataflow job in Airflow consists of two steps: + + * running a subprocess and reading the stderr/stderr log for the job id. + * loop waiting for the end of the job ID from the previous step. + This loop checks the status of the job. + + Step two is started just after step one has finished, so if you have wait_until_finished in your + pipeline code, step two will not start until the process stops. When this process stops, + steps two will run, but it will only execute one iteration as the job will be in a terminal state. + + If you in your pipeline do not call the wait_for_pipeline method but pass wait_until_finish=True + to the operator, the second loop will wait for the job's terminal state. + + If you in your pipeline do not call the wait_for_pipeline method, and pass wait_until_finish=False + to the operator, the second loop will check once is job not in terminal state and exit the loop. + :type wait_until_finished: Optional[bool] + + It's a good practice to define dataflow_* parameters in the default_args of the dag + like the project, zone and staging location. + + .. seealso:: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/LaunchTemplateParameters + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment + + .. code-block:: python + + default_args = { + 'dataflow_default_options': { + 'zone': 'europe-west1-d', + 'tempLocation': 'gs://my-staging-bucket/staging/', + } + } + } + + You need to pass the path to your dataflow template as a file reference with the + ``template`` parameter. Use ``parameters`` to pass on parameters to your job. + Use ``environment`` to pass on runtime environment variables to your job. + + .. code-block:: python + + t1 = DataflowTemplateOperator( + task_id='dataflow_example', + template='{{var.value.gcp_dataflow_base}}', + parameters={ + 'inputFile': "gs://bucket/input/my_input.txt", + 'outputFile': "gs://bucket/output/my_output.txt" + }, + gcp_conn_id='airflow-conn-id', + dag=my-dag) + + ``template``, ``dataflow_default_options``, ``parameters``, and ``job_name`` are + templated so you can use variables in them. + + Note that ``dataflow_default_options`` is expected to save high-level options + for project information, which apply to all dataflow operators in the DAG. + + .. seealso:: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3 + /LaunchTemplateParameters + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment + For more detail on job template execution have a look at the reference: + https://cloud.google.com/dataflow/docs/templates/executing-templates + """ + + template_fields = [ + "template", + "job_name", + "options", + "parameters", + "project_id", + "location", + "gcp_conn_id", + "impersonation_chain", + "environment", + ] + ui_color = "#0273d4" + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + template: str, + job_name: str = "{{task.task_id}}", + options: Optional[Dict[str, Any]] = None, + dataflow_default_options: Optional[Dict[str, Any]] = None, + parameters: Optional[Dict[str, str]] = None, + project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + poll_sleep: int = 10, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + environment: Optional[Dict] = None, + cancel_timeout: Optional[int] = 10 * 60, + wait_until_finished: Optional[bool] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.template = template + self.job_name = job_name + self.options = options or {} + self.dataflow_default_options = dataflow_default_options or {} + self.parameters = parameters or {} + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.poll_sleep = poll_sleep + self.job_id = None + self.hook: Optional[DataflowHook] = None + self.impersonation_chain = impersonation_chain + self.environment = environment + self.cancel_timeout = cancel_timeout + self.wait_until_finished = wait_until_finished + + def execute(self, context) -> dict: + self.hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + cancel_timeout=self.cancel_timeout, + wait_until_finished=self.wait_until_finished, + ) + + def set_current_job_id(job_id): + self.job_id = job_id + + options = self.dataflow_default_options + options.update(self.options) + job = self.hook.start_template_dataflow( + job_name=self.job_name, + variables=options, + parameters=self.parameters, + dataflow_template=self.template, + on_new_job_id_callback=set_current_job_id, + project_id=self.project_id, + location=self.location, + environment=self.environment, + ) + + return job + + def on_kill(self) -> None: + self.log.info("On kill.") + if self.job_id: + self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id) + + +class DataflowStartFlexTemplateOperator(BaseOperator): + """ + Starts flex templates with the Dataflow pipeline. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataflowStartFlexTemplateOperator` + + :param body: The request body. See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body + :param location: The location of the Dataflow job (for example europe-west1) + :type location: str + :param project_id: The ID of the GCP project that owns the job. + If set to ``None`` or missing, the default project_id from the GCP connection is used. + :type project_id: Optional[str] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud + Platform. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param drain_pipeline: Optional, set to True if want to stop streaming job by draining it + instead of canceling during killing task instance. See: + https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + :type drain_pipeline: bool + :param cancel_timeout: How long (in seconds) operator should wait for the pipeline to be + successfully cancelled when task is being killed. + :type cancel_timeout: Optional[int] + :param wait_until_finished: (Optional) + If True, wait for the end of pipeline execution before exiting. + If False, only submits job. + If None, default behavior. + + The default behavior depends on the type of pipeline: + + * for the streaming pipeline, wait for jobs to start, + * for the batch pipeline, wait for the jobs to complete. + + .. warning:: + + You cannot call ``PipelineResult.wait_until_finish`` method in your pipeline code for the operator + to work properly. i. e. you must use asynchronous execution. Otherwise, your pipeline will + always wait until finished. For more information, look at: + `Asynchronous execution + `__ + + The process of starting the Dataflow job in Airflow consists of two steps: + + * running a subprocess and reading the stderr/stderr log for the job id. + * loop waiting for the end of the job ID from the previous step. + This loop checks the status of the job. + + Step two is started just after step one has finished, so if you have wait_until_finished in your + pipeline code, step two will not start until the process stops. When this process stops, + steps two will run, but it will only execute one iteration as the job will be in a terminal state. + + If you in your pipeline do not call the wait_for_pipeline method but pass wait_until_finish=True + to the operator, the second loop will wait for the job's terminal state. + + If you in your pipeline do not call the wait_for_pipeline method, and pass wait_until_finish=False + to the operator, the second loop will check once is job not in terminal state and exit the loop. + :type wait_until_finished: Optional[bool] + """ + + template_fields = ["body", "location", "project_id", "gcp_conn_id"] + + @apply_defaults + def __init__( + self, + body: Dict, + location: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + drain_pipeline: bool = False, + cancel_timeout: Optional[int] = 10 * 60, + wait_until_finished: Optional[bool] = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.body = body + self.location = location + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.drain_pipeline = drain_pipeline + self.cancel_timeout = cancel_timeout + self.wait_until_finished = wait_until_finished + self.job_id = None + self.hook: Optional[DataflowHook] = None + + def execute(self, context): + self.hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + drain_pipeline=self.drain_pipeline, + cancel_timeout=self.cancel_timeout, + wait_until_finished=self.wait_until_finished, + ) + + def set_current_job_id(job_id): + self.job_id = job_id + + job = self.hook.start_flex_template( + body=self.body, + location=self.location, + project_id=self.project_id, + on_new_job_id_callback=set_current_job_id, + ) + + return job + + def on_kill(self) -> None: + self.log.info("On kill.") + if self.job_id: + self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id) + + +class DataflowStartSqlJobOperator(BaseOperator): + """ + Starts Dataflow SQL query. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataflowStartSqlJobOperator` + + .. warning:: + This operator requires ``gcloud`` command (Google Cloud SDK) must be installed on the Airflow worker + `__ + + :param job_name: The unique name to assign to the Cloud Dataflow job. + :type job_name: str + :param query: The SQL query to execute. + :type query: str + :param options: Job parameters to be executed. It can be a dictionary with the following keys. + + For more information, look at: + `https://cloud.google.com/sdk/gcloud/reference/beta/dataflow/sql/query + `__ + command reference + + :type options: dict + :param location: The location of the Dataflow job (for example europe-west1) + :type location: str + :param project_id: The ID of the GCP project that owns the job. + If set to ``None`` or missing, the default project_id from the GCP connection is used. + :type project_id: Optional[str] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud + Platform. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param drain_pipeline: Optional, set to True if want to stop streaming job by draining it + instead of canceling during killing task instance. See: + https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + :type drain_pipeline: bool + """ + + template_fields = [ + "job_name", + "query", + "options", + "location", + "project_id", + "gcp_conn_id", + ] + + @apply_defaults + def __init__( + self, + job_name: str, + query: str, + options: Dict[str, Any], + location: str = DEFAULT_DATAFLOW_LOCATION, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + drain_pipeline: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.job_name = job_name + self.query = query + self.options = options + self.location = location + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.drain_pipeline = drain_pipeline + self.job_id = None + self.hook: Optional[DataflowHook] = None + + def execute(self, context): + self.hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + drain_pipeline=self.drain_pipeline, + ) + + def set_current_job_id(job_id): + self.job_id = job_id + + job = self.hook.start_sql_job( + job_name=self.job_name, + query=self.query, + options=self.options, + location=self.location, + project_id=self.project_id, + on_new_job_id_callback=set_current_job_id, + ) + + return job + + def on_kill(self) -> None: + self.log.info("On kill.") + if self.job_id: + self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id) + + +# pylint: disable=too-many-instance-attributes +class DataflowCreatePythonJobOperator(BaseOperator): + """ + Launching Cloud Dataflow jobs written in python. Note that both + dataflow_default_options and options will be merged to specify pipeline + execution parameter, and dataflow_default_options is expected to save + high-level options, for instances, project and zone information, which + apply to all dataflow operators in the DAG. + + This class is deprecated. + Please use `providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator`. + + .. seealso:: + For more detail on job submission have a look at the reference: + https://cloud.google.com/dataflow/pipelines/specifying-exec-params + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataflowCreatePythonJobOperator` + + :param py_file: Reference to the python dataflow pipeline file.py, e.g., + /some/local/file/path/to/your/python/pipeline/file. (templated) + :type py_file: str + :param job_name: The 'job_name' to use when executing the DataFlow job + (templated). This ends up being set in the pipeline options, so any entry + with key ``'jobName'`` or ``'job_name'`` in ``options`` will be overwritten. + :type job_name: str + :param py_options: Additional python options, e.g., ["-m", "-v"]. + :type py_options: list[str] + :param dataflow_default_options: Map of default job options. + :type dataflow_default_options: dict + :param options: Map of job specific options.The key must be a dictionary. + The value can contain different types: + + * If the value is None, the single option - ``--key`` (without value) will be added. + * If the value is False, this option will be skipped + * If the value is True, the single option - ``--key`` (without value) will be added. + * If the value is list, the many options will be added for each key. + If the value is ``['A', 'B']`` and the key is ``key`` then the ``--key=A --key-B`` options + will be left + * Other value types will be replaced with the Python textual representation. + + When defining labels (``labels`` option), you can also provide a dictionary. + :type options: dict + :param py_interpreter: Python version of the beam pipeline. + If None, this defaults to the python3. + To track python versions supported by beam and related + issues check: https://issues.apache.org/jira/browse/BEAM-1251 + :type py_interpreter: str + :param py_requirements: Additional python package(s) to install. + If a value is passed to this parameter, a new virtual environment has been created with + additional packages installed. + + You could also install the apache_beam package if it is not installed on your system or you want + to use a different version. + :type py_requirements: List[str] + :param py_system_site_packages: Whether to include system_site_packages in your virtualenv. + See virtualenv documentation for more information. + + This option is only relevant if the ``py_requirements`` parameter is not None. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param location: Job location. + :type location: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param poll_sleep: The time in seconds to sleep between polling Google + Cloud Platform for the dataflow job status while the job is in the + JOB_STATE_RUNNING state. + :type poll_sleep: int + :param drain_pipeline: Optional, set to True if want to stop streaming job by draining it + instead of canceling during killing task instance. See: + https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + :type drain_pipeline: bool + :param cancel_timeout: How long (in seconds) operator should wait for the pipeline to be + successfully cancelled when task is being killed. + :type cancel_timeout: Optional[int] + :param wait_until_finished: (Optional) + If True, wait for the end of pipeline execution before exiting. + If False, only submits job. + If None, default behavior. + + The default behavior depends on the type of pipeline: + + * for the streaming pipeline, wait for jobs to start, + * for the batch pipeline, wait for the jobs to complete. + + .. warning:: + + You cannot call ``PipelineResult.wait_until_finish`` method in your pipeline code for the operator + to work properly. i. e. you must use asynchronous execution. Otherwise, your pipeline will + always wait until finished. For more information, look at: + `Asynchronous execution + `__ + + The process of starting the Dataflow job in Airflow consists of two steps: + + * running a subprocess and reading the stderr/stderr log for the job id. + * loop waiting for the end of the job ID from the previous step. + This loop checks the status of the job. + + Step two is started just after step one has finished, so if you have wait_until_finished in your + pipeline code, step two will not start until the process stops. When this process stops, + steps two will run, but it will only execute one iteration as the job will be in a terminal state. + + If you in your pipeline do not call the wait_for_pipeline method but pass wait_until_finish=True + to the operator, the second loop will wait for the job's terminal state. + + If you in your pipeline do not call the wait_for_pipeline method, and pass wait_until_finish=False + to the operator, the second loop will check once is job not in terminal state and exit the loop. + :type wait_until_finished: Optional[bool] + """ + + template_fields = ["options", "dataflow_default_options", "job_name", "py_file"] + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + py_file: str, + job_name: str = "{{task.task_id}}", + dataflow_default_options: Optional[dict] = None, + options: Optional[dict] = None, + py_interpreter: str = "python3", + py_options: Optional[List[str]] = None, + py_requirements: Optional[List[str]] = None, + py_system_site_packages: bool = False, + project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + poll_sleep: int = 10, + drain_pipeline: bool = False, + cancel_timeout: Optional[int] = 10 * 60, + wait_until_finished: Optional[bool] = None, + **kwargs, + ) -> None: + # TODO: Remove one day + warnings.warn( + "The `{cls}` operator is deprecated, please use " + "`providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator` instead." + "".format(cls=self.__class__.__name__), + DeprecationWarning, + stacklevel=2, + ) + super().__init__(**kwargs) + + self.py_file = py_file + self.job_name = job_name + self.py_options = py_options or [] + self.dataflow_default_options = dataflow_default_options or {} + self.options = options or {} + self.options.setdefault("labels", {}).update( + {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")} + ) + self.py_interpreter = py_interpreter + self.py_requirements = py_requirements + self.py_system_site_packages = py_system_site_packages + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.poll_sleep = poll_sleep + self.drain_pipeline = drain_pipeline + self.cancel_timeout = cancel_timeout + self.wait_until_finished = wait_until_finished + self.job_id = None + self.beam_hook: Optional[BeamHook] = None + self.dataflow_hook: Optional[DataflowHook] = None + + def execute(self, context): + """Execute the python dataflow job.""" + self.beam_hook = BeamHook(runner=BeamRunnerType.DataflowRunner) + self.dataflow_hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + poll_sleep=self.poll_sleep, + impersonation_chain=None, + drain_pipeline=self.drain_pipeline, + cancel_timeout=self.cancel_timeout, + wait_until_finished=self.wait_until_finished, + ) + + job_name = self.dataflow_hook.build_dataflow_job_name(job_name=self.job_name) + pipeline_options = self.dataflow_default_options.copy() + pipeline_options["job_name"] = job_name + pipeline_options["project"] = self.project_id or self.dataflow_hook.project_id + pipeline_options["region"] = self.location + pipeline_options.update(self.options) + + # Convert argument names from lowerCamelCase to snake case. + camel_to_snake = lambda name: re.sub( + r"[A-Z]", lambda x: "_" + x.group(0).lower(), name + ) + formatted_pipeline_options = { + camel_to_snake(key): pipeline_options[key] for key in pipeline_options + } + + def set_current_job_id(job_id): + self.job_id = job_id + + process_line_callback = process_line_and_extract_dataflow_job_id_callback( + on_new_job_id_callback=set_current_job_id + ) + + with ExitStack() as exit_stack: + if self.py_file.lower().startswith("gs://"): + gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to) + tmp_gcs_file = exit_stack.enter_context( # pylint: disable=no-member + gcs_hook.provide_file(object_url=self.py_file) + ) + self.py_file = tmp_gcs_file.name + + self.beam_hook.start_python_pipeline( + variables=formatted_pipeline_options, + py_file=self.py_file, + py_options=self.py_options, + py_interpreter=self.py_interpreter, + py_requirements=self.py_requirements, + py_system_site_packages=self.py_system_site_packages, + process_line_callback=process_line_callback, + ) + + self.dataflow_hook.wait_for_done( # pylint: disable=no-value-for-parameter + job_name=job_name, + location=self.location, + job_id=self.job_id, + multiple_jobs=False, + ) + + return {"job_id": self.job_id} + + def on_kill(self) -> None: + self.log.info("On kill.") + if self.job_id: + self.dataflow_hook.cancel_job( + job_id=self.job_id, + project_id=self.project_id or self.dataflow_hook.project_id, + ) diff --git a/reference/providers/google/cloud/operators/datafusion.py b/reference/providers/google/cloud/operators/datafusion.py new file mode 100644 index 0000000..336e063 --- /dev/null +++ b/reference/providers/google/cloud/operators/datafusion.py @@ -0,0 +1,957 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains Google DataFusion operators.""" +from time import sleep +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook +from airflow.utils.decorators import apply_defaults +from google.api_core.retry import exponential_sleep_generator +from googleapiclient.errors import HttpError + + +class CloudDataFusionRestartInstanceOperator(BaseOperator): + """ + Restart a single Data Fusion instance. + At the end of an operation instance is fully restarted. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataFusionRestartInstanceOperator` + + :param instance_name: The name of the instance to restart. + :type instance_name: str + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param project_id: The ID of the Google Cloud project that the instance belongs to. + :type project_id: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "instance_name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + instance_name: str, + location: str, + project_id: Optional[str] = None, + api_version: str = "v1beta1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.instance_name = instance_name + self.location = location + self.project_id = project_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = DataFusionHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Restarting Data Fusion instance: %s", self.instance_name) + operation = hook.restart_instance( + instance_name=self.instance_name, + location=self.location, + project_id=self.project_id, + ) + hook.wait_for_operation(operation) + self.log.info("Instance %s restarted successfully", self.instance_name) + + +class CloudDataFusionDeleteInstanceOperator(BaseOperator): + """ + Deletes a single Date Fusion instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataFusionDeleteInstanceOperator` + + :param instance_name: The name of the instance to restart. + :type instance_name: str + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param project_id: The ID of the Google Cloud project that the instance belongs to. + :type project_id: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "instance_name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + instance_name: str, + location: str, + project_id: Optional[str] = None, + api_version: str = "v1beta1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.instance_name = instance_name + self.location = location + self.project_id = project_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = DataFusionHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Deleting Data Fusion instance: %s", self.instance_name) + operation = hook.delete_instance( + instance_name=self.instance_name, + location=self.location, + project_id=self.project_id, + ) + hook.wait_for_operation(operation) + self.log.info("Instance %s deleted successfully", self.instance_name) + + +class CloudDataFusionCreateInstanceOperator(BaseOperator): + """ + Creates a new Data Fusion instance in the specified project and location. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataFusionCreateInstanceOperator` + + :param instance_name: The name of the instance to create. + :type instance_name: str + :param instance: An instance of Instance. + https://cloud.google.com/data-fusion/docs/reference/rest/v1beta1/projects.locations.instances#Instance + :type instance: Dict[str, Any] + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param project_id: The ID of the Google Cloud project that the instance belongs to. + :type project_id: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "instance_name", + "instance", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + instance_name: str, + instance: Dict[str, Any], + location: str, + project_id: Optional[str] = None, + api_version: str = "v1beta1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.instance_name = instance_name + self.instance = instance + self.location = location + self.project_id = project_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> dict: + hook = DataFusionHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Creating Data Fusion instance: %s", self.instance_name) + try: + operation = hook.create_instance( + instance_name=self.instance_name, + instance=self.instance, + location=self.location, + project_id=self.project_id, + ) + instance = hook.wait_for_operation(operation) + self.log.info("Instance %s created successfully", self.instance_name) + except HttpError as err: + if err.resp.status not in (409, "409"): + raise + self.log.info("Instance %s already exists", self.instance_name) + instance = hook.get_instance( + instance_name=self.instance_name, + location=self.location, + project_id=self.project_id, + ) + # Wait for instance to be ready + for time_to_wait in exponential_sleep_generator(initial=10, maximum=120): + if instance["state"] != "CREATING": + break + sleep(time_to_wait) + instance = hook.get_instance( + instance_name=self.instance_name, + location=self.location, + project_id=self.project_id, + ) + return instance + + +class CloudDataFusionUpdateInstanceOperator(BaseOperator): + """ + Updates a single Data Fusion instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataFusionUpdateInstanceOperator` + + :param instance_name: The name of the instance to create. + :type instance_name: str + :param instance: An instance of Instance. + https://cloud.google.com/data-fusion/docs/reference/rest/v1beta1/projects.locations.instances#Instance + :type instance: Dict[str, Any] + :param update_mask: Field mask is used to specify the fields that the update will overwrite + in an instance resource. The fields specified in the updateMask are relative to the resource, + not the full request. A field will be overwritten if it is in the mask. If the user does not + provide a mask, all the supported fields (labels and options currently) will be overwritten. + A comma-separated list of fully qualified names of fields. Example: "user.displayName,photo". + https://developers.google.com/protocol-buffers/docs/reference/google.protobuf?_ga=2.205612571.-968688242.1573564810#google.protobuf.FieldMask + :type update_mask: str + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param project_id: The ID of the Google Cloud project that the instance belongs to. + :type project_id: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "instance_name", + "instance", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + instance_name: str, + instance: Dict[str, Any], + update_mask: str, + location: str, + project_id: Optional[str] = None, + api_version: str = "v1beta1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.update_mask = update_mask + self.instance_name = instance_name + self.instance = instance + self.location = location + self.project_id = project_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = DataFusionHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Updating Data Fusion instance: %s", self.instance_name) + operation = hook.patch_instance( + instance_name=self.instance_name, + instance=self.instance, + update_mask=self.update_mask, + location=self.location, + project_id=self.project_id, + ) + hook.wait_for_operation(operation) + self.log.info("Instance %s updated successfully", self.instance_name) + + +class CloudDataFusionGetInstanceOperator(BaseOperator): + """ + Gets details of a single Data Fusion instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataFusionGetInstanceOperator` + + :param instance_name: The name of the instance. + :type instance_name: str + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param project_id: The ID of the Google Cloud project that the instance belongs to. + :type project_id: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "instance_name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + instance_name: str, + location: str, + project_id: Optional[str] = None, + api_version: str = "v1beta1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.instance_name = instance_name + self.location = location + self.project_id = project_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> dict: + hook = DataFusionHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Retrieving Data Fusion instance: %s", self.instance_name) + instance = hook.get_instance( + instance_name=self.instance_name, + location=self.location, + project_id=self.project_id, + ) + return instance + + +class CloudDataFusionCreatePipelineOperator(BaseOperator): + """ + Creates a Cloud Data Fusion pipeline. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataFusionCreatePipelineOperator` + + :param pipeline_name: Your pipeline name. + :type pipeline_name: str + :param pipeline: The pipeline definition. For more information check: + https://docs.cdap.io/cdap/current/en/developer-manual/pipelines/developing-pipelines.html#pipeline-configuration-file-format + :type pipeline: Dict[str, Any] + :param instance_name: The name of the instance. + :type instance_name: str + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param namespace: If your pipeline belongs to a Basic edition instance, the namespace ID + is always default. If your pipeline belongs to an Enterprise edition instance, you + can create a namespace. + :type namespace: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "instance_name", + "pipeline_name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + pipeline_name: str, + pipeline: Dict[str, Any], + instance_name: str, + location: str, + namespace: str = "default", + project_id: Optional[str] = None, + api_version: str = "v1beta1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.pipeline_name = pipeline_name + self.pipeline = pipeline + self.namespace = namespace + self.instance_name = instance_name + self.location = location + self.project_id = project_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = DataFusionHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Creating Data Fusion pipeline: %s", self.pipeline_name) + instance = hook.get_instance( + instance_name=self.instance_name, + location=self.location, + project_id=self.project_id, + ) + api_url = instance["apiEndpoint"] + hook.create_pipeline( + pipeline_name=self.pipeline_name, + pipeline=self.pipeline, + instance_url=api_url, + namespace=self.namespace, + ) + self.log.info("Pipeline created") + + +class CloudDataFusionDeletePipelineOperator(BaseOperator): + """ + Deletes a Cloud Data Fusion pipeline. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataFusionDeletePipelineOperator` + + :param pipeline_name: Your pipeline name. + :type pipeline_name: str + :param version_id: Version of pipeline to delete + :type version_id: Optional[str] + :param instance_name: The name of the instance. + :type instance_name: str + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param namespace: If your pipeline belongs to a Basic edition instance, the namespace ID + is always default. If your pipeline belongs to an Enterprise edition instance, you + can create a namespace. + :type namespace: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "instance_name", + "version_id", + "pipeline_name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + pipeline_name: str, + instance_name: str, + location: str, + version_id: Optional[str] = None, + namespace: str = "default", + project_id: Optional[str] = None, + api_version: str = "v1beta1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.pipeline_name = pipeline_name + self.version_id = version_id + self.namespace = namespace + self.instance_name = instance_name + self.location = location + self.project_id = project_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = DataFusionHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Deleting Data Fusion pipeline: %s", self.pipeline_name) + instance = hook.get_instance( + instance_name=self.instance_name, + location=self.location, + project_id=self.project_id, + ) + api_url = instance["apiEndpoint"] + hook.delete_pipeline( + pipeline_name=self.pipeline_name, + version_id=self.version_id, + instance_url=api_url, + namespace=self.namespace, + ) + self.log.info("Pipeline deleted") + + +class CloudDataFusionListPipelinesOperator(BaseOperator): + """ + Lists Cloud Data Fusion pipelines. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataFusionListPipelinesOperator` + + + :param instance_name: The name of the instance. + :type instance_name: str + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param artifact_version: Artifact version to filter instances + :type artifact_version: Optional[str] + :param artifact_name: Artifact name to filter instances + :type artifact_name: Optional[str] + :param namespace: If your pipeline belongs to a Basic edition instance, the namespace ID + is always default. If your pipeline belongs to an Enterprise edition instance, you + can create a namespace. + :type namespace: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "instance_name", + "artifact_name", + "artifact_version", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + instance_name: str, + location: str, + artifact_name: Optional[str] = None, + artifact_version: Optional[str] = None, + namespace: str = "default", + project_id: Optional[str] = None, + api_version: str = "v1beta1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.artifact_version = artifact_version + self.artifact_name = artifact_name + self.namespace = namespace + self.instance_name = instance_name + self.location = location + self.project_id = project_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> dict: + hook = DataFusionHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Listing Data Fusion pipelines") + instance = hook.get_instance( + instance_name=self.instance_name, + location=self.location, + project_id=self.project_id, + ) + api_url = instance["apiEndpoint"] + pipelines = hook.list_pipelines( + instance_url=api_url, + namespace=self.namespace, + artifact_version=self.artifact_version, + artifact_name=self.artifact_name, + ) + self.log.info("%s", pipelines) + return pipelines + + +class CloudDataFusionStartPipelineOperator(BaseOperator): + """ + Starts a Cloud Data Fusion pipeline. Works for both batch and stream pipelines. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataFusionStartPipelineOperator` + + :param pipeline_name: Your pipeline name. + :type pipeline_name: str + :param instance_name: The name of the instance. + :type instance_name: str + :param success_states: If provided the operator will wait for pipeline to be in one of + the provided states. + :type success_states: List[str] + :param pipeline_timeout: How long (in seconds) operator should wait for the pipeline to be in one of + ``success_states``. Works only if ``success_states`` are provided. + :type pipeline_timeout: int + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param runtime_args: Optional runtime args to be passed to the pipeline + :type runtime_args: dict + :param namespace: If your pipeline belongs to a Basic edition instance, the namespace ID + is always default. If your pipeline belongs to an Enterprise edition instance, you + can create a namespace. + :type namespace: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "instance_name", + "pipeline_name", + "runtime_args", + "impersonation_chain", + ) + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + pipeline_name: str, + instance_name: str, + location: str, + runtime_args: Optional[Dict[str, Any]] = None, + success_states: Optional[List[str]] = None, + namespace: str = "default", + pipeline_timeout: int = 10 * 60, + project_id: Optional[str] = None, + api_version: str = "v1beta1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.pipeline_name = pipeline_name + self.success_states = success_states + self.runtime_args = runtime_args + self.pipeline_timeout = pipeline_timeout + self.namespace = namespace + self.instance_name = instance_name + self.location = location + self.project_id = project_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = DataFusionHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Starting Data Fusion pipeline: %s", self.pipeline_name) + instance = hook.get_instance( + instance_name=self.instance_name, + location=self.location, + project_id=self.project_id, + ) + api_url = instance["apiEndpoint"] + pipeline_id = hook.start_pipeline( + pipeline_name=self.pipeline_name, + instance_url=api_url, + namespace=self.namespace, + runtime_args=self.runtime_args, + ) + + self.log.info("Pipeline started") + if self.success_states: + hook.wait_for_pipeline_state( + success_states=self.success_states, + pipeline_id=pipeline_id, + pipeline_name=self.pipeline_name, + namespace=self.namespace, + instance_url=api_url, + timeout=self.pipeline_timeout, + ) + + +class CloudDataFusionStopPipelineOperator(BaseOperator): + """ + Stops a Cloud Data Fusion pipeline. Works for both batch and stream pipelines. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataFusionStopPipelineOperator` + + :param pipeline_name: Your pipeline name. + :type pipeline_name: str + :param instance_name: The name of the instance. + :type instance_name: str + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param namespace: If your pipeline belongs to a Basic edition instance, the namespace ID + is always default. If your pipeline belongs to an Enterprise edition instance, you + can create a namespace. + :type namespace: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "instance_name", + "pipeline_name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + pipeline_name: str, + instance_name: str, + location: str, + namespace: str = "default", + project_id: Optional[str] = None, + api_version: str = "v1beta1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.pipeline_name = pipeline_name + self.namespace = namespace + self.instance_name = instance_name + self.location = location + self.project_id = project_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = DataFusionHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Starting Data Fusion pipeline: %s", self.pipeline_name) + instance = hook.get_instance( + instance_name=self.instance_name, + location=self.location, + project_id=self.project_id, + ) + api_url = instance["apiEndpoint"] + hook.stop_pipeline( + pipeline_name=self.pipeline_name, + instance_url=api_url, + namespace=self.namespace, + ) + self.log.info("Pipeline started") diff --git a/reference/providers/google/cloud/operators/dataprep.py b/reference/providers/google/cloud/operators/dataprep.py new file mode 100644 index 0000000..d9be0b6 --- /dev/null +++ b/reference/providers/google/cloud/operators/dataprep.py @@ -0,0 +1,139 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Dataprep operator.""" + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook +from airflow.utils.decorators import apply_defaults + + +class DataprepGetJobsForJobGroupOperator(BaseOperator): + """ + Get information about the batch jobs within a Cloud Dataprep job. + API documentation https://clouddataprep.com/documentation/api#section/Overview + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataprepGetJobsForJobGroupOperator` + + :param job_id The ID of the job that will be requests + :type job_id: int + """ + + template_fields = ("job_id",) + + @apply_defaults + def __init__( + self, *, dataprep_conn_id: str = "dataprep_default", job_id: int, **kwargs + ) -> None: + super().__init__(**kwargs) + self.dataprep_conn_id = (dataprep_conn_id,) + self.job_id = job_id + + def execute(self, context: dict) -> dict: + self.log.info("Fetching data for job with id: %d ...", self.job_id) + hook = GoogleDataprepHook( + dataprep_conn_id="dataprep_default", + ) + response = hook.get_jobs_for_job_group(job_id=self.job_id) + return response + + +class DataprepGetJobGroupOperator(BaseOperator): + """ + Get the specified job group. + A job group is a job that is executed from a specific node in a flow. + API documentation https://clouddataprep.com/documentation/api#section/Overview + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataprepGetJobGroupOperator` + + :param job_group_id: The ID of the job that will be requests + :type job_group_id: int + :param embed: Comma-separated list of objects to pull in as part of the response + :type embed: string + :param include_deleted: if set to "true", will include deleted objects + :type include_deleted: bool + """ + + template_fields = ("job_group_id", "embed") + + @apply_defaults + def __init__( + self, + *, + dataprep_conn_id: str = "dataprep_default", + job_group_id: int, + embed: str, + include_deleted: bool, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataprep_conn_id: str = dataprep_conn_id + self.job_group_id = job_group_id + self.embed = embed + self.include_deleted = include_deleted + + def execute(self, context: dict) -> dict: + self.log.info("Fetching data for job with id: %d ...", self.job_group_id) + hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) + response = hook.get_job_group( + job_group_id=self.job_group_id, + embed=self.embed, + include_deleted=self.include_deleted, + ) + return response + + +class DataprepRunJobGroupOperator(BaseOperator): + """ + Create a ``jobGroup``, which launches the specified job as the authenticated user. + This performs the same action as clicking on the Run Job button in the application. + To get recipe_id please follow the Dataprep API documentation + https://clouddataprep.com/documentation/api#operation/runJobGroup + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataprepRunJobGroupOperator` + + :param dataprep_conn_id: The Dataprep connection ID + :type dataprep_conn_id: str + :param body_request: Passed as the body_request to GoogleDataprepHook's run_job_group, + where it's the identifier for the recipe to run + :type body_request: dict + """ + + template_fields = ("body_request",) + + def __init__( + self, + *, + dataprep_conn_id: str = "dataprep_default", + body_request: dict, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.body_request = body_request + self.dataprep_conn_id = dataprep_conn_id + + def execute(self, context: None) -> dict: + self.log.info("Creating a job...") + hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) + response = hook.run_job_group(body_request=self.body_request) + return response diff --git a/reference/providers/google/cloud/operators/dataproc.py b/reference/providers/google/cloud/operators/dataproc.py new file mode 100644 index 0000000..eb4f6c2 --- /dev/null +++ b/reference/providers/google/cloud/operators/dataproc.py @@ -0,0 +1,2090 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains Google Dataproc operators.""" + +import inspect +import ntpath +import os +import re +import time +import uuid +import warnings +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Sequence, Set, Tuple, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.dataproc import ( + DataprocHook, + DataProcJobBuilder, +) +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils import timezone +from airflow.utils.decorators import apply_defaults +from google.api_core.exceptions import AlreadyExists, NotFound +from google.api_core.retry import Retry, exponential_sleep_generator +from google.cloud.dataproc_v1beta2 import Cluster # pylint: disable=no-name-in-module +from google.protobuf.duration_pb2 import Duration +from google.protobuf.field_mask_pb2 import FieldMask + + +# pylint: disable=too-many-instance-attributes +class ClusterGenerator: + """ + Create a new Dataproc Cluster. + + :param cluster_name: The name of the DataProc cluster to create. (templated) + :type cluster_name: str + :param project_id: The ID of the google cloud project in which + to create the cluster. (templated) + :type project_id: str + :param num_workers: The # of workers to spin up. If set to zero will + spin up cluster in a single node mode + :type num_workers: int + :param storage_bucket: The storage bucket to use, setting to None lets dataproc + generate a custom one for you + :type storage_bucket: str + :param init_actions_uris: List of GCS uri's containing + dataproc initialization scripts + :type init_actions_uris: list[str] + :param init_action_timeout: Amount of time executable scripts in + init_actions_uris has to complete + :type init_action_timeout: str + :param metadata: dict of key-value google compute engine metadata entries + to add to all instances + :type metadata: dict + :param image_version: the version of software inside the Dataproc cluster + :type image_version: str + :param custom_image: custom Dataproc image for more info see + https://cloud.google.com/dataproc/docs/guides/dataproc-images + :type custom_image: str + :param custom_image_project_id: project id for the custom Dataproc image, for more info see + https://cloud.google.com/dataproc/docs/guides/dataproc-images + :type custom_image_project_id: str + :param autoscaling_policy: The autoscaling policy used by the cluster. Only resource names + including projectid and location (region) are valid. Example: + ``projects/[projectId]/locations/[dataproc_region]/autoscalingPolicies/[policy_id]`` + :type autoscaling_policy: str + :param properties: dict of properties to set on + config files (e.g. spark-defaults.conf), see + https://cloud.google.com/dataproc/docs/reference/rest/v1/projects.regions.clusters#SoftwareConfig + :type properties: dict + :param optional_components: List of optional cluster components, for more info see + https://cloud.google.com/dataproc/docs/reference/rest/v1/ClusterConfig#Component + :type optional_components: list[str] + :param num_masters: The # of master nodes to spin up + :type num_masters: int + :param master_machine_type: Compute engine machine type to use for the master node + :type master_machine_type: str + :param master_disk_type: Type of the boot disk for the master node + (default is ``pd-standard``). + Valid values: ``pd-ssd`` (Persistent Disk Solid State Drive) or + ``pd-standard`` (Persistent Disk Hard Disk Drive). + :type master_disk_type: str + :param master_disk_size: Disk size for the master node + :type master_disk_size: int + :param worker_machine_type: Compute engine machine type to use for the worker nodes + :type worker_machine_type: str + :param worker_disk_type: Type of the boot disk for the worker node + (default is ``pd-standard``). + Valid values: ``pd-ssd`` (Persistent Disk Solid State Drive) or + ``pd-standard`` (Persistent Disk Hard Disk Drive). + :type worker_disk_type: str + :param worker_disk_size: Disk size for the worker nodes + :type worker_disk_size: int + :param num_preemptible_workers: The # of preemptible worker nodes to spin up + :type num_preemptible_workers: int + :param labels: dict of labels to add to the cluster + :type labels: dict + :param zone: The zone where the cluster will be located. Set to None to auto-zone. (templated) + :type zone: str + :param network_uri: The network uri to be used for machine communication, cannot be + specified with subnetwork_uri + :type network_uri: str + :param subnetwork_uri: The subnetwork uri to be used for machine communication, + cannot be specified with network_uri + :type subnetwork_uri: str + :param internal_ip_only: If true, all instances in the cluster will only + have internal IP addresses. This can only be enabled for subnetwork + enabled networks + :type internal_ip_only: bool + :param tags: The GCE tags to add to all instances + :type tags: list[str] + :param region: The specified region where the dataproc cluster is created. + :type region: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param service_account: The service account of the dataproc instances. + :type service_account: str + :param service_account_scopes: The URIs of service account scopes to be included. + :type service_account_scopes: list[str] + :param idle_delete_ttl: The longest duration that cluster would keep alive while + staying idle. Passing this threshold will cause cluster to be auto-deleted. + A duration in seconds. + :type idle_delete_ttl: int + :param auto_delete_time: The time when cluster will be auto-deleted. + :type auto_delete_time: datetime.datetime + :param auto_delete_ttl: The life duration of cluster, the cluster will be + auto-deleted at the end of this duration. + A duration in seconds. (If auto_delete_time is set this parameter will be ignored) + :type auto_delete_ttl: int + :param customer_managed_key: The customer-managed key used for disk encryption + ``projects/[PROJECT_STORING_KEYS]/locations/[LOCATION]/keyRings/[KEY_RING_NAME]/cryptoKeys/[KEY_NAME]`` # noqa # pylint: disable=line-too-long + :type customer_managed_key: str + """ + + # pylint: disable=too-many-arguments,too-many-locals + def __init__( + self, + project_id: str, + num_workers: Optional[int] = None, + zone: Optional[str] = None, + network_uri: Optional[str] = None, + subnetwork_uri: Optional[str] = None, + internal_ip_only: Optional[bool] = None, + tags: Optional[List[str]] = None, + storage_bucket: Optional[str] = None, + init_actions_uris: Optional[List[str]] = None, + init_action_timeout: str = "10m", + metadata: Optional[Dict] = None, + custom_image: Optional[str] = None, + custom_image_project_id: Optional[str] = None, + image_version: Optional[str] = None, + autoscaling_policy: Optional[str] = None, + properties: Optional[Dict] = None, + optional_components: Optional[List[str]] = None, + num_masters: int = 1, + master_machine_type: str = "n1-standard-4", + master_disk_type: str = "pd-standard", + master_disk_size: int = 1024, + worker_machine_type: str = "n1-standard-4", + worker_disk_type: str = "pd-standard", + worker_disk_size: int = 1024, + num_preemptible_workers: int = 0, + service_account: Optional[str] = None, + service_account_scopes: Optional[List[str]] = None, + idle_delete_ttl: Optional[int] = None, + auto_delete_time: Optional[datetime] = None, + auto_delete_ttl: Optional[int] = None, + customer_managed_key: Optional[str] = None, + **kwargs, + ) -> None: + + self.project_id = project_id + self.num_masters = num_masters + self.num_workers = num_workers + self.num_preemptible_workers = num_preemptible_workers + self.storage_bucket = storage_bucket + self.init_actions_uris = init_actions_uris + self.init_action_timeout = init_action_timeout + self.metadata = metadata + self.custom_image = custom_image + self.custom_image_project_id = custom_image_project_id + self.image_version = image_version + self.properties = properties or {} + self.optional_components = optional_components + self.master_machine_type = master_machine_type + self.master_disk_type = master_disk_type + self.master_disk_size = master_disk_size + self.autoscaling_policy = autoscaling_policy + self.worker_machine_type = worker_machine_type + self.worker_disk_type = worker_disk_type + self.worker_disk_size = worker_disk_size + self.zone = zone + self.network_uri = network_uri + self.subnetwork_uri = subnetwork_uri + self.internal_ip_only = internal_ip_only + self.tags = tags + self.service_account = service_account + self.service_account_scopes = service_account_scopes + self.idle_delete_ttl = idle_delete_ttl + self.auto_delete_time = auto_delete_time + self.auto_delete_ttl = auto_delete_ttl + self.customer_managed_key = customer_managed_key + self.single_node = num_workers == 0 + + if self.custom_image and self.image_version: + raise ValueError("The custom_image and image_version can't be both set") + + if self.single_node and self.num_preemptible_workers > 0: + raise ValueError("Single node cannot have preemptible workers.") + + def _get_init_action_timeout(self) -> dict: + match = re.match(r"^(\d+)([sm])$", self.init_action_timeout) + if match: + val = float(match.group(1)) + if match.group(2) == "s": + return {"seconds": int(val)} + elif match.group(2) == "m": + return {"seconds": int(timedelta(minutes=val).total_seconds())} + + raise AirflowException( + "DataprocClusterCreateOperator init_action_timeout" + " should be expressed in minutes or seconds. i.e. 10m, 30s" + ) + + def _build_gce_cluster_config(self, cluster_data): + if self.zone: + zone_uri = ( + "https://www.googleapis.com/compute/v1/projects/{}/zones/{}".format( + self.project_id, self.zone + ) + ) + cluster_data["gce_cluster_config"]["zone_uri"] = zone_uri + + if self.metadata: + cluster_data["gce_cluster_config"]["metadata"] = self.metadata + + if self.network_uri: + cluster_data["gce_cluster_config"]["network_uri"] = self.network_uri + + if self.subnetwork_uri: + cluster_data["gce_cluster_config"]["subnetwork_uri"] = self.subnetwork_uri + + if self.internal_ip_only: + if not self.subnetwork_uri: + raise AirflowException( + "Set internal_ip_only to true only when you pass a subnetwork_uri." + ) + cluster_data["gce_cluster_config"]["internal_ip_only"] = True + + if self.tags: + cluster_data["gce_cluster_config"]["tags"] = self.tags + + if self.service_account: + cluster_data["gce_cluster_config"]["service_account"] = self.service_account + + if self.service_account_scopes: + cluster_data["gce_cluster_config"][ + "service_account_scopes" + ] = self.service_account_scopes + + return cluster_data + + def _build_lifecycle_config(self, cluster_data): + if self.idle_delete_ttl: + cluster_data["lifecycle_config"]["idle_delete_ttl"] = { + "seconds": self.idle_delete_ttl + } + + if self.auto_delete_time: + utc_auto_delete_time = timezone.convert_to_utc(self.auto_delete_time) + cluster_data["lifecycle_config"][ + "auto_delete_time" + ] = utc_auto_delete_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + elif self.auto_delete_ttl: + cluster_data["lifecycle_config"]["auto_delete_ttl"] = { + "seconds": int(self.auto_delete_ttl) + } + + return cluster_data + + def _build_cluster_data(self): + if self.zone: + master_type_uri = f"projects/{self.project_id}/zones/{self.zone}/machineTypes/{self.master_machine_type}" + worker_type_uri = f"projects/{self.project_id}/zones/{self.zone}/machineTypes/{self.worker_machine_type}" + else: + master_type_uri = self.master_machine_type + worker_type_uri = self.worker_machine_type + + cluster_data = { + "gce_cluster_config": {}, + "master_config": { + "num_instances": self.num_masters, + "machine_type_uri": master_type_uri, + "disk_config": { + "boot_disk_type": self.master_disk_type, + "boot_disk_size_gb": self.master_disk_size, + }, + }, + "worker_config": { + "num_instances": self.num_workers, + "machine_type_uri": worker_type_uri, + "disk_config": { + "boot_disk_type": self.worker_disk_type, + "boot_disk_size_gb": self.worker_disk_size, + }, + }, + "secondary_worker_config": {}, + "software_config": {}, + "lifecycle_config": {}, + "encryption_config": {}, + "autoscaling_config": {}, + } + if self.num_preemptible_workers > 0: + cluster_data["secondary_worker_config"] = { + "num_instances": self.num_preemptible_workers, + "machine_type_uri": worker_type_uri, + "disk_config": { + "boot_disk_type": self.worker_disk_type, + "boot_disk_size_gb": self.worker_disk_size, + }, + "is_preemptible": True, + } + + if self.storage_bucket: + cluster_data["config_bucket"] = self.storage_bucket + + if self.image_version: + cluster_data["software_config"]["image_version"] = self.image_version + + elif self.custom_image: + project_id = self.custom_image_project_id or self.project_id + custom_image_url = ( + "https://www.googleapis.com/compute/beta/projects/" + "{}/global/images/{}".format(project_id, self.custom_image) + ) + cluster_data["master_config"]["image_uri"] = custom_image_url + if not self.single_node: + cluster_data["worker_config"]["image_uri"] = custom_image_url + + cluster_data = self._build_gce_cluster_config(cluster_data) + + if self.single_node: + self.properties["dataproc:dataproc.allow.zero.workers"] = "true" + + if self.properties: + cluster_data["software_config"]["properties"] = self.properties + + if self.optional_components: + cluster_data["software_config"][ + "optional_components" + ] = self.optional_components + + cluster_data = self._build_lifecycle_config(cluster_data) + + if self.init_actions_uris: + init_actions_dict = [ + { + "executable_file": uri, + "execution_timeout": self._get_init_action_timeout(), + } + for uri in self.init_actions_uris + ] + cluster_data["initialization_actions"] = init_actions_dict + + if self.customer_managed_key: + cluster_data["encryption_config"] = { + "gce_pd_kms_key_name": self.customer_managed_key + } + if self.autoscaling_policy: + cluster_data["autoscaling_config"] = {"policy_uri": self.autoscaling_policy} + + return cluster_data + + def make(self): + """ + Helper method for easier migration. + :return: Dict representing Dataproc cluster. + """ + return self._build_cluster_data() + + +# pylint: disable=too-many-instance-attributes +class DataprocCreateClusterOperator(BaseOperator): + """ + Create a new cluster on Google Cloud Dataproc. The operator will wait until the + creation is successful or an error occurs in the creation process. If the cluster + already exists and ``use_if_exists`` is True then the operator will: + + - if cluster state is ERROR then delete it if specified and raise error + - if cluster state is CREATING wait for it and then check for ERROR state + - if cluster state is DELETING wait for it and then create new cluster + + Please refer to + + https://cloud.google.com/dataproc/docs/reference/rest/v1/projects.regions.clusters + + for a detailed explanation on the different parameters. Most of the configuration + parameters detailed in the link are available as a parameter to this operator. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataprocCreateClusterOperator` + + :param project_id: The ID of the google cloud project in which + to create the cluster. (templated) + :type project_id: str + :param cluster_name: Name of the cluster to create + :type cluster_name: str + :param labels: Labels that will be assigned to created cluster + :type labels: Dict[str, str] + :param cluster_config: Required. The cluster config to create. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1.types.ClusterConfig` + :type cluster_config: Union[Dict, google.cloud.dataproc_v1.types.ClusterConfig] + :param region: The specified region where the dataproc cluster is created. + :type region: str + :parm delete_on_error: If true the cluster will be deleted if created with ERROR state. Default + value is true. + :type delete_on_error: bool + :parm use_if_exists: If true use existing cluster + :type use_if_exists: bool + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the + first ``google.longrunning.Operation`` created and stored in the backend is returned. + :type request_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "project_id", + "region", + "cluster_config", + "cluster_name", + "labels", + "impersonation_chain", + ) + template_fields_renderers = {"cluster_config": "json"} + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + cluster_name: str, + region: Optional[str] = None, + project_id: Optional[str] = None, + cluster_config: Optional[Dict] = None, + labels: Optional[Dict] = None, + request_id: Optional[str] = None, + delete_on_error: bool = True, + use_if_exists: bool = True, + retry: Optional[Retry] = None, + timeout: float = 1 * 60 * 60, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + if region is None: + warnings.warn( + "Default region value `global` will be deprecated. Please, provide region value.", + DeprecationWarning, + stacklevel=2, + ) + region = "global" + + # TODO: remove one day + if cluster_config is None: + warnings.warn( + "Passing cluster parameters by keywords to `{}` " + "will be deprecated. Please provide cluster_config object using `cluster_config` parameter. " + "You can use `airflow.dataproc.ClusterGenerator.generate_cluster` method to " + "obtain cluster object.".format(type(self).__name__), + DeprecationWarning, + stacklevel=1, + ) + # Remove result of apply defaults + if "params" in kwargs: + del kwargs["params"] + + # Create cluster object from kwargs + if project_id is None: + raise AirflowException( + "project_id argument is required when building cluster from keywords parameters" + ) + kwargs["project_id"] = project_id + cluster_config = ClusterGenerator(**kwargs).make() + + # Remove from kwargs cluster params passed for backward compatibility + cluster_params = inspect.signature(ClusterGenerator.__init__).parameters + for arg in cluster_params: + if arg in kwargs: + del kwargs[arg] + + super().__init__(**kwargs) + + self.cluster_config = cluster_config + self.cluster_name = cluster_name + self.labels = labels + self.project_id = project_id + self.region = region + self.request_id = request_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delete_on_error = delete_on_error + self.use_if_exists = use_if_exists + self.impersonation_chain = impersonation_chain + + def _create_cluster(self, hook: DataprocHook): + operation = hook.create_cluster( + project_id=self.project_id, + region=self.region, + cluster_name=self.cluster_name, + labels=self.labels, + cluster_config=self.cluster_config, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + cluster = operation.result() + self.log.info("Cluster created.") + return cluster + + def _delete_cluster(self, hook): + self.log.info("Deleting the cluster") + hook.delete_cluster( + region=self.region, + cluster_name=self.cluster_name, + project_id=self.project_id, + ) + + def _get_cluster(self, hook: DataprocHook) -> Cluster: + return hook.get_cluster( + project_id=self.project_id, + region=self.region, + cluster_name=self.cluster_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None: + if cluster.status.state != cluster.status.State.ERROR: + return + self.log.info("Cluster is in ERROR state") + gcs_uri = hook.diagnose_cluster( + region=self.region, + cluster_name=self.cluster_name, + project_id=self.project_id, + ) + self.log.info( + "Diagnostic information for cluster %s available at: %s", + self.cluster_name, + gcs_uri, + ) + if self.delete_on_error: + self._delete_cluster(hook) + raise AirflowException("Cluster was created but was in ERROR state.") + raise AirflowException("Cluster was created but is in ERROR state") + + def _wait_for_cluster_in_deleting_state(self, hook: DataprocHook) -> None: + time_left = self.timeout + for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120): + if time_left < 0: + raise AirflowException( + f"Cluster {self.cluster_name} is still DELETING state, aborting" + ) + time.sleep(time_to_sleep) + time_left = time_left - time_to_sleep + try: + self._get_cluster(hook) + except NotFound: + break + + def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster: + time_left = self.timeout + cluster = self._get_cluster(hook) + for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120): + if cluster.status.state != cluster.status.State.CREATING: + break + if time_left < 0: + raise AirflowException( + f"Cluster {self.cluster_name} is still CREATING state, aborting" + ) + time.sleep(time_to_sleep) + time_left = time_left - time_to_sleep + cluster = self._get_cluster(hook) + return cluster + + def execute(self, context) -> dict: + self.log.info("Creating cluster: %s", self.cluster_name) + hook = DataprocHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + try: + # First try to create a new cluster + cluster = self._create_cluster(hook) + except AlreadyExists: + if not self.use_if_exists: + raise + self.log.info("Cluster already exists.") + cluster = self._get_cluster(hook) + + # Check if cluster is not in ERROR state + self._handle_error_state(hook, cluster) + if cluster.status.state == cluster.status.State.CREATING: + # Wait for cluster to be created + cluster = self._wait_for_cluster_in_creating_state(hook) + self._handle_error_state(hook, cluster) + elif cluster.status.state == cluster.status.State.DELETING: + # Wait for cluster to be deleted + self._wait_for_cluster_in_deleting_state(hook) + # Create new cluster + cluster = self._create_cluster(hook) + self._handle_error_state(hook, cluster) + + return Cluster.to_dict(cluster) + + +class DataprocScaleClusterOperator(BaseOperator): + """ + Scale, up or down, a cluster on Google Cloud Dataproc. + The operator will wait until the cluster is re-scaled. + + **Example**: :: + + t1 = DataprocClusterScaleOperator( + task_id='dataproc_scale', + project_id='my-project', + cluster_name='cluster-1', + num_workers=10, + num_preemptible_workers=10, + graceful_decommission_timeout='1h', + dag=dag) + + .. seealso:: + For more detail on about scaling clusters have a look at the reference: + https://cloud.google.com/dataproc/docs/concepts/configuring-clusters/scaling-clusters + + :param cluster_name: The name of the cluster to scale. (templated) + :type cluster_name: str + :param project_id: The ID of the google cloud project in which + the cluster runs. (templated) + :type project_id: str + :param region: The region for the dataproc cluster. (templated) + :type region: str + :param num_workers: The new number of workers + :type num_workers: int + :param num_preemptible_workers: The new number of preemptible workers + :type num_preemptible_workers: int + :param graceful_decommission_timeout: Timeout for graceful YARN decommissioning. + Maximum value is 1d + :type graceful_decommission_timeout: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ["cluster_name", "project_id", "region", "impersonation_chain"] + + @apply_defaults + def __init__( + self, + *, + cluster_name: str, + project_id: Optional[str] = None, + region: str = "global", + num_workers: int = 2, + num_preemptible_workers: int = 0, + graceful_decommission_timeout: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.cluster_name = cluster_name + self.num_workers = num_workers + self.num_preemptible_workers = num_preemptible_workers + self.graceful_decommission_timeout = graceful_decommission_timeout + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + # TODO: Remove one day + warnings.warn( + "The `{cls}` operator is deprecated, please use `DataprocUpdateClusterOperator` instead.".format( + cls=type(self).__name__ + ), + DeprecationWarning, + stacklevel=1, + ) + + def _build_scale_cluster_data(self) -> dict: + scale_data = { + "config": { + "worker_config": {"num_instances": self.num_workers}, + "secondary_worker_config": { + "num_instances": self.num_preemptible_workers + }, + } + } + return scale_data + + @property + def _graceful_decommission_timeout_object(self) -> Optional[Dict[str, int]]: + if not self.graceful_decommission_timeout: + return None + + timeout = None + match = re.match(r"^(\d+)([smdh])$", self.graceful_decommission_timeout) + if match: + if match.group(2) == "s": + timeout = int(match.group(1)) + elif match.group(2) == "m": + val = float(match.group(1)) + timeout = int(timedelta(minutes=val).total_seconds()) + elif match.group(2) == "h": + val = float(match.group(1)) + timeout = int(timedelta(hours=val).total_seconds()) + elif match.group(2) == "d": + val = float(match.group(1)) + timeout = int(timedelta(days=val).total_seconds()) + + if not timeout: + raise AirflowException( + "DataprocClusterScaleOperator " + " should be expressed in day, hours, minutes or seconds. " + " i.e. 1d, 4h, 10m, 30s" + ) + + return {"seconds": timeout} + + def execute(self, context) -> None: + """Scale, up or down, a cluster on Google Cloud Dataproc.""" + self.log.info("Scaling cluster: %s", self.cluster_name) + + scaling_cluster_data = self._build_scale_cluster_data() + update_mask = [ + "config.worker_config.num_instances", + "config.secondary_worker_config.num_instances", + ] + + hook = DataprocHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + operation = hook.update_cluster( + project_id=self.project_id, + location=self.region, + cluster_name=self.cluster_name, + cluster=scaling_cluster_data, + graceful_decommission_timeout=self._graceful_decommission_timeout_object, + update_mask={"paths": update_mask}, + ) + operation.result() + self.log.info("Cluster scaling finished") + + +class DataprocDeleteClusterOperator(BaseOperator): + """ + Deletes a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to (templated). + :type project_id: str + :param region: Required. The Cloud Dataproc region in which to handle the request (templated). + :type region: str + :param cluster_name: Required. The cluster name (templated). + :type cluster_name: str + :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail + if cluster with specified UUID does not exist. + :type cluster_uuid: str + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the + first ``google.longrunning.Operation`` created and stored in the backend is returned. + :type request_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ("project_id", "region", "cluster_name", "impersonation_chain") + + @apply_defaults + def __init__( + self, + *, + project_id: str, + region: str, + cluster_name: str, + cluster_uuid: Optional[str] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.cluster_name = cluster_name + self.cluster_uuid = cluster_uuid + self.request_id = request_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = DataprocHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info("Deleting cluster: %s", self.cluster_name) + operation = hook.delete_cluster( + project_id=self.project_id, + region=self.region, + cluster_name=self.cluster_name, + cluster_uuid=self.cluster_uuid, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + operation.result() + self.log.info("Cluster deleted.") + + +class DataprocJobBaseOperator(BaseOperator): + """ + The base class for operators that launch job on DataProc. + + :param job_name: The job name used in the DataProc cluster. This name by default + is the task_id appended with the execution data, but can be templated. The + name will always be appended with a random number to avoid name clashes. + :type job_name: str + :param cluster_name: The name of the DataProc cluster. + :type cluster_name: str + :param dataproc_properties: Map for the Hive properties. Ideal to put in + default arguments (templated) + :type dataproc_properties: dict + :param dataproc_jars: HCFS URIs of jar files to add to the CLASSPATH of the Hive server and Hadoop + MapReduce (MR) tasks. Can contain Hive SerDes and UDFs. (templated) + :type dataproc_jars: list + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param labels: The labels to associate with this job. Label keys must contain 1 to 63 characters, + and must conform to RFC 1035. Label values may be empty, but, if present, must contain 1 to 63 + characters, and must conform to RFC 1035. No more than 32 labels can be associated with a job. + :type labels: dict + :param region: The specified region where the dataproc cluster is created. + :type region: str + :param job_error_states: Job states that should be considered error states. + Any states in this set will result in an error being raised and failure of the + task. Eg, if the ``CANCELLED`` state should also be considered a task failure, + pass in ``{'ERROR', 'CANCELLED'}``. Possible values are currently only + ``'ERROR'`` and ``'CANCELLED'``, but could change in the future. Defaults to + ``{'ERROR'}``. + :type job_error_states: set + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + :param asynchronous: Flag to return after submitting the job to the Dataproc API. + This is useful for submitting long running jobs and + waiting on them asynchronously using the DataprocJobSensor + :type asynchronous: bool + + :var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API. + This is useful for identifying or linking to the job in the Google Cloud Console + Dataproc UI, as the actual "jobId" submitted to the Dataproc API is appended with + an 8 character random string. + :vartype dataproc_job_id: str + """ + + job_type = "" + + @apply_defaults + def __init__( + self, + *, + job_name: str = "{{task.task_id}}_{{ds_nodash}}", + cluster_name: str = "cluster-1", + dataproc_properties: Optional[Dict] = None, + dataproc_jars: Optional[List[str]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + labels: Optional[Dict] = None, + region: Optional[str] = None, + job_error_states: Optional[Set[str]] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + asynchronous: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.labels = labels + self.job_name = job_name + self.cluster_name = cluster_name + self.dataproc_properties = dataproc_properties + self.dataproc_jars = dataproc_jars + + if region is None: + warnings.warn( + "Default region value `global` will be deprecated. Please, provide region value.", + DeprecationWarning, + stacklevel=2, + ) + region = "global" + self.region = region + + self.job_error_states = ( + job_error_states if job_error_states is not None else {"ERROR"} + ) + self.impersonation_chain = impersonation_chain + + self.hook = DataprocHook( + gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain + ) + self.project_id = self.hook.project_id + self.job_template = None + self.job = None + self.dataproc_job_id = None + self.asynchronous = asynchronous + + def create_job_template(self): + """Initialize `self.job_template` with default values""" + self.job_template = DataProcJobBuilder( + project_id=self.project_id, + task_id=self.task_id, + cluster_name=self.cluster_name, + job_type=self.job_type, + properties=self.dataproc_properties, + ) + self.job_template.set_job_name(self.job_name) + self.job_template.add_jar_file_uris(self.dataproc_jars) + self.job_template.add_labels(self.labels) + + def _generate_job_template(self) -> str: + if self.job_template: + job = self.job_template.build() + return job["job"] + raise Exception("Create a job template before") + + def execute(self, context): + if self.job_template: + self.job = self.job_template.build() + self.dataproc_job_id = self.job["job"]["reference"]["job_id"] + self.log.info("Submitting %s job %s", self.job_type, self.dataproc_job_id) + job_object = self.hook.submit_job( + project_id=self.project_id, job=self.job["job"], location=self.region + ) + job_id = job_object.reference.job_id + self.log.info("Job %s submitted successfully.", job_id) + + if not self.asynchronous: + self.log.info("Waiting for job %s to complete", job_id) + self.hook.wait_for_job( + job_id=job_id, location=self.region, project_id=self.project_id + ) + self.log.info("Job %s completed successfully.", job_id) + return job_id + else: + raise AirflowException("Create a job template before") + + def on_kill(self) -> None: + """ + Callback called when the operator is killed. + Cancel any running job. + """ + if self.dataproc_job_id: + self.hook.cancel_job( + project_id=self.project_id, + job_id=self.dataproc_job_id, + location=self.region, + ) + + +class DataprocSubmitPigJobOperator(DataprocJobBaseOperator): + """ + Start a Pig query Job on a Cloud DataProc cluster. The parameters of the operation + will be passed to the cluster. + + It's a good practice to define dataproc_* parameters in the default_args of the dag + like the cluster name and UDFs. + + .. code-block:: python + + default_args = { + 'cluster_name': 'cluster-1', + 'dataproc_pig_jars': [ + 'gs://example/udf/jar/datafu/1.2.0/datafu.jar', + 'gs://example/udf/jar/gpig/1.2/gpig.jar' + ] + } + + You can pass a pig script as string or file reference. Use variables to pass on + variables for the pig script to be resolved on the cluster or use the parameters to + be resolved in the script as template parameters. + + **Example**: :: + + t1 = DataProcPigOperator( + task_id='dataproc_pig', + query='a_pig_script.pig', + variables={'out': 'gs://example/output/{{ds}}'}, + dag=dag) + + .. seealso:: + For more detail on about job submission have a look at the reference: + https://cloud.google.com/dataproc/reference/rest/v1/projects.regions.jobs + + :param query: The query or reference to the query + file (pg or pig extension). (templated) + :type query: str + :param query_uri: The HCFS URI of the script that contains the Pig queries. + :type query_uri: str + :param variables: Map of named parameters for the query. (templated) + :type variables: dict + """ + + template_fields = [ + "query", + "variables", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", + ] + template_ext = (".pg", ".pig") + ui_color = "#0273d4" + job_type = "pig_job" + + @apply_defaults + def __init__( + self, + *, + query: Optional[str] = None, + query_uri: Optional[str] = None, + variables: Optional[Dict] = None, + **kwargs, + ) -> None: + # TODO: Remove one day + warnings.warn( + "The `{cls}` operator is deprecated, please use `DataprocSubmitJobOperator` instead. You can use" + " `generate_job` method of `{cls}` to generate dictionary representing your job" + " and use it with the new operator.".format(cls=type(self).__name__), + DeprecationWarning, + stacklevel=1, + ) + + super().__init__(**kwargs) + self.query = query + self.query_uri = query_uri + self.variables = variables + + def generate_job(self): + """ + Helper method for easier migration to `DataprocSubmitJobOperator`. + :return: Dict representing Dataproc job + """ + self.create_job_template() + + if self.query is None: + self.job_template.add_query_uri(self.query_uri) + else: + self.job_template.add_query(self.query) + self.job_template.add_variables(self.variables) + return self._generate_job_template() + + def execute(self, context): + self.create_job_template() + + if self.query is None: + self.job_template.add_query_uri(self.query_uri) + else: + self.job_template.add_query(self.query) + self.job_template.add_variables(self.variables) + + super().execute(context) + + +class DataprocSubmitHiveJobOperator(DataprocJobBaseOperator): + """ + Start a Hive query Job on a Cloud DataProc cluster. + + :param query: The query or reference to the query file (q extension). + :type query: str + :param query_uri: The HCFS URI of the script that contains the Hive queries. + :type query_uri: str + :param variables: Map of named parameters for the query. + :type variables: dict + """ + + template_fields = [ + "query", + "variables", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", + ] + template_ext = (".q", ".hql") + ui_color = "#0273d4" + job_type = "hive_job" + + @apply_defaults + def __init__( + self, + *, + query: Optional[str] = None, + query_uri: Optional[str] = None, + variables: Optional[Dict] = None, + **kwargs, + ) -> None: + # TODO: Remove one day + warnings.warn( + "The `{cls}` operator is deprecated, please use `DataprocSubmitJobOperator` instead. You can use" + " `generate_job` method of `{cls}` to generate dictionary representing your job" + " and use it with the new operator.".format(cls=type(self).__name__), + DeprecationWarning, + stacklevel=1, + ) + + super().__init__(**kwargs) + self.query = query + self.query_uri = query_uri + self.variables = variables + if self.query is not None and self.query_uri is not None: + raise AirflowException("Only one of `query` and `query_uri` can be passed.") + + def generate_job(self): + """ + Helper method for easier migration to `DataprocSubmitJobOperator`. + :return: Dict representing Dataproc job + """ + self.create_job_template() + if self.query is None: + self.job_template.add_query_uri(self.query_uri) + else: + self.job_template.add_query(self.query) + self.job_template.add_variables(self.variables) + return self._generate_job_template() + + def execute(self, context): + self.create_job_template() + if self.query is None: + self.job_template.add_query_uri(self.query_uri) + else: + self.job_template.add_query(self.query) + self.job_template.add_variables(self.variables) + + super().execute(context) + + +class DataprocSubmitSparkSqlJobOperator(DataprocJobBaseOperator): + """ + Start a Spark SQL query Job on a Cloud DataProc cluster. + + :param query: The query or reference to the query file (q extension). (templated) + :type query: str + :param query_uri: The HCFS URI of the script that contains the SQL queries. + :type query_uri: str + :param variables: Map of named parameters for the query. (templated) + :type variables: dict + """ + + template_fields = [ + "query", + "variables", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", + ] + template_ext = (".q",) + ui_color = "#0273d4" + job_type = "spark_sql_job" + + @apply_defaults + def __init__( + self, + *, + query: Optional[str] = None, + query_uri: Optional[str] = None, + variables: Optional[Dict] = None, + **kwargs, + ) -> None: + # TODO: Remove one day + warnings.warn( + "The `{cls}` operator is deprecated, please use `DataprocSubmitJobOperator` instead. You can use" + " `generate_job` method of `{cls}` to generate dictionary representing your job" + " and use it with the new operator.".format(cls=type(self).__name__), + DeprecationWarning, + stacklevel=1, + ) + + super().__init__(**kwargs) + self.query = query + self.query_uri = query_uri + self.variables = variables + if self.query is not None and self.query_uri is not None: + raise AirflowException("Only one of `query` and `query_uri` can be passed.") + + def generate_job(self): + """ + Helper method for easier migration to `DataprocSubmitJobOperator`. + :return: Dict representing Dataproc job + """ + self.create_job_template() + if self.query is None: + self.job_template.add_query_uri(self.query_uri) + else: + self.job_template.add_query(self.query) + self.job_template.add_variables(self.variables) + return self._generate_job_template() + + def execute(self, context): + self.create_job_template() + if self.query is None: + self.job_template.add_query_uri(self.query_uri) + else: + self.job_template.add_query(self.query) + self.job_template.add_variables(self.variables) + + super().execute(context) + + +class DataprocSubmitSparkJobOperator(DataprocJobBaseOperator): + """ + Start a Spark Job on a Cloud DataProc cluster. + + :param main_jar: The HCFS URI of the jar file that contains the main class + (use this or the main_class, not both together). + :type main_jar: str + :param main_class: Name of the job class. (use this or the main_jar, not both + together). + :type main_class: str + :param arguments: Arguments for the job. (templated) + :type arguments: list + :param archives: List of archived files that will be unpacked in the work + directory. Should be stored in Cloud Storage. + :type archives: list + :param files: List of files to be copied to the working directory + :type files: list + """ + + template_fields = [ + "arguments", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", + ] + ui_color = "#0273d4" + job_type = "spark_job" + + @apply_defaults + def __init__( + self, + *, + main_jar: Optional[str] = None, + main_class: Optional[str] = None, + arguments: Optional[List] = None, + archives: Optional[List] = None, + files: Optional[List] = None, + **kwargs, + ) -> None: + # TODO: Remove one day + warnings.warn( + "The `{cls}` operator is deprecated, please use `DataprocSubmitJobOperator` instead. You can use" + " `generate_job` method of `{cls}` to generate dictionary representing your job" + " and use it with the new operator.".format(cls=type(self).__name__), + DeprecationWarning, + stacklevel=1, + ) + + super().__init__(**kwargs) + self.main_jar = main_jar + self.main_class = main_class + self.arguments = arguments + self.archives = archives + self.files = files + + def generate_job(self): + """ + Helper method for easier migration to `DataprocSubmitJobOperator`. + :return: Dict representing Dataproc job + """ + self.create_job_template() + self.job_template.set_main(self.main_jar, self.main_class) + self.job_template.add_args(self.arguments) + self.job_template.add_archive_uris(self.archives) + self.job_template.add_file_uris(self.files) + return self._generate_job_template() + + def execute(self, context): + self.create_job_template() + self.job_template.set_main(self.main_jar, self.main_class) + self.job_template.add_args(self.arguments) + self.job_template.add_archive_uris(self.archives) + self.job_template.add_file_uris(self.files) + + super().execute(context) + + +class DataprocSubmitHadoopJobOperator(DataprocJobBaseOperator): + """ + Start a Hadoop Job on a Cloud DataProc cluster. + + :param main_jar: The HCFS URI of the jar file containing the main class + (use this or the main_class, not both together). + :type main_jar: str + :param main_class: Name of the job class. (use this or the main_jar, not both + together). + :type main_class: str + :param arguments: Arguments for the job. (templated) + :type arguments: list + :param archives: List of archived files that will be unpacked in the work + directory. Should be stored in Cloud Storage. + :type archives: list + :param files: List of files to be copied to the working directory + :type files: list + """ + + template_fields = [ + "arguments", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", + ] + ui_color = "#0273d4" + job_type = "hadoop_job" + + @apply_defaults + def __init__( + self, + *, + main_jar: Optional[str] = None, + main_class: Optional[str] = None, + arguments: Optional[List] = None, + archives: Optional[List] = None, + files: Optional[List] = None, + **kwargs, + ) -> None: + # TODO: Remove one day + warnings.warn( + "The `{cls}` operator is deprecated, please use `DataprocSubmitJobOperator` instead. You can use" + " `generate_job` method of `{cls}` to generate dictionary representing your job" + " and use it with the new operator.".format(cls=type(self).__name__), + DeprecationWarning, + stacklevel=1, + ) + + super().__init__(**kwargs) + self.main_jar = main_jar + self.main_class = main_class + self.arguments = arguments + self.archives = archives + self.files = files + + def generate_job(self): + """ + Helper method for easier migration to `DataprocSubmitJobOperator`. + :return: Dict representing Dataproc job + """ + self.create_job_template() + self.job_template.set_main(self.main_jar, self.main_class) + self.job_template.add_args(self.arguments) + self.job_template.add_archive_uris(self.archives) + self.job_template.add_file_uris(self.files) + return self._generate_job_template() + + def execute(self, context): + self.create_job_template() + self.job_template.set_main(self.main_jar, self.main_class) + self.job_template.add_args(self.arguments) + self.job_template.add_archive_uris(self.archives) + self.job_template.add_file_uris(self.files) + + super().execute(context) + + +class DataprocSubmitPySparkJobOperator(DataprocJobBaseOperator): + """ + Start a PySpark Job on a Cloud DataProc cluster. + + :param main: [Required] The Hadoop Compatible Filesystem (HCFS) URI of the main + Python file to use as the driver. Must be a .py file. (templated) + :type main: str + :param arguments: Arguments for the job. (templated) + :type arguments: list + :param archives: List of archived files that will be unpacked in the work + directory. Should be stored in Cloud Storage. + :type archives: list + :param files: List of files to be copied to the working directory + :type files: list + :param pyfiles: List of Python files to pass to the PySpark framework. + Supported file types: .py, .egg, and .zip + :type pyfiles: list + """ + + template_fields = [ + "main", + "arguments", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", + ] + ui_color = "#0273d4" + job_type = "pyspark_job" + + @staticmethod + def _generate_temp_filename(filename): + date = time.strftime("%Y%m%d%H%M%S") + return f"{date}_{str(uuid.uuid4())[:8]}_{ntpath.basename(filename)}" + + def _upload_file_temp(self, bucket, local_file): + """Upload a local file to a Google Cloud Storage bucket.""" + temp_filename = self._generate_temp_filename(local_file) + if not bucket: + raise AirflowException( + "If you want Airflow to upload the local file to a temporary bucket, set " + "the 'temp_bucket' key in the connection string" + ) + + self.log.info("Uploading %s to %s", local_file, temp_filename) + + GCSHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ).upload( + bucket_name=bucket, + object_name=temp_filename, + mime_type="application/x-python", + filename=local_file, + ) + return f"gs://{bucket}/{temp_filename}" + + @apply_defaults + def __init__( + self, + *, + main: str, + arguments: Optional[List] = None, + archives: Optional[List] = None, + pyfiles: Optional[List] = None, + files: Optional[List] = None, + **kwargs, + ) -> None: + # TODO: Remove one day + warnings.warn( + "The `{cls}` operator is deprecated, please use `DataprocSubmitJobOperator` instead. You can use" + " `generate_job` method of `{cls}` to generate dictionary representing your job" + " and use it with the new operator.".format(cls=type(self).__name__), + DeprecationWarning, + stacklevel=1, + ) + + super().__init__(**kwargs) + self.main = main + self.arguments = arguments + self.archives = archives + self.files = files + self.pyfiles = pyfiles + + def generate_job(self): + """ + Helper method for easier migration to `DataprocSubmitJobOperator`. + :return: Dict representing Dataproc job + """ + self.create_job_template() + # Check if the file is local, if that is the case, upload it to a bucket + if os.path.isfile(self.main): + cluster_info = self.hook.get_cluster( + project_id=self.hook.project_id, + region=self.region, + cluster_name=self.cluster_name, + ) + bucket = cluster_info["config"]["config_bucket"] + self.main = f"gs://{bucket}/{self.main}" + self.job_template.set_python_main(self.main) + self.job_template.add_args(self.arguments) + self.job_template.add_archive_uris(self.archives) + self.job_template.add_file_uris(self.files) + self.job_template.add_python_file_uris(self.pyfiles) + + return self._generate_job_template() + + def execute(self, context): + self.create_job_template() + # Check if the file is local, if that is the case, upload it to a bucket + if os.path.isfile(self.main): + cluster_info = self.hook.get_cluster( + project_id=self.hook.project_id, + region=self.region, + cluster_name=self.cluster_name, + ) + bucket = cluster_info["config"]["config_bucket"] + self.main = self._upload_file_temp(bucket, self.main) + + self.job_template.set_python_main(self.main) + self.job_template.add_args(self.arguments) + self.job_template.add_archive_uris(self.archives) + self.job_template.add_file_uris(self.files) + self.job_template.add_python_file_uris(self.pyfiles) + + super().execute(context) + + +class DataprocCreateWorkflowTemplateOperator(BaseOperator): + """ + Creates new workflow template. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The Cloud Dataproc region in which to handle the request. + :type location: str + :param template: The Dataproc workflow template to create. If a dict is provided, + it must be of the same form as the protobuf message WorkflowTemplate. + :type template: Union[dict, WorkflowTemplate] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + + template_fields = ("location", "template") + template_fields_renderers = {"template": "json"} + + def __init__( + self, + *, + location: str, + template: Dict, + project_id: str, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.location = location + self.template = template + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = DataprocHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info("Creating template") + try: + workflow = hook.create_workflow_template( + location=self.location, + template=self.template, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Workflow %s created", workflow.name) + except AlreadyExists: + self.log.info("Workflow with given id already exists") + + +class DataprocInstantiateWorkflowTemplateOperator(BaseOperator): + """ + Instantiate a WorkflowTemplate on Google Cloud Dataproc. The operator will wait + until the WorkflowTemplate is finished executing. + + .. seealso:: + Please refer to: + https://cloud.google.com/dataproc/docs/reference/rest/v1beta2/projects.regions.workflowTemplates/instantiate + + :param template_id: The id of the template. (templated) + :type template_id: str + :param project_id: The ID of the google cloud project in which + the template runs + :type project_id: str + :param region: The specified region where the dataproc cluster is created. + :type region: str + :param parameters: a map of parameters for Dataproc Template in key-value format: + map (key: string, value: string) + Example: { "date_from": "2019-08-01", "date_to": "2019-08-02"}. + Values may not exceed 100 characters. Please refer to: + https://cloud.google.com/dataproc/docs/concepts/workflows/workflow-parameters + :type parameters: Dict[str, str] + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``SubmitJobRequest`` requests with the same id, then the second request will be ignored and the first + ``Job`` created and stored in the backend is returned. + It is recommended to always set this value to a UUID. + :type request_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ["template_id", "impersonation_chain", "request_id", "parameters"] + template_fields_renderers = {"parameters": "json"} + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + template_id: str, + region: str, + project_id: Optional[str] = None, + version: Optional[int] = None, + request_id: Optional[str] = None, + parameters: Optional[Dict[str, str]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.template_id = template_id + self.parameters = parameters + self.version = version + self.project_id = project_id + self.region = region + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.request_id = request_id + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = DataprocHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info("Instantiating template %s", self.template_id) + operation = hook.instantiate_workflow_template( + project_id=self.project_id, + location=self.region, + template_name=self.template_id, + version=self.version, + request_id=self.request_id, + parameters=self.parameters, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + operation.result() + self.log.info("Template instantiated.") + + +class DataprocInstantiateInlineWorkflowTemplateOperator(BaseOperator): + """ + Instantiate a WorkflowTemplate Inline on Google Cloud Dataproc. The operator will + wait until the WorkflowTemplate is finished executing. + + .. seealso:: + Please refer to: + https://cloud.google.com/dataproc/docs/reference/rest/v1beta2/projects.regions.workflowTemplates/instantiateInline + + :param template: The template contents. (templated) + :type template: dict + :param project_id: The ID of the google cloud project in which + the template runs + :type project_id: str + :param region: The specified region where the dataproc cluster is created. + :type region: str + :param parameters: a map of parameters for Dataproc Template in key-value format: + map (key: string, value: string) + Example: { "date_from": "2019-08-01", "date_to": "2019-08-02"}. + Values may not exceed 100 characters. Please refer to: + https://cloud.google.com/dataproc/docs/concepts/workflows/workflow-parameters + :type parameters: Dict[str, str] + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``SubmitJobRequest`` requests with the same id, then the second request will be ignored and the first + ``Job`` created and stored in the backend is returned. + It is recommended to always set this value to a UUID. + :type request_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ["template", "impersonation_chain"] + template_fields_renderers = {"template": "json"} + + @apply_defaults + def __init__( + self, + *, + template: Dict, + region: str, + project_id: Optional[str] = None, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.template = template + self.project_id = project_id + self.location = region + self.template = template + self.request_id = request_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + self.log.info("Instantiating Inline Template") + hook = DataprocHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + operation = hook.instantiate_inline_workflow_template( + template=self.template, + project_id=self.project_id, + location=self.location, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + operation.result() + self.log.info("Template instantiated.") + + +class DataprocSubmitJobOperator(BaseOperator): + """ + Submits a job to a cluster. + + :param project_id: Required. The ID of the Google Cloud project that the job belongs to. + :type project_id: str + :param location: Required. The Cloud Dataproc region in which to handle the request. + :type location: str + :param job: Required. The job resource. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1beta2.types.Job` + :type job: Dict + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``SubmitJobRequest`` requests with the same id, then the second request will be ignored and the first + ``Job`` created and stored in the backend is returned. + It is recommended to always set this value to a UUID. + :type request_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + :param asynchronous: Flag to return after submitting the job to the Dataproc API. + This is useful for submitting long running jobs and + waiting on them asynchronously using the DataprocJobSensor + :type asynchronous: bool + :param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called + :type cancel_on_kill: bool + :param wait_timeout: How many seconds wait for job to be ready. Used only if ``asynchronous`` is False + :type wait_timeout: int + """ + + template_fields = ( + "project_id", + "location", + "job", + "impersonation_chain", + "request_id", + ) + template_fields_renderers = {"job": "json"} + + @apply_defaults + def __init__( + self, + *, + project_id: str, + location: str, + job: Dict, + request_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + asynchronous: bool = False, + cancel_on_kill: bool = True, + wait_timeout: Optional[int] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.job = job + self.request_id = request_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.asynchronous = asynchronous + self.cancel_on_kill = cancel_on_kill + self.hook: Optional[DataprocHook] = None + self.job_id: Optional[str] = None + self.wait_timeout = wait_timeout + + def execute(self, context: Dict): + self.log.info("Submitting job") + self.hook = DataprocHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + job_object = self.hook.submit_job( + project_id=self.project_id, + location=self.location, + job=self.job, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + job_id = job_object.reference.job_id + self.log.info("Job %s submitted successfully.", job_id) + + if not self.asynchronous: + self.log.info("Waiting for job %s to complete", job_id) + self.hook.wait_for_job( + job_id=job_id, + location=self.location, + project_id=self.project_id, + timeout=self.wait_timeout, + ) + self.log.info("Job %s completed successfully.", job_id) + + self.job_id = job_id + return self.job_id + + def on_kill(self): + if self.job_id and self.cancel_on_kill: + self.hook.cancel_job( + job_id=self.job_id, project_id=self.project_id, location=self.location + ) + + +class DataprocUpdateClusterOperator(BaseOperator): + """ + Updates a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The Cloud Dataproc region in which to handle the request. + :type location: str + :param cluster_name: Required. The cluster name. + :type cluster_name: str + :param cluster: Required. The changes to the cluster. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1beta2.types.Cluster` + :type cluster: Union[Dict, google.cloud.dataproc_v1beta2.types.Cluster] + :param update_mask: Required. Specifies the path, relative to ``Cluster``, of the field to update. For + example, to change the number of workers in a cluster to 5, the ``update_mask`` parameter would be + specified as ``config.worker_config.num_instances``, and the ``PATCH`` request body would specify the + new value. If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] + :param graceful_decommission_timeout: Optional. Timeout for graceful YARN decommissioning. Graceful + decommissioning allows removing nodes from the cluster without interrupting jobs in progress. Timeout + specifies how long to wait for jobs in progress to finish before forcefully removing nodes (and + potentially interrupting jobs). Default timeout is 0 (for forceful decommission), and the maximum + allowed timeout is 1 day. + :type graceful_decommission_timeout: Union[Dict, google.protobuf.duration_pb2.Duration] + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``UpdateClusterRequest`` requests with the same id, then the second request will be ignored and the + first ``google.longrunning.Operation`` created and stored in the backend is returned. + :type request_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ("impersonation_chain", "cluster_name") + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + location: str, + cluster_name: str, + cluster: Union[Dict, Cluster], + update_mask: Union[Dict, FieldMask], + graceful_decommission_timeout: Union[Dict, Duration], + request_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Retry = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.cluster_name = cluster_name + self.cluster = cluster + self.update_mask = update_mask + self.graceful_decommission_timeout = graceful_decommission_timeout + self.request_id = request_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = DataprocHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info("Updating %s cluster.", self.cluster_name) + operation = hook.update_cluster( + project_id=self.project_id, + location=self.location, + cluster_name=self.cluster_name, + cluster=self.cluster, + update_mask=self.update_mask, + graceful_decommission_timeout=self.graceful_decommission_timeout, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + operation.result() + self.log.info("Updated %s cluster.", self.cluster_name) diff --git a/reference/providers/google/cloud/operators/datastore.py b/reference/providers/google/cloud/operators/datastore.py new file mode 100644 index 0000000..ec7af06 --- /dev/null +++ b/reference/providers/google/cloud/operators/datastore.py @@ -0,0 +1,709 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains Google Datastore operators.""" +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.datastore import DatastoreHook +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils.decorators import apply_defaults + + +class CloudDatastoreExportEntitiesOperator(BaseOperator): + """ + Export entities from Google Cloud Datastore to Cloud Storage + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDatastoreExportEntitiesOperator` + + :param bucket: name of the cloud storage bucket to backup data + :type bucket: str + :param namespace: optional namespace path in the specified Cloud Storage bucket + to backup data. If this namespace does not exist in GCS, it will be created. + :type namespace: str + :param datastore_conn_id: the name of the Datastore connection id to use + :type datastore_conn_id: str + :param cloud_storage_conn_id: the name of the cloud storage connection id to + force-write backup + :type cloud_storage_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param entity_filter: description of what data from the project is included in the + export, refer to + https://cloud.google.com/datastore/docs/reference/rest/Shared.Types/EntityFilter + :type entity_filter: dict + :param labels: client-assigned labels for cloud storage + :type labels: dict + :param polling_interval_in_seconds: number of seconds to wait before polling for + execution status again + :type polling_interval_in_seconds: int + :param overwrite_existing: if the storage bucket + namespace is not empty, it will be + emptied prior to exports. This enables overwriting existing backups. + :type overwrite_existing: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "bucket", + "namespace", + "entity_filter", + "labels", + "impersonation_chain", + ] + + @apply_defaults + def __init__( + self, # pylint: disable=too-many-arguments + *, + bucket: str, + namespace: Optional[str] = None, + datastore_conn_id: str = "google_cloud_default", + cloud_storage_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + entity_filter: Optional[dict] = None, + labels: Optional[dict] = None, + polling_interval_in_seconds: int = 10, + overwrite_existing: bool = False, + project_id: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.datastore_conn_id = datastore_conn_id + self.cloud_storage_conn_id = cloud_storage_conn_id + self.delegate_to = delegate_to + self.bucket = bucket + self.namespace = namespace + self.entity_filter = entity_filter + self.labels = labels + self.polling_interval_in_seconds = polling_interval_in_seconds + self.overwrite_existing = overwrite_existing + self.project_id = project_id + self.impersonation_chain = impersonation_chain + if kwargs.get("xcom_push") is not None: + raise AirflowException( + "'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead" + ) + + def execute(self, context) -> dict: + self.log.info("Exporting data to Cloud Storage bucket %s", self.bucket) + + if self.overwrite_existing and self.namespace: + gcs_hook = GCSHook( + self.cloud_storage_conn_id, impersonation_chain=self.impersonation_chain + ) + objects = gcs_hook.list(self.bucket, prefix=self.namespace) + for obj in objects: + gcs_hook.delete(self.bucket, obj) + + ds_hook = DatastoreHook( + self.datastore_conn_id, + self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + result = ds_hook.export_to_storage_bucket( + bucket=self.bucket, + namespace=self.namespace, + entity_filter=self.entity_filter, + labels=self.labels, + project_id=self.project_id, + ) + operation_name = result["name"] + result = ds_hook.poll_operation_until_done( + operation_name, self.polling_interval_in_seconds + ) + + state = result["metadata"]["common"]["state"] + if state != "SUCCESSFUL": + raise AirflowException(f"Operation failed: result={result}") + return result + + +class CloudDatastoreImportEntitiesOperator(BaseOperator): + """ + Import entities from Cloud Storage to Google Cloud Datastore + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDatastoreImportEntitiesOperator` + + :param bucket: container in Cloud Storage to store data + :type bucket: str + :param file: path of the backup metadata file in the specified Cloud Storage bucket. + It should have the extension .overall_export_metadata + :type file: str + :param namespace: optional namespace of the backup metadata file in + the specified Cloud Storage bucket. + :type namespace: str + :param entity_filter: description of what data from the project is included in + the export, refer to + https://cloud.google.com/datastore/docs/reference/rest/Shared.Types/EntityFilter + :type entity_filter: dict + :param labels: client-assigned labels for cloud storage + :type labels: dict + :param datastore_conn_id: the name of the connection id to use + :type datastore_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param polling_interval_in_seconds: number of seconds to wait before polling for + execution status again + :type polling_interval_in_seconds: float + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "bucket", + "file", + "namespace", + "entity_filter", + "labels", + "impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + bucket: str, + file: str, + namespace: Optional[str] = None, + entity_filter: Optional[dict] = None, + labels: Optional[dict] = None, + datastore_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + polling_interval_in_seconds: float = 10, + project_id: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.datastore_conn_id = datastore_conn_id + self.delegate_to = delegate_to + self.bucket = bucket + self.file = file + self.namespace = namespace + self.entity_filter = entity_filter + self.labels = labels + self.polling_interval_in_seconds = polling_interval_in_seconds + self.project_id = project_id + self.impersonation_chain = impersonation_chain + if kwargs.get("xcom_push") is not None: + raise AirflowException( + "'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead" + ) + + def execute(self, context): + self.log.info("Importing data from Cloud Storage bucket %s", self.bucket) + ds_hook = DatastoreHook( + self.datastore_conn_id, + self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + result = ds_hook.import_from_storage_bucket( + bucket=self.bucket, + file=self.file, + namespace=self.namespace, + entity_filter=self.entity_filter, + labels=self.labels, + project_id=self.project_id, + ) + operation_name = result["name"] + result = ds_hook.poll_operation_until_done( + operation_name, self.polling_interval_in_seconds + ) + + state = result["metadata"]["common"]["state"] + if state != "SUCCESSFUL": + raise AirflowException(f"Operation failed: result={result}") + + return result + + +class CloudDatastoreAllocateIdsOperator(BaseOperator): + """ + Allocate IDs for incomplete keys. Return list of keys. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDatastoreAllocateIdsOperator` + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/rest/v1/projects/allocateIds + + :param partial_keys: a list of partial keys. + :type partial_keys: list + :param project_id: Google Cloud project ID against which to make the request. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "partial_keys", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + partial_keys: List, + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.partial_keys = partial_keys + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> list: + hook = DatastoreHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + keys = hook.allocate_ids( + partial_keys=self.partial_keys, + project_id=self.project_id, + ) + return keys + + +class CloudDatastoreBeginTransactionOperator(BaseOperator): + """ + Begins a new transaction. Returns a transaction handle. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDatastoreBeginTransactionOperator` + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/rest/v1/projects/beginTransaction + + :param transaction_options: Options for a new transaction. + :type transaction_options: Dict[str, Any] + :param project_id: Google Cloud project ID against which to make the request. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "transaction_options", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + transaction_options: Dict[str, Any], + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.transaction_options = transaction_options + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> str: + hook = DatastoreHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + handle = hook.begin_transaction( + transaction_options=self.transaction_options, + project_id=self.project_id, + ) + return handle + + +class CloudDatastoreCommitOperator(BaseOperator): + """ + Commit a transaction, optionally creating, deleting or modifying some entities. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDatastoreCommitOperator` + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/rest/v1/projects/commit + + :param body: the body of the commit request. + :type body: dict + :param project_id: Google Cloud project ID against which to make the request. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "body", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + body: Dict[str, Any], + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.body = body + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> dict: + hook = DatastoreHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + response = hook.commit( + body=self.body, + project_id=self.project_id, + ) + return response + + +class CloudDatastoreRollbackOperator(BaseOperator): + """ + Roll back a transaction. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDatastoreRollbackOperator` + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/rest/v1/projects/rollback + + :param transaction: the transaction to roll back. + :type transaction: str + :param project_id: Google Cloud project ID against which to make the request. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "transaction", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + transaction: str, + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.transaction = transaction + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = DatastoreHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + hook.rollback( + transaction=self.transaction, + project_id=self.project_id, + ) + + +class CloudDatastoreRunQueryOperator(BaseOperator): + """ + Run a query for entities. Returns the batch of query results. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDatastoreRunQueryOperator` + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/rest/v1/projects/runQuery + + :param body: the body of the query request. + :type body: dict + :param project_id: Google Cloud project ID against which to make the request. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "body", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + body: Dict[str, Any], + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.body = body + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> dict: + hook = DatastoreHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + response = hook.run_query( + body=self.body, + project_id=self.project_id, + ) + return response + + +class CloudDatastoreGetOperationOperator(BaseOperator): + """ + Gets the latest state of a long-running operation. + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects.operations/get + + :param name: the name of the operation resource. + :type name: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + name: str, + delegate_to: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.name = name + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = DatastoreHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + op = hook.get_operation(name=self.name) + return op + + +class CloudDatastoreDeleteOperationOperator(BaseOperator): + """ + Deletes the long-running operation. + + .. seealso:: + https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects.operations/delete + + :param name: the name of the operation resource. + :type name: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + name: str, + delegate_to: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.name = name + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = DatastoreHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + hook.delete_operation(name=self.name) diff --git a/reference/providers/google/cloud/operators/dlp.py b/reference/providers/google/cloud/operators/dlp.py new file mode 100644 index 0000000..7270c7c --- /dev/null +++ b/reference/providers/google/cloud/operators/dlp.py @@ -0,0 +1,2899 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=R0913, C0302 +""" +This module contains various Google Cloud DLP operators +which allow you to perform basic operations using +Cloud DLP. +""" +from typing import Dict, Optional, Sequence, Tuple, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.dlp import CloudDLPHook +from airflow.utils.decorators import apply_defaults +from google.api_core.exceptions import AlreadyExists, InvalidArgument, NotFound +from google.api_core.retry import Retry +from google.cloud.dlp_v2.types import ( + ByteContentItem, + ContentItem, + DeidentifyConfig, + DeidentifyTemplate, + FieldMask, + InspectConfig, + InspectJobConfig, + InspectTemplate, + JobTrigger, + RedactImageRequest, + RiskAnalysisJobConfig, + StoredInfoTypeConfig, +) +from google.protobuf.json_format import MessageToDict + + +class CloudDLPCancelDLPJobOperator(BaseOperator): + """ + Starts asynchronous cancellation on a long-running DlpJob. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPCancelDLPJobOperator` + + :param dlp_job_id: ID of the DLP job resource to be cancelled. + :type dlp_job_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default project_id + from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "dlp_job_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + dlp_job_id: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dlp_job_id = dlp_job_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + hook.cancel_dlp_job( + dlp_job_id=self.dlp_job_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudDLPCreateDeidentifyTemplateOperator(BaseOperator): + """ + Creates a DeidentifyTemplate for re-using frequently used configuration for + de-identifying content, images, and storage. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPCreateDeidentifyTemplateOperator` + + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param deidentify_template: (Optional) The DeidentifyTemplate to create. + :type deidentify_template: dict or google.cloud.dlp_v2.types.DeidentifyTemplate + :param template_id: (Optional) The template ID. + :type template_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.DeidentifyTemplate + """ + + template_fields = ( + "organization_id", + "project_id", + "deidentify_template", + "template_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + deidentify_template: Optional[Union[Dict, DeidentifyTemplate]] = None, + template_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.organization_id = organization_id + self.project_id = project_id + self.deidentify_template = deidentify_template + self.template_id = template_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + template = hook.create_deidentify_template( + organization_id=self.organization_id, + project_id=self.project_id, + deidentify_template=self.deidentify_template, + template_id=self.template_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except AlreadyExists: + template = hook.get_deidentify_template( + organization_id=self.organization_id, + project_id=self.project_id, + template_id=self.template_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + return MessageToDict(template) + + +class CloudDLPCreateDLPJobOperator(BaseOperator): + """ + Creates a new job to inspect storage or calculate risk metrics. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPCreateDLPJobOperator` + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param inspect_job: (Optional) The configuration for the inspect job. + :type inspect_job: dict or google.cloud.dlp_v2.types.InspectJobConfig + :param risk_job: (Optional) The configuration for the risk job. + :type risk_job: dict or google.cloud.dlp_v2.types.RiskAnalysisJobConfig + :param job_id: (Optional) The job ID. + :type job_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param wait_until_finished: (Optional) If true, it will keep polling the job state + until it is set to DONE. + :type wait_until_finished: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.DlpJob + """ + + template_fields = ( + "project_id", + "inspect_job", + "risk_job", + "job_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + project_id: Optional[str] = None, + inspect_job: Optional[Union[Dict, InspectJobConfig]] = None, + risk_job: Optional[Union[Dict, RiskAnalysisJobConfig]] = None, + job_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + wait_until_finished: bool = True, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.inspect_job = inspect_job + self.risk_job = risk_job + self.job_id = job_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.wait_until_finished = wait_until_finished + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + job = hook.create_dlp_job( + project_id=self.project_id, + inspect_job=self.inspect_job, + risk_job=self.risk_job, + job_id=self.job_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + wait_until_finished=self.wait_until_finished, + ) + except AlreadyExists: + job = hook.get_dlp_job( + project_id=self.project_id, + dlp_job_id=self.job_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(job) + + +class CloudDLPCreateInspectTemplateOperator(BaseOperator): + """ + Creates an InspectTemplate for re-using frequently used configuration for + inspecting content, images, and storage. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPCreateInspectTemplateOperator` + + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param inspect_template: (Optional) The InspectTemplate to create. + :type inspect_template: dict or google.cloud.dlp_v2.types.InspectTemplate + :param template_id: (Optional) The template ID. + :type template_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.InspectTemplate + """ + + template_fields = ( + "organization_id", + "project_id", + "inspect_template", + "template_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + inspect_template: Optional[InspectTemplate] = None, + template_id: Optional[Union[Dict, InspectTemplate]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.organization_id = organization_id + self.project_id = project_id + self.inspect_template = inspect_template + self.template_id = template_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + template = hook.create_inspect_template( + organization_id=self.organization_id, + project_id=self.project_id, + inspect_template=self.inspect_template, + template_id=self.template_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except AlreadyExists: + template = hook.get_inspect_template( + organization_id=self.organization_id, + project_id=self.project_id, + template_id=self.template_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(template) + + +class CloudDLPCreateJobTriggerOperator(BaseOperator): + """ + Creates a job trigger to run DLP actions such as scanning storage for sensitive + information on a set schedule. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPCreateJobTriggerOperator` + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param job_trigger: (Optional) The JobTrigger to create. + :type job_trigger: dict or google.cloud.dlp_v2.types.JobTrigger + :param trigger_id: (Optional) The JobTrigger ID. + :type trigger_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.JobTrigger + """ + + template_fields = ( + "project_id", + "job_trigger", + "trigger_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + project_id: Optional[str] = None, + job_trigger: Optional[Union[Dict, JobTrigger]] = None, + trigger_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.job_trigger = job_trigger + self.trigger_id = trigger_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + trigger = hook.create_job_trigger( + project_id=self.project_id, + job_trigger=self.job_trigger, + trigger_id=self.trigger_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except InvalidArgument as e: + if "already in use" not in e.message: + raise + trigger = hook.get_job_trigger( + project_id=self.project_id, + job_trigger_id=self.trigger_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(trigger) + + +class CloudDLPCreateStoredInfoTypeOperator(BaseOperator): + """ + Creates a pre-built stored infoType to be used for inspection. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPCreateStoredInfoTypeOperator` + + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param config: (Optional) The config for the StoredInfoType. + :type config: dict or google.cloud.dlp_v2.types.StoredInfoTypeConfig + :param stored_info_type_id: (Optional) The StoredInfoType ID. + :type stored_info_type_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.StoredInfoType + """ + + template_fields = ( + "organization_id", + "project_id", + "config", + "stored_info_type_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + config: Optional[StoredInfoTypeConfig] = None, + stored_info_type_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.organization_id = organization_id + self.project_id = project_id + self.config = config + self.stored_info_type_id = stored_info_type_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + info = hook.create_stored_info_type( + organization_id=self.organization_id, + project_id=self.project_id, + config=self.config, + stored_info_type_id=self.stored_info_type_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except InvalidArgument as e: + if "already exists" not in e.message: + raise + info = hook.get_stored_info_type( + organization_id=self.organization_id, + project_id=self.project_id, + stored_info_type_id=self.stored_info_type_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(info) + + +class CloudDLPDeidentifyContentOperator(BaseOperator): + """ + De-identifies potentially sensitive info from a ContentItem. This method has limits + on input size and output size. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPDeidentifyContentOperator` + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param deidentify_config: (Optional) Configuration for the de-identification of the + content item. Items specified here will override the template referenced by the + deidentify_template_name argument. + :type deidentify_config: dict or google.cloud.dlp_v2.types.DeidentifyConfig + :param inspect_config: (Optional) Configuration for the inspector. Items specified + here will override the template referenced by the inspect_template_name argument. + :type inspect_config: dict or google.cloud.dlp_v2.types.InspectConfig + :param item: (Optional) The item to de-identify. Will be treated as text. + :type item: dict or google.cloud.dlp_v2.types.ContentItem + :param inspect_template_name: (Optional) Optional template to use. Any configuration + directly specified in inspect_config will override those set in the template. + :type inspect_template_name: str + :param deidentify_template_name: (Optional) Optional template to use. Any + configuration directly specified in deidentify_config will override those set + in the template. + :type deidentify_template_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.DeidentifyContentResponse + """ + + template_fields = ( + "project_id", + "deidentify_config", + "inspect_config", + "item", + "inspect_template_name", + "deidentify_template_name", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + project_id: Optional[str] = None, + deidentify_config: Optional[Union[Dict, DeidentifyConfig]] = None, + inspect_config: Optional[Union[Dict, InspectConfig]] = None, + item: Optional[Union[Dict, ContentItem]] = None, + inspect_template_name: Optional[str] = None, + deidentify_template_name: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.deidentify_config = deidentify_config + self.inspect_config = inspect_config + self.item = item + self.inspect_template_name = inspect_template_name + self.deidentify_template_name = deidentify_template_name + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> dict: + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + response = hook.deidentify_content( + project_id=self.project_id, + deidentify_config=self.deidentify_config, + inspect_config=self.inspect_config, + item=self.item, + inspect_template_name=self.inspect_template_name, + deidentify_template_name=self.deidentify_template_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(response) + + +class CloudDLPDeleteDeidentifyTemplateOperator(BaseOperator): + """ + Deletes a DeidentifyTemplate. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPDeleteDeidentifyTemplateOperator` + + :param template_id: The ID of deidentify template to be deleted. + :type template_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "template_id", + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + template_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.template_id = template_id + self.organization_id = organization_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + hook.delete_deidentify_template( + template_id=self.template_id, + organization_id=self.organization_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except NotFound: + self.log.error("Template %s not found.", self.template_id) + + +class CloudDLPDeleteDLPJobOperator(BaseOperator): + """ + Deletes a long-running DlpJob. This method indicates that the client is no longer + interested in the DlpJob result. The job will be cancelled if possible. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPDeleteDLPJobOperator` + + :param dlp_job_id: The ID of the DLP job resource to be cancelled. + :type dlp_job_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "dlp_job_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + dlp_job_id: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dlp_job_id = dlp_job_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + hook.delete_dlp_job( + dlp_job_id=self.dlp_job_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except NotFound: + self.log.error("Job %s id not found.", self.dlp_job_id) + + +class CloudDLPDeleteInspectTemplateOperator(BaseOperator): + """ + Deletes an InspectTemplate. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPDeleteInspectTemplateOperator` + + :param template_id: The ID of the inspect template to be deleted. + :type template_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "template_id", + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + template_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.template_id = template_id + self.organization_id = organization_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + hook.delete_inspect_template( + template_id=self.template_id, + organization_id=self.organization_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except NotFound: + self.log.error("Template %s not found", self.template_id) + + +class CloudDLPDeleteJobTriggerOperator(BaseOperator): + """ + Deletes a job trigger. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPDeleteJobTriggerOperator` + + :param job_trigger_id: The ID of the DLP job trigger to be deleted. + :type job_trigger_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "job_trigger_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + job_trigger_id: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_trigger_id = job_trigger_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + hook.delete_job_trigger( + job_trigger_id=self.job_trigger_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except NotFound: + self.log.error("Trigger %s not found", self.job_trigger_id) + + +class CloudDLPDeleteStoredInfoTypeOperator(BaseOperator): + """ + Deletes a stored infoType. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPDeleteStoredInfoTypeOperator` + + :param stored_info_type_id: The ID of the stored info type to be deleted. + :type stored_info_type_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "stored_info_type_id", + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + stored_info_type_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.stored_info_type_id = stored_info_type_id + self.organization_id = organization_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + hook.delete_stored_info_type( + stored_info_type_id=self.stored_info_type_id, + organization_id=self.organization_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except NotFound: + self.log.error("Stored info %s not found", self.stored_info_type_id) + + +class CloudDLPGetDeidentifyTemplateOperator(BaseOperator): + """ + Gets a DeidentifyTemplate. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPGetDeidentifyTemplateOperator` + + :param template_id: The ID of deidentify template to be read. + :type template_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.DeidentifyTemplate + """ + + template_fields = ( + "template_id", + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + template_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.template_id = template_id + self.organization_id = organization_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + template = hook.get_deidentify_template( + template_id=self.template_id, + organization_id=self.organization_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(template) + + +class CloudDLPGetDLPJobOperator(BaseOperator): + """ + Gets the latest state of a long-running DlpJob. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPGetDLPJobOperator` + + :param dlp_job_id: The ID of the DLP job resource to be read. + :type dlp_job_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.DlpJob + """ + + template_fields = ( + "dlp_job_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + dlp_job_id: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dlp_job_id = dlp_job_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + job = hook.get_dlp_job( + dlp_job_id=self.dlp_job_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(job) + + +class CloudDLPGetInspectTemplateOperator(BaseOperator): + """ + Gets an InspectTemplate. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPGetInspectTemplateOperator` + + :param template_id: The ID of inspect template to be read. + :type template_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.InspectTemplate + """ + + template_fields = ( + "template_id", + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + template_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.template_id = template_id + self.organization_id = organization_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + template = hook.get_inspect_template( + template_id=self.template_id, + organization_id=self.organization_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(template) + + +class CloudDLPGetDLPJobTriggerOperator(BaseOperator): + """ + Gets a job trigger. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPGetDLPJobTriggerOperator` + + :param job_trigger_id: The ID of the DLP job trigger to be read. + :type job_trigger_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.JobTrigger + """ + + template_fields = ( + "job_trigger_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + job_trigger_id: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_trigger_id = job_trigger_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + trigger = hook.get_job_trigger( + job_trigger_id=self.job_trigger_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(trigger) + + +class CloudDLPGetStoredInfoTypeOperator(BaseOperator): + """ + Gets a stored infoType. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPGetStoredInfoTypeOperator` + + :param stored_info_type_id: The ID of the stored info type to be read. + :type stored_info_type_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.StoredInfoType + """ + + template_fields = ( + "stored_info_type_id", + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + stored_info_type_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.stored_info_type_id = stored_info_type_id + self.organization_id = organization_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + info = hook.get_stored_info_type( + stored_info_type_id=self.stored_info_type_id, + organization_id=self.organization_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(info) + + +class CloudDLPInspectContentOperator(BaseOperator): + """ + Finds potentially sensitive info in content. This method has limits on + input size, processing time, and output size. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPInspectContentOperator` + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param inspect_config: (Optional) Configuration for the inspector. Items specified + here will override the template referenced by the inspect_template_name argument. + :type inspect_config: dict or google.cloud.dlp_v2.types.InspectConfig + :param item: (Optional) The item to de-identify. Will be treated as text. + :type item: dict or google.cloud.dlp_v2.types.ContentItem + :param inspect_template_name: (Optional) Optional template to use. Any configuration + directly specified in inspect_config will override those set in the template. + :type inspect_template_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.tasks_v2.types.InspectContentResponse + """ + + template_fields = ( + "project_id", + "inspect_config", + "item", + "inspect_template_name", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + project_id: Optional[str] = None, + inspect_config: Optional[Union[Dict, InspectConfig]] = None, + item: Optional[Union[Dict, ContentItem]] = None, + inspect_template_name: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.inspect_config = inspect_config + self.item = item + self.inspect_template_name = inspect_template_name + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + response = hook.inspect_content( + project_id=self.project_id, + inspect_config=self.inspect_config, + item=self.item, + inspect_template_name=self.inspect_template_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(response) + + +class CloudDLPListDeidentifyTemplatesOperator(BaseOperator): + """ + Lists DeidentifyTemplates. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPListDeidentifyTemplatesOperator` + + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param order_by: (Optional) Optional comma separated list of fields to order by, + followed by asc or desc postfix. + :type order_by: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: list[google.cloud.dlp_v2.types.DeidentifyTemplate] + """ + + template_fields = ( + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + page_size: Optional[int] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.organization_id = organization_id + self.project_id = project_id + self.page_size = page_size + self.order_by = order_by + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + template = hook.list_deidentify_templates( + organization_id=self.organization_id, + project_id=self.project_id, + page_size=self.page_size, + order_by=self.order_by, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(template) + + +class CloudDLPListDLPJobsOperator(BaseOperator): + """ + Lists DlpJobs that match the specified filter in the request. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPListDLPJobsOperator` + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param results_filter: (Optional) Filter used to specify a subset of results. + :type results_filter: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param job_type: (Optional) The type of job. + :type job_type: str + :param order_by: (Optional) Optional comma separated list of fields to order by, + followed by asc or desc postfix. + :type order_by: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: list[google.cloud.dlp_v2.types.DlpJob] + """ + + template_fields = ( + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + project_id: Optional[str] = None, + results_filter: Optional[str] = None, + page_size: Optional[int] = None, + job_type: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.results_filter = results_filter + self.page_size = page_size + self.job_type = job_type + self.order_by = order_by + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + job = hook.list_dlp_jobs( + project_id=self.project_id, + results_filter=self.results_filter, + page_size=self.page_size, + job_type=self.job_type, + order_by=self.order_by, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(job) + + +class CloudDLPListInfoTypesOperator(BaseOperator): + """ + Returns a list of the sensitive information types that the DLP API supports. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPListInfoTypesOperator` + + :param language_code: (Optional) Optional BCP-47 language code for localized infoType + friendly names. If omitted, or if localized strings are not available, en-US + strings will be returned. + :type language_code: str + :param results_filter: (Optional) Filter used to specify a subset of results. + :type results_filter: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: ListInfoTypesResponse + """ + + template_fields = ( + "language_code", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + language_code: Optional[str] = None, + results_filter: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.language_code = language_code + self.results_filter = results_filter + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + response = hook.list_info_types( + language_code=self.language_code, + results_filter=self.results_filter, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(response) + + +class CloudDLPListInspectTemplatesOperator(BaseOperator): + """ + Lists InspectTemplates. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPListInspectTemplatesOperator` + + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param order_by: (Optional) Optional comma separated list of fields to order by, + followed by asc or desc postfix. + :type order_by: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: list[google.cloud.dlp_v2.types.InspectTemplate] + """ + + template_fields = ( + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + page_size: Optional[int] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.organization_id = organization_id + self.project_id = project_id + self.page_size = page_size + self.order_by = order_by + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + templates = hook.list_inspect_templates( + organization_id=self.organization_id, + project_id=self.project_id, + page_size=self.page_size, + order_by=self.order_by, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return [MessageToDict(t) for t in templates] + + +class CloudDLPListJobTriggersOperator(BaseOperator): + """ + Lists job triggers. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPListJobTriggersOperator` + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param order_by: (Optional) Optional comma separated list of fields to order by, + followed by asc or desc postfix. + :type order_by: str + :param results_filter: (Optional) Filter used to specify a subset of results. + :type results_filter: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: list[google.cloud.dlp_v2.types.JobTrigger] + """ + + template_fields = ( + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + project_id: Optional[str] = None, + page_size: Optional[int] = None, + order_by: Optional[str] = None, + results_filter: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.page_size = page_size + self.order_by = order_by + self.results_filter = results_filter + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + jobs = hook.list_job_triggers( + project_id=self.project_id, + page_size=self.page_size, + order_by=self.order_by, + results_filter=self.results_filter, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return [MessageToDict(j) for j in jobs] + + +class CloudDLPListStoredInfoTypesOperator(BaseOperator): + """ + Lists stored infoTypes. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPListStoredInfoTypesOperator` + + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param order_by: (Optional) Optional comma separated list of fields to order by, + followed by asc or desc postfix. + :type order_by: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: list[google.cloud.dlp_v2.types.StoredInfoType] + """ + + template_fields = ( + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + page_size: Optional[int] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.organization_id = organization_id + self.project_id = project_id + self.page_size = page_size + self.order_by = order_by + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + infos = hook.list_stored_info_types( + organization_id=self.organization_id, + project_id=self.project_id, + page_size=self.page_size, + order_by=self.order_by, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return [MessageToDict(i) for i in infos] + + +class CloudDLPRedactImageOperator(BaseOperator): + """ + Redacts potentially sensitive info from an image. This method has limits on + input size, processing time, and output size. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPRedactImageOperator` + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param inspect_config: (Optional) Configuration for the inspector. Items specified + here will override the template referenced by the inspect_template_name argument. + :type inspect_config: dict or google.cloud.dlp_v2.types.InspectConfig + :param image_redaction_configs: (Optional) The configuration for specifying what + content to redact from images. + :type image_redaction_configs: list[dict] or + list[google.cloud.dlp_v2.types.RedactImageRequest.ImageRedactionConfig] + :param include_findings: (Optional) Whether the response should include findings + along with the redacted image. + :type include_findings: bool + :param byte_item: (Optional) The content must be PNG, JPEG, SVG or BMP. + :type byte_item: dict or google.cloud.dlp_v2.types.ByteContentItem + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.RedactImageResponse + """ + + template_fields = ( + "project_id", + "inspect_config", + "image_redaction_configs", + "include_findings", + "byte_item", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + project_id: Optional[str] = None, + inspect_config: Optional[Union[Dict, InspectConfig]] = None, + image_redaction_configs: Optional[ + Union[Dict, RedactImageRequest.ImageRedactionConfig] + ] = None, + include_findings: Optional[bool] = None, + byte_item: Optional[Union[Dict, ByteContentItem]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.inspect_config = inspect_config + self.image_redaction_configs = image_redaction_configs + self.include_findings = include_findings + self.byte_item = byte_item + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + response = hook.redact_image( + project_id=self.project_id, + inspect_config=self.inspect_config, + image_redaction_configs=self.image_redaction_configs, + include_findings=self.include_findings, + byte_item=self.byte_item, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(response) + + +class CloudDLPReidentifyContentOperator(BaseOperator): + """ + Re-identifies content that has been de-identified. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPReidentifyContentOperator` + + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param reidentify_config: (Optional) Configuration for the re-identification of + the content item. + :type reidentify_config: dict or google.cloud.dlp_v2.types.DeidentifyConfig + :param inspect_config: (Optional) Configuration for the inspector. + :type inspect_config: dict or google.cloud.dlp_v2.types.InspectConfig + :param item: (Optional) The item to re-identify. Will be treated as text. + :type item: dict or google.cloud.dlp_v2.types.ContentItem + :param inspect_template_name: (Optional) Optional template to use. Any configuration + directly specified in inspect_config will override those set in the template. + :type inspect_template_name: str + :param reidentify_template_name: (Optional) Optional template to use. References an + instance of DeidentifyTemplate. Any configuration directly specified in + reidentify_config or inspect_config will override those set in the template. + :type reidentify_template_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.ReidentifyContentResponse + """ + + template_fields = ( + "project_id", + "reidentify_config", + "inspect_config", + "item", + "inspect_template_name", + "reidentify_template_name", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + project_id: Optional[str] = None, + reidentify_config: Optional[Union[Dict, DeidentifyConfig]] = None, + inspect_config: Optional[Union[Dict, InspectConfig]] = None, + item: Optional[Union[Dict, ContentItem]] = None, + inspect_template_name: Optional[str] = None, + reidentify_template_name: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.reidentify_config = reidentify_config + self.inspect_config = inspect_config + self.item = item + self.inspect_template_name = inspect_template_name + self.reidentify_template_name = reidentify_template_name + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + response = hook.reidentify_content( + project_id=self.project_id, + reidentify_config=self.reidentify_config, + inspect_config=self.inspect_config, + item=self.item, + inspect_template_name=self.inspect_template_name, + reidentify_template_name=self.reidentify_template_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(response) + + +class CloudDLPUpdateDeidentifyTemplateOperator(BaseOperator): + """ + Updates the DeidentifyTemplate. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPUpdateDeidentifyTemplateOperator` + + :param template_id: The ID of deidentify template to be updated. + :type template_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param deidentify_template: New DeidentifyTemplate value. + :type deidentify_template: dict or google.cloud.dlp_v2.types.DeidentifyTemplate + :param update_mask: Mask to control which fields get updated. + :type update_mask: dict or google.cloud.dlp_v2.types.FieldMask + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.DeidentifyTemplate + """ + + template_fields = ( + "template_id", + "organization_id", + "project_id", + "deidentify_template", + "update_mask", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + template_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + deidentify_template: Optional[Union[Dict, DeidentifyTemplate]] = None, + update_mask: Optional[Union[Dict, FieldMask]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.template_id = template_id + self.organization_id = organization_id + self.project_id = project_id + self.deidentify_template = deidentify_template + self.update_mask = update_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + template = hook.update_deidentify_template( + template_id=self.template_id, + organization_id=self.organization_id, + project_id=self.project_id, + deidentify_template=self.deidentify_template, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(template) + + +class CloudDLPUpdateInspectTemplateOperator(BaseOperator): + """ + Updates the InspectTemplate. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPUpdateInspectTemplateOperator` + + :param template_id: The ID of the inspect template to be updated. + :type template_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param inspect_template: New InspectTemplate value. + :type inspect_template: dict or google.cloud.dlp_v2.types.InspectTemplate + :param update_mask: Mask to control which fields get updated. + :type update_mask: dict or google.cloud.dlp_v2.types.FieldMask + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.InspectTemplate + """ + + template_fields = ( + "template_id", + "organization_id", + "project_id", + "inspect_template", + "update_mask", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + template_id: str, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + inspect_template: Optional[Union[Dict, InspectTemplate]] = None, + update_mask: Optional[Union[Dict, FieldMask]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.template_id = template_id + self.organization_id = organization_id + self.project_id = project_id + self.inspect_template = inspect_template + self.update_mask = update_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + template = hook.update_inspect_template( + template_id=self.template_id, + organization_id=self.organization_id, + project_id=self.project_id, + inspect_template=self.inspect_template, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(template) + + +class CloudDLPUpdateJobTriggerOperator(BaseOperator): + """ + Updates a job trigger. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPUpdateJobTriggerOperator` + + :param job_trigger_id: The ID of the DLP job trigger to be updated. + :type job_trigger_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. If set to None or missing, the default + project_id from the Google Cloud connection is used. + :type project_id: str + :param job_trigger: New JobTrigger value. + :type job_trigger: dict or google.cloud.dlp_v2.types.JobTrigger + :param update_mask: Mask to control which fields get updated. + :type update_mask: dict or google.cloud.dlp_v2.types.FieldMask + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.InspectTemplate + """ + + template_fields = ( + "job_trigger_id", + "project_id", + "job_trigger", + "update_mask", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + job_trigger_id, + project_id: Optional[str] = None, + job_trigger: Optional[JobTrigger] = None, + update_mask: Optional[Union[Dict, FieldMask]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_trigger_id = job_trigger_id + self.project_id = project_id + self.job_trigger = job_trigger + self.update_mask = update_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + trigger = hook.update_job_trigger( + job_trigger_id=self.job_trigger_id, + project_id=self.project_id, + job_trigger=self.job_trigger, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(trigger) + + +class CloudDLPUpdateStoredInfoTypeOperator(BaseOperator): + """ + Updates the stored infoType by creating a new version. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDLPUpdateStoredInfoTypeOperator` + + :param stored_info_type_id: The ID of the stored info type to be updated. + :type stored_info_type_id: str + :param organization_id: (Optional) The organization ID. Required to set this + field if parent resource is an organization. + :type organization_id: str + :param project_id: (Optional) Google Cloud project ID where the + DLP Instance exists. Only set this field if the parent resource is + a project instead of an organization. + :type project_id: str + :param config: Updated configuration for the storedInfoType. If not provided, a new + version of the storedInfoType will be created with the existing configuration. + :type config: dict or google.cloud.dlp_v2.types.StoredInfoTypeConfig + :param update_mask: Mask to control which fields get updated. + :type update_mask: dict or google.cloud.dlp_v2.types.FieldMask + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.dlp_v2.types.StoredInfoType + """ + + template_fields = ( + "stored_info_type_id", + "organization_id", + "project_id", + "config", + "update_mask", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + stored_info_type_id, + organization_id: Optional[str] = None, + project_id: Optional[str] = None, + config: Optional[Union[Dict, StoredInfoTypeConfig]] = None, + update_mask: Optional[Union[Dict, FieldMask]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.stored_info_type_id = stored_info_type_id + self.organization_id = organization_id + self.project_id = project_id + self.config = config + self.update_mask = update_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudDLPHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + info = hook.update_stored_info_type( + stored_info_type_id=self.stored_info_type_id, + organization_id=self.organization_id, + project_id=self.project_id, + config=self.config, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(info) diff --git a/reference/providers/google/cloud/operators/functions.py b/reference/providers/google/cloud/operators/functions.py new file mode 100644 index 0000000..a86fa2f --- /dev/null +++ b/reference/providers/google/cloud/operators/functions.py @@ -0,0 +1,511 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud Functions operators.""" + +import re +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.functions import CloudFunctionsHook +from airflow.providers.google.cloud.utils.field_validator import ( + GcpBodyFieldValidator, + GcpFieldValidationException, +) +from airflow.utils.decorators import apply_defaults +from airflow.version import version +from googleapiclient.errors import HttpError + + +def _validate_available_memory_in_mb(value): + if int(value) <= 0: + raise GcpFieldValidationException( + "The available memory has to be greater than 0" + ) + + +def _validate_max_instances(value): + if int(value) <= 0: + raise GcpFieldValidationException( + "The max instances parameter has to be greater than 0" + ) + + +CLOUD_FUNCTION_VALIDATION = [ + dict(name="name", regexp="^.+$"), + dict(name="description", regexp="^.+$", optional=True), + dict(name="entryPoint", regexp=r"^.+$", optional=True), + dict(name="runtime", regexp=r"^.+$", optional=True), + dict(name="timeout", regexp=r"^.+$", optional=True), + dict( + name="availableMemoryMb", + custom_validation=_validate_available_memory_in_mb, + optional=True, + ), + dict(name="labels", optional=True), + dict(name="environmentVariables", optional=True), + dict(name="network", regexp=r"^.+$", optional=True), + dict(name="maxInstances", optional=True, custom_validation=_validate_max_instances), + dict( + name="source_code", + type="union", + fields=[ + dict(name="sourceArchiveUrl", regexp=r"^.+$"), + dict(name="sourceRepositoryUrl", regexp=r"^.+$", api_version="v1beta2"), + dict( + name="sourceRepository", + type="dict", + fields=[dict(name="url", regexp=r"^.+$")], + ), + dict(name="sourceUploadUrl"), + ], + ), + dict( + name="trigger", + type="union", + fields=[ + dict( + name="httpsTrigger", + type="dict", + fields=[ + # This dict should be empty at input (url is added at output) + ], + ), + dict( + name="eventTrigger", + type="dict", + fields=[ + dict(name="eventType", regexp=r"^.+$"), + dict(name="resource", regexp=r"^.+$"), + dict(name="service", regexp=r"^.+$", optional=True), + dict( + name="failurePolicy", + type="dict", + optional=True, + fields=[dict(name="retry", type="dict", optional=True)], + ), + ], + ), + ], + ), +] # type: List[Dict[str, Any]] + + +class CloudFunctionDeployFunctionOperator(BaseOperator): + """ + Creates a function in Google Cloud Functions. + If a function with this name already exists, it will be updated. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudFunctionDeployFunctionOperator` + + :param location: Google Cloud region where the function should be created. + :type location: str + :param body: Body of the Cloud Functions definition. The body must be a + Cloud Functions dictionary as described in: + https://cloud.google.com/functions/docs/reference/rest/v1/projects.locations.functions + . Different API versions require different variants of the Cloud Functions + dictionary. + :type body: dict or google.cloud.functions.v1.CloudFunction + :param project_id: (Optional) Google Cloud project ID where the function + should be created. + :type project_id: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + Default 'google_cloud_default'. + :type gcp_conn_id: str + :param api_version: (Optional) API version used (for example v1 - default - or + v1beta1). + :type api_version: str + :param zip_path: Path to zip file containing source code of the function. If the path + is set, the sourceUploadUrl should not be specified in the body or it should + be empty. Then the zip file will be uploaded using the upload URL generated + via generateUploadUrl from the Cloud Functions API. + :type zip_path: str + :param validate_body: If set to False, body validation is not performed. + :type validate_body: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcf_function_deploy_template_fields] + template_fields = ( + "body", + "project_id", + "location", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gcf_function_deploy_template_fields] + + @apply_defaults + def __init__( + self, + *, + location: str, + body: Dict, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + zip_path: Optional[str] = None, + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.project_id = project_id + self.location = location + self.body = body + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.zip_path = zip_path + self.zip_path_preprocessor = ZipPathPreprocessor(body, zip_path) + self._field_validator = None # type: Optional[GcpBodyFieldValidator] + self.impersonation_chain = impersonation_chain + if validate_body: + self._field_validator = GcpBodyFieldValidator( + CLOUD_FUNCTION_VALIDATION, api_version=api_version + ) + self._validate_inputs() + super().__init__(**kwargs) + + def _validate_inputs(self) -> None: + if not self.location: + raise AirflowException("The required parameter 'location' is missing") + if not self.body: + raise AirflowException("The required parameter 'body' is missing") + self.zip_path_preprocessor.preprocess_body() + + def _validate_all_body_fields(self) -> None: + if self._field_validator: + self._field_validator.validate(self.body) + + def _create_new_function(self, hook) -> None: + hook.create_new_function( + project_id=self.project_id, location=self.location, body=self.body + ) + + def _update_function(self, hook) -> None: + hook.update_function(self.body["name"], self.body, self.body.keys()) + + def _check_if_function_exists(self, hook) -> bool: + name = self.body.get("name") + if not name: + raise GcpFieldValidationException( + f"The 'name' field should be present in body: '{self.body}'." + ) + try: + hook.get_function(name) + except HttpError as e: + status = e.resp.status + if status == 404: + return False + raise e + return True + + def _upload_source_code(self, hook): + return hook.upload_function_zip( + project_id=self.project_id, location=self.location, zip_path=self.zip_path + ) + + def _set_airflow_version_label(self) -> None: + if "labels" not in self.body.keys(): + self.body["labels"] = {} + self.body["labels"].update( + {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")} + ) + + def execute(self, context): + hook = CloudFunctionsHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + if self.zip_path_preprocessor.should_upload_function(): + self.body[GCF_SOURCE_UPLOAD_URL] = self._upload_source_code(hook) + self._validate_all_body_fields() + self._set_airflow_version_label() + if not self._check_if_function_exists(hook): + self._create_new_function(hook) + else: + self._update_function(hook) + + +GCF_SOURCE_ARCHIVE_URL = "sourceArchiveUrl" +GCF_SOURCE_UPLOAD_URL = "sourceUploadUrl" +SOURCE_REPOSITORY = "sourceRepository" +GCF_ZIP_PATH = "zip_path" + + +class ZipPathPreprocessor: + """ + Pre-processes zip path parameter. + + Responsible for checking if the zip path parameter is correctly specified in + relation with source_code body fields. Non empty zip path parameter is special because + it is mutually exclusive with sourceArchiveUrl and sourceRepository body fields. + It is also mutually exclusive with non-empty sourceUploadUrl. + The pre-process modifies sourceUploadUrl body field in special way when zip_path + is not empty. An extra step is run when execute method is called and sourceUploadUrl + field value is set in the body with the value returned by generateUploadUrl Cloud + Function API method. + + :param body: Body passed to the create/update method calls. + :type body: dict + :param zip_path: (optional) Path to zip file containing source code of the function. If the path + is set, the sourceUploadUrl should not be specified in the body or it should + be empty. Then the zip file will be uploaded using the upload URL generated + via generateUploadUrl from the Cloud Functions API. + :type zip_path: str + + """ + + upload_function = None # type: Optional[bool] + + def __init__(self, body: dict, zip_path: Optional[str] = None) -> None: + self.body = body + self.zip_path = zip_path + + @staticmethod + def _is_present_and_empty(dictionary, field) -> bool: + return field in dictionary and not dictionary[field] + + def _verify_upload_url_and_no_zip_path(self) -> None: + if self._is_present_and_empty(self.body, GCF_SOURCE_UPLOAD_URL): + if not self.zip_path: + raise AirflowException( + "Parameter '{url}' is empty in the body and argument '{path}' " + "is missing or empty. You need to have non empty '{path}' " + "when '{url}' is present and empty.".format( + url=GCF_SOURCE_UPLOAD_URL, path=GCF_ZIP_PATH + ) + ) + + def _verify_upload_url_and_zip_path(self) -> None: + if GCF_SOURCE_UPLOAD_URL in self.body and self.zip_path: + if not self.body[GCF_SOURCE_UPLOAD_URL]: + self.upload_function = True + else: + raise AirflowException( + "Only one of '{}' in body or '{}' argument " + "allowed. Found both.".format(GCF_SOURCE_UPLOAD_URL, GCF_ZIP_PATH) + ) + + def _verify_archive_url_and_zip_path(self) -> None: + if GCF_SOURCE_ARCHIVE_URL in self.body and self.zip_path: + raise AirflowException( + "Only one of '{}' in body or '{}' argument " + "allowed. Found both.".format(GCF_SOURCE_ARCHIVE_URL, GCF_ZIP_PATH) + ) + + def should_upload_function(self) -> bool: + """ + Checks if function source should be uploaded. + + :rtype: bool + """ + if self.upload_function is None: + raise AirflowException( + "validate() method has to be invoked before " "should_upload_function" + ) + return self.upload_function + + def preprocess_body(self) -> None: + """ + Modifies sourceUploadUrl body field in special way when zip_path + is not empty. + """ + self._verify_archive_url_and_zip_path() + self._verify_upload_url_and_zip_path() + self._verify_upload_url_and_no_zip_path() + if self.upload_function is None: + self.upload_function = False + + +FUNCTION_NAME_PATTERN = "^projects/[^/]+/locations/[^/]+/functions/[^/]+$" +FUNCTION_NAME_COMPILED_PATTERN = re.compile(FUNCTION_NAME_PATTERN) + + +class CloudFunctionDeleteFunctionOperator(BaseOperator): + """ + Deletes the specified function from Google Cloud Functions. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudFunctionDeleteFunctionOperator` + + :param name: A fully-qualified function name, matching + the pattern: `^projects/[^/]+/locations/[^/]+/functions/[^/]+$` + :type name: str + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (for example v1 or v1beta1). + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcf_function_delete_template_fields] + template_fields = ( + "name", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gcf_function_delete_template_fields] + + @apply_defaults + def __init__( + self, + *, + name: str, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.name = name + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.impersonation_chain = impersonation_chain + self._validate_inputs() + super().__init__(**kwargs) + + def _validate_inputs(self) -> None: + if not self.name: + raise AttributeError("Empty parameter: name") + else: + pattern = FUNCTION_NAME_COMPILED_PATTERN + if not pattern.match(self.name): + raise AttributeError( + f"Parameter name must match pattern: {FUNCTION_NAME_PATTERN}" + ) + + def execute(self, context): + hook = CloudFunctionsHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + try: + return hook.delete_function(self.name) + except HttpError as e: + status = e.resp.status + if status == 404: + self.log.info("The function does not exist in this project") + return None + else: + self.log.error("An error occurred. Exiting.") + raise e + + +class CloudFunctionInvokeFunctionOperator(BaseOperator): + """ + Invokes a deployed Cloud Function. To be used for testing + purposes as very limited traffic is allowed. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudFunctionDeployFunctionOperator` + + :param function_id: ID of the function to be called + :type function_id: str + :param input_data: Input to be passed to the function + :type input_data: Dict + :param location: The location where the function is located. + :type location: str + :param project_id: Optional, Google Cloud Project project_id where the function belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :return: None + """ + + template_fields = ( + "function_id", + "input_data", + "location", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + function_id: str, + input_data: Dict, + location: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.function_id = function_id + self.input_data = input_data + self.location = location + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.impersonation_chain = impersonation_chain + + def execute(self, context: Dict): + hook = CloudFunctionsHook( + api_version=self.api_version, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Calling function %s.", self.function_id) + result = hook.call_function( + function_id=self.function_id, + input_data=self.input_data, + location=self.location, + project_id=self.project_id, + ) + self.log.info( + "Function called successfully. Execution id %s", result.get("executionId") + ) + self.xcom_push( + context=context, key="execution_id", value=result.get("executionId") + ) + return result diff --git a/reference/providers/google/cloud/operators/gcs.py b/reference/providers/google/cloud/operators/gcs.py new file mode 100644 index 0000000..6758a28 --- /dev/null +++ b/reference/providers/google/cloud/operators/gcs.py @@ -0,0 +1,1113 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Storage Bucket operator.""" +import datetime +import subprocess +import sys +import warnings +from pathlib import Path +from tempfile import NamedTemporaryFile, TemporaryDirectory +from typing import Dict, Iterable, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils import timezone +from airflow.utils.decorators import apply_defaults +from google.api_core.exceptions import Conflict +from google.cloud.exceptions import GoogleCloudError + + +class GCSCreateBucketOperator(BaseOperator): + """ + Creates a new bucket. Google Cloud Storage uses a flat namespace, + so you can't create a bucket with a name that is already in use. + + .. seealso:: + For more information, see Bucket Naming Guidelines: + https://cloud.google.com/storage/docs/bucketnaming.html#requirements + + :param bucket_name: The name of the bucket. (templated) + :type bucket_name: str + :param re# An optional dict with parameters for creating the bucket. + For information on available parameters, see Cloud Storage API doc: + https://cloud.google.com/storage/docs/json_api/v1/buckets/insert + :type re# dict + :param storage_class: This defines how objects in the bucket are stored + and determines the SLA and the cost of storage (templated). Values include + + - ``MULTI_REGIONAL`` + - ``REGIONAL`` + - ``STANDARD`` + - ``NEARLINE`` + - ``COLDLINE``. + + If this value is not specified when the bucket is + created, it will default to STANDARD. + :type storage_class: str + :param location: The location of the bucket. (templated) + Object data for objects in the bucket resides in physical storage + within this region. Defaults to US. + + .. seealso:: https://developers.google.com/storage/docs/bucket-locations + + :type location: str + :param project_id: The ID of the Google Cloud Project. (templated) + :type project_id: str + :param labels: User-provided labels, in key/value pairs. + :type labels: dict + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + The following Operator would create a new bucket ``test-bucket`` + with ``MULTI_REGIONAL`` storage class in ``EU`` region + + .. code-block:: python + + CreateBucket = GoogleCloudStorageCreateBucketOperator( + task_id='CreateNewBucket', + bucket_name='test-bucket', + storage_class='MULTI_REGIONAL', + location='EU', + labels={'env': 'dev', 'team': 'airflow'}, + gcp_conn_id='airflow-conn-id' + ) + + """ + + template_fields = ( + "bucket_name", + "storage_class", + "location", + "project_id", + "impersonation_chain", + ) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + *, + bucket_name: str, + re# Optional[Dict] = None, + storage_class: str = "MULTI_REGIONAL", + location: str = "US", + project_id: Optional[str] = None, + labels: Optional[Dict] = None, + gcp_conn_id: str = "google_cloud_default", + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = google_cloud_storage_conn_id + + self.bucket_name = bucket_name + self.resource = resource + self.storage_class = storage_class + self.location = location + self.project_id = project_id + self.labels = labels + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + try: + hook.create_bucket( + bucket_name=self.bucket_name, + resource=self.resource, + storage_class=self.storage_class, + location=self.location, + project_id=self.project_id, + labels=self.labels, + ) + except Conflict: # HTTP 409 + self.log.warning("Bucket %s already exists", self.bucket_name) + + +class GCSListObjectsOperator(BaseOperator): + """ + List all objects from the bucket with the give string prefix and delimiter in name. + + This operator returns a python list with the name of objects which can be used by + `xcom` in the downstream task. + + :param bucket: The Google Cloud Storage bucket to find the objects. (templated) + :type bucket: str + :param prefix: Prefix string which filters objects whose name begin with + this prefix. (templated) + :type prefix: str + :param delimiter: The delimiter by which you want to filter the objects. (templated) + For e.g to lists the CSV files from in a directory in GCS you would use + delimiter='.csv'. + :type delimiter: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + **Example**: + The following Operator would list all the Avro files from ``sales/sales-2017`` + folder in ``data`` bucket. :: + + GCS_Files = GoogleCloudStorageListOperator( + task_id='GCS_Files', + bucket='data', + prefix='sales/sales-2017/', + delimiter='.avro', + gcp_conn_id=google_cloud_conn_id + ) + """ + + template_fields: Iterable[str] = ( + "bucket", + "prefix", + "delimiter", + "impersonation_chain", + ) + + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + *, + bucket: str, + prefix: Optional[str] = None, + delimiter: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = google_cloud_storage_conn_id + + self.bucket = bucket + self.prefix = prefix + self.delimiter = delimiter + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> list: + + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info( + "Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s", + self.bucket, + self.delimiter, + self.prefix, + ) + + return hook.list( + bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter + ) + + +class GCSDeleteObjectsOperator(BaseOperator): + """ + Deletes objects from a Google Cloud Storage bucket, either + from an explicit list of object names or all objects + matching a prefix. + + :param bucket_name: The GCS bucket to delete from + :type bucket_name: str + :param objects: List of objects to delete. These should be the names + of objects in the bucket, not including gs://bucket/ + :type objects: Iterable[str] + :param prefix: Prefix of objects to delete. All objects matching this + prefix in the bucket will be deleted. + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "bucket_name", + "prefix", + "objects", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + bucket_name: str, + objects: Optional[Iterable[str]] = None, + prefix: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = google_cloud_storage_conn_id + + self.bucket_name = bucket_name + self.objects = objects + self.prefix = prefix + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + if not objects and not prefix: + raise ValueError("Either object or prefix should be set. Both are None") + + super().__init__(**kwargs) + + def execute(self, context): + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + if self.objects: + objects = self.objects + else: + objects = hook.list(bucket_name=self.bucket_name, prefix=self.prefix) + + self.log.info("Deleting %s objects from %s", len(objects), self.bucket_name) + for object_name in objects: + hook.delete(bucket_name=self.bucket_name, object_name=object_name) + + +class GCSBucketCreateAclEntryOperator(BaseOperator): + """ + Creates a new ACL entry on the specified bucket. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GCSBucketCreateAclEntryOperator` + + :param bucket: Name of a bucket. + :type bucket: str + :param entity: The entity holding the permission, in one of the following forms: + user-userId, user-email, group-groupId, group-email, domain-domain, + project-team-projectId, allUsers, allAuthenticatedUsers + :type entity: str + :param role: The access permission for the entity. + Acceptable values are: "OWNER", "READER", "WRITER". + :type role: str + :param user_project: (Optional) The project to be billed for this request. + Required for Requester Pays buckets. + :type user_project: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcs_bucket_create_acl_template_fields] + template_fields = ( + "bucket", + "entity", + "role", + "user_project", + "impersonation_chain", + ) + # [END gcs_bucket_create_acl_template_fields] + + @apply_defaults + def __init__( + self, + *, + bucket: str, + entity: str, + role: str, + user_project: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + google_cloud_storage_conn_id: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = google_cloud_storage_conn_id + + self.bucket = bucket + self.entity = entity + self.role = role + self.user_project = user_project + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + hook.insert_bucket_acl( + bucket_name=self.bucket, + entity=self.entity, + role=self.role, + user_project=self.user_project, + ) + + +class GCSObjectCreateAclEntryOperator(BaseOperator): + """ + Creates a new ACL entry on the specified object. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GCSObjectCreateAclEntryOperator` + + :param bucket: Name of a bucket. + :type bucket: str + :param object_name: Name of the object. For information about how to URL encode object + names to be path safe, see: + https://cloud.google.com/storage/docs/json_api/#encoding + :type object_name: str + :param entity: The entity holding the permission, in one of the following forms: + user-userId, user-email, group-groupId, group-email, domain-domain, + project-team-projectId, allUsers, allAuthenticatedUsers + :type entity: str + :param role: The access permission for the entity. + Acceptable values are: "OWNER", "READER". + :type role: str + :param generation: Optional. If present, selects a specific revision of this object. + :type generation: long + :param user_project: (Optional) The project to be billed for this request. + Required for Requester Pays buckets. + :type user_project: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcs_object_create_acl_template_fields] + template_fields = ( + "bucket", + "object_name", + "entity", + "generation", + "role", + "user_project", + "impersonation_chain", + ) + # [END gcs_object_create_acl_template_fields] + + @apply_defaults + def __init__( + self, + *, + bucket: str, + object_name: str, + entity: str, + role: str, + generation: Optional[int] = None, + user_project: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + google_cloud_storage_conn_id: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = google_cloud_storage_conn_id + + self.bucket = bucket + self.object_name = object_name + self.entity = entity + self.role = role + self.generation = generation + self.user_project = user_project + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + hook.insert_object_acl( + bucket_name=self.bucket, + object_name=self.object_name, + entity=self.entity, + role=self.role, + generation=self.generation, + user_project=self.user_project, + ) + + +class GCSFileTransformOperator(BaseOperator): + """ + Copies data from a source GCS location to a temporary location on the + local filesystem. Runs a transformation on this file as specified by + the transformation script and uploads the output to a destination bucket. + If the output bucket is not specified the original file will be + overwritten. + + The locations of the source and the destination files in the local + filesystem is provided as an first and second arguments to the + transformation script. The transformation script is expected to read the + data from source, transform it and write the output to the local + destination file. + + :param source_bucket: The key to be retrieved from S3. (templated) + :type source_bucket: str + :param destination_bucket: The key to be written from S3. (templated) + :type destination_bucket: str + :param transform_script: location of the executable transformation script or list of arguments + passed to subprocess ex. `['python', 'script.py', 10]`. (templated) + :type transform_script: Union[str, List[str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "source_bucket", + "destination_bucket", + "transform_script", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + source_bucket: str, + source_object: str, + transform_script: Union[str, List[str]], + destination_bucket: Optional[str] = None, + destination_object: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.source_bucket = source_bucket + self.source_object = source_object + self.destination_bucket = destination_bucket or self.source_bucket + self.destination_object = destination_object or self.source_object + + self.gcp_conn_id = gcp_conn_id + self.transform_script = transform_script + self.output_encoding = sys.getdefaultencoding() + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + + with NamedTemporaryFile() as source_file, NamedTemporaryFile() as destination_file: + self.log.info("Downloading file from %s", self.source_bucket) + hook.download( + bucket_name=self.source_bucket, + object_name=self.source_object, + filename=source_file.name, + ) + + self.log.info("Starting the transformation") + cmd = ( + [self.transform_script] + if isinstance(self.transform_script, str) + else self.transform_script + ) + cmd += [source_file.name, destination_file.name] + process = subprocess.Popen( + args=cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + close_fds=True, + ) + self.log.info("Process output:") + if process.stdout: + for line in iter(process.stdout.readline, b""): + self.log.info(line.decode(self.output_encoding).rstrip()) + + process.wait() + if process.returncode: + raise AirflowException(f"Transform script failed: {process.returncode}") + + self.log.info( + "Transformation succeeded. Output temporarily located at %s", + destination_file.name, + ) + + self.log.info( + "Uploading file to %s as %s", + self.destination_bucket, + self.destination_object, + ) + hook.upload( + bucket_name=self.destination_bucket, + object_name=self.destination_object, + filename=destination_file.name, + ) + + +class GCSTimeSpanFileTransformOperator(BaseOperator): + """ + Determines a list of objects that were added or modified at a GCS source + location during a specific time-span, copies them to a temporary location + on the local file system, runs a transform on this file as specified by + the transformation script and uploads the output to the destination bucket. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GCSTimeSpanFileTransformOperator` + + The locations of the source and the destination files in the local + filesystem is provided as an first and second arguments to the + transformation script. The time-span is passed to the transform script as + third and fourth argument as UTC ISO 8601 string. + + The transformation script is expected to read the + data from source, transform it and write the output to the local + destination file. + + :param source_bucket: The bucket to fetch data from. (templated) + :type source_bucket: str + :param source_prefix: Prefix string which filters objects whose name begin with + this prefix. Can interpolate execution date and time components. (templated) + :type source_prefix: str + :param source_gcp_conn_id: The connection ID to use connecting to Google Cloud + to download files to be processed. + :type source_gcp_conn_id: str + :param source_impersonation_chain: Optional service account to impersonate using short-term + credentials (to download files to be processed), or chained list of accounts required to + get the access_token of the last account in the list, which will be impersonated in the + request. If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type source_impersonation_chain: Union[str, Sequence[str]] + + :param destination_bucket: The bucket to write data to. (templated) + :type destination_bucket: str + :param destination_prefix: Prefix string for the upload location. + Can interpolate execution date and time components. (templated) + :type destination_prefix: str + :param destination_gcp_conn_id: The connection ID to use connecting to Google Cloud + to upload processed files. + :type destination_gcp_conn_id: str + :param destination_impersonation_chain: Optional service account to impersonate using short-term + credentials (to upload processed files), or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type destination_impersonation_chain: Union[str, Sequence[str]] + + :param transform_script: location of the executable transformation script or list of arguments + passed to subprocess ex. `['python', 'script.py', 10]`. (templated) + :type transform_script: Union[str, List[str]] + + + :param chunk_size: The size of a chunk of data when downloading or uploading (in bytes). + This must be a multiple of 256 KB (per the google clout storage API specification). + :type chunk_size: Optional[int] + :param download_continue_on_fail: With this set to true, if a download fails the task does not error out + but will still continue. + :type download_num_attempts: int + :param upload_chunk_size: The size of a chunk of data when uploading (in bytes). + This must be a multiple of 256 KB (per the google clout storage API specification). + :type download_chunk_size: Optional[int] + :param upload_continue_on_fail: With this set to true, if an upload fails the task does not error out + but will still continue. + :type download_chunk_size: Optional[bool] + :param upload_num_attempts: Number of attempts to try to upload a single file. + :type upload_num_attempts: int + """ + + template_fields = ( + "source_bucket", + "source_prefix", + "destination_bucket", + "destination_prefix", + "transform_script", + "source_impersonation_chain", + "destination_impersonation_chain", + ) + + @staticmethod + def interpolate_prefix( + prefix: str, dt: datetime.datetime + ) -> Optional[datetime.datetime]: + """Interpolate prefix with datetime. + + :param prefix: The prefix to interpolate + :type prefix: str + :param dt: The datetime to interpolate + :type dt: datetime + + """ + return dt.strftime(prefix) if prefix else None + + @apply_defaults + def __init__( + self, + *, + source_bucket: str, + source_prefix: str, + source_gcp_conn_id: str, + destination_bucket: str, + destination_prefix: str, + destination_gcp_conn_id: str, + transform_script: Union[str, List[str]], + source_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + destination_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + chunk_size: Optional[int] = None, + download_continue_on_fail: Optional[bool] = False, + download_num_attempts: int = 1, + upload_continue_on_fail: Optional[bool] = False, + upload_num_attempts: int = 1, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.source_bucket = source_bucket + self.source_prefix = source_prefix + self.source_gcp_conn_id = source_gcp_conn_id + self.source_impersonation_chain = source_impersonation_chain + + self.destination_bucket = destination_bucket + self.destination_prefix = destination_prefix + self.destination_gcp_conn_id = destination_gcp_conn_id + self.destination_impersonation_chain = destination_impersonation_chain + + self.transform_script = transform_script + self.output_encoding = sys.getdefaultencoding() + + self.chunk_size = chunk_size + self.download_continue_on_fail = download_continue_on_fail + self.download_num_attempts = download_num_attempts + self.upload_continue_on_fail = upload_continue_on_fail + self.upload_num_attempts = upload_num_attempts + + def execute(self, context: dict) -> None: + # Define intervals and prefixes. + timespan_start = context["execution_date"] + timespan_end = context["dag"].following_schedule(timespan_start) + if timespan_end is None: + self.log.warning( + "No following schedule found, setting timespan end to max %s", + timespan_end, + ) + timespan_end = datetime.datetime.max + + timespan_start = timespan_start.replace(tzinfo=timezone.utc) + timespan_end = timespan_end.replace(tzinfo=timezone.utc) + + source_prefix_interp = GCSTimeSpanFileTransformOperator.interpolate_prefix( + self.source_prefix, + timespan_start, + ) + destination_prefix_interp = GCSTimeSpanFileTransformOperator.interpolate_prefix( + self.destination_prefix, + timespan_start, + ) + + source_hook = GCSHook( + gcp_conn_id=self.source_gcp_conn_id, + impersonation_chain=self.source_impersonation_chain, + ) + destination_hook = GCSHook( + gcp_conn_id=self.destination_gcp_conn_id, + impersonation_chain=self.destination_impersonation_chain, + ) + + # Fetch list of files. + blobs_to_transform = source_hook.list_by_timespan( + bucket_name=self.source_bucket, + prefix=source_prefix_interp, + timespan_start=timespan_start, + timespan_end=timespan_end, + ) + + with TemporaryDirectory() as temp_input_dir, TemporaryDirectory() as temp_output_dir: + temp_input_dir = Path(temp_input_dir) + temp_output_dir = Path(temp_output_dir) + + # TODO: download in parallel. + for blob_to_transform in blobs_to_transform: + destination_file = temp_input_dir / blob_to_transform + destination_file.parent.mkdir(parents=True, exist_ok=True) + try: + source_hook.download( + bucket_name=self.source_bucket, + object_name=blob_to_transform, + filename=str(destination_file), + chunk_size=self.chunk_size, + num_max_attempts=self.download_num_attempts, + ) + except GoogleCloudError: + if self.download_continue_on_fail: + continue + raise + + self.log.info("Starting the transformation") + cmd = ( + [self.transform_script] + if isinstance(self.transform_script, str) + else self.transform_script + ) + cmd += [ + str(temp_input_dir), + str(temp_output_dir), + timespan_start.replace(microsecond=0).isoformat(), + timespan_end.replace(microsecond=0).isoformat(), + ] + process = subprocess.Popen( + args=cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + close_fds=True, + ) + self.log.info("Process output:") + if process.stdout: + for line in iter(process.stdout.readline, b""): + self.log.info(line.decode(self.output_encoding).rstrip()) + + process.wait() + if process.returncode: + raise AirflowException(f"Transform script failed: {process.returncode}") + + self.log.info( + "Transformation succeeded. Output temporarily located at %s", + temp_output_dir, + ) + + files_uploaded = [] + + # TODO: upload in parallel. + for upload_file in temp_output_dir.glob("**/*"): + if upload_file.is_dir(): + continue + + upload_file_name = str(upload_file.relative_to(temp_output_dir)) + + if self.destination_prefix is not None: + upload_file_name = f"{destination_prefix_interp}/{upload_file_name}" + + self.log.info("Uploading file %s to %s", upload_file, upload_file_name) + + try: + destination_hook.upload( + bucket_name=self.destination_bucket, + object_name=upload_file_name, + filename=str(upload_file), + chunk_size=self.chunk_size, + num_max_attempts=self.upload_num_attempts, + ) + files_uploaded.append(str(upload_file_name)) + except GoogleCloudError: + if self.upload_continue_on_fail: + continue + raise + + return files_uploaded + + +class GCSDeleteBucketOperator(BaseOperator): + """ + Deletes bucket from a Google Cloud Storage. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GCSDeleteBucketOperator` + + :param bucket_name: name of the bucket which will be deleted + :type bucket_name: str + :param force: false not allow to delete non empty bucket, set force=True + allows to delete non empty bucket + :type: bool + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "bucket_name", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + bucket_name: str, + force: bool = True, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.bucket_name = bucket_name + self.force: bool = force + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + hook.delete_bucket(bucket_name=self.bucket_name, force=self.force) + + +class GCSSynchronizeBucketsOperator(BaseOperator): + """ + Synchronizes the contents of the buckets or bucket's directories in the Google Cloud Services. + + Parameters ``source_object`` and ``destination_object`` describe the root sync directory. If they are + not passed, the entire bucket will be synchronized. They should point to directories. + + .. note:: + The synchronization of individual files is not supported. Only entire directories can be + synchronized. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GCSSynchronizeBuckets` + + :param source_bucket: The name of the bucket containing the source objects. + :type source_bucket: str + :param destination_bucket: The name of the bucket containing the destination objects. + :type destination_bucket: str + :param source_object: The root sync directory in the source bucket. + :type source_object: Optional[str] + :param destination_object: The root sync directory in the destination bucket. + :type destination_object: Optional[str] + :param recursive: If True, subdirectories will be considered + :type recursive: bool + :param allow_overwrite: if True, the files will be overwritten if a mismatched file is found. + By default, overwriting files is not allowed + :type allow_overwrite: bool + :param delete_extra_files: if True, deletes additional files from the source that not found in the + destination. By default extra files are not deleted. + + .. note:: + This option can delete data quickly if you specify the wrong source/destination combination. + + :type delete_extra_files: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "source_bucket", + "destination_bucket", + "source_object", + "destination_object", + "recursive", + "delete_extra_files", + "allow_overwrite", + "gcp_conn_id", + "delegate_to", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + source_bucket: str, + destination_bucket: str, + source_object: Optional[str] = None, + destination_object: Optional[str] = None, + recursive: bool = True, + delete_extra_files: bool = False, + allow_overwrite: bool = False, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.source_bucket = source_bucket + self.destination_bucket = destination_bucket + self.source_object = source_object + self.destination_object = destination_object + self.recursive = recursive + self.delete_extra_files = delete_extra_files + self.allow_overwrite = allow_overwrite + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + hook.sync( + source_bucket=self.source_bucket, + destination_bucket=self.destination_bucket, + source_object=self.source_object, + destination_object=self.destination_object, + recursive=self.recursive, + delete_extra_files=self.delete_extra_files, + allow_overwrite=self.allow_overwrite, + ) diff --git a/reference/providers/google/cloud/operators/kubernetes_engine.py b/reference/providers/google/cloud/operators/kubernetes_engine.py new file mode 100644 index 0000000..e50eeda --- /dev/null +++ b/reference/providers/google/cloud/operators/kubernetes_engine.py @@ -0,0 +1,351 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains Google Kubernetes Engine operators.""" + +import os +import tempfile +from typing import Dict, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import ( + KubernetesPodOperator, +) +from airflow.providers.google.cloud.hooks.kubernetes_engine import GKEHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.utils.decorators import apply_defaults +from airflow.utils.process_utils import execute_in_subprocess, patch_environ +from google.cloud.container_v1.types import Cluster + + +class GKEDeleteClusterOperator(BaseOperator): + """ + Deletes the cluster, including the Kubernetes endpoint and all worker nodes. + + To delete a certain cluster, you must specify the ``project_id``, the ``name`` + of the cluster, the ``location`` that the cluster is in, and the ``task_id``. + + **Operator Creation**: :: + + operator = GKEClusterDeleteOperator( + task_id='cluster_delete', + project_id='my-project', + location='cluster-location' + name='cluster-name') + + .. seealso:: + For more detail about deleting clusters have a look at the reference: + https://google-cloud-python.readthedocs.io/en/latest/container/gapic/v1/api.html#google.cloud.container_v1.ClusterManagerClient.delete_cluster + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GKEDeleteClusterOperator` + + :param project_id: The Google Developers Console [project ID or project number] + :type project_id: str + :param name: The name of the resource to delete, in this case cluster name + :type name: str + :param location: The name of the Google Compute Engine zone in which the cluster + resides. + :type location: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param api_version: The api version to use + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "project_id", + "gcp_conn_id", + "name", + "location", + "api_version", + "impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + name: str, + location: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v2", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.location = location + self.api_version = api_version + self.name = name + self.impersonation_chain = impersonation_chain + self._check_input() + + def _check_input(self) -> None: + if not all([self.project_id, self.name, self.location]): + self.log.error( + "One of (project_id, name, location) is missing or incorrect" + ) + raise AirflowException("Operator has incorrect or missing input.") + + def execute(self, context) -> Optional[str]: + hook = GKEHook( + gcp_conn_id=self.gcp_conn_id, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + delete_result = hook.delete_cluster(name=self.name, project_id=self.project_id) + return delete_result + + +class GKECreateClusterOperator(BaseOperator): + """ + Create a Google Kubernetes Engine Cluster of specified dimensions + The operator will wait until the cluster is created. + + The **minimum** required to define a cluster to create is: + + ``dict()`` :: + cluster_def = {'name': 'my-cluster-name', + 'initial_node_count': 1} + + or + + ``Cluster`` proto :: + from google.cloud.container_v1.types import Cluster + + cluster_def = Cluster(name='my-cluster-name', initial_node_count=1) + + **Operator Creation**: :: + + operator = GKEClusterCreateOperator( + task_id='cluster_create', + project_id='my-project', + location='my-location' + body=cluster_def) + + .. seealso:: + For more detail on about creating clusters have a look at the reference: + :class:`google.cloud.container_v1.types.Cluster` + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GKECreateClusterOperator` + + :param project_id: The Google Developers Console [project ID or project number] + :type project_id: str + :param location: The name of the Google Compute Engine zone in which the cluster + resides. + :type location: str + :param body: The Cluster definition to create, can be protobuf or python dict, if + dict it must match protobuf message Cluster + :type body: dict or google.cloud.container_v1.types.Cluster + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param api_version: The api version to use + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "project_id", + "gcp_conn_id", + "location", + "api_version", + "body", + "impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + location: str, + body: Optional[Union[Dict, Cluster]], + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v2", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.location = location + self.api_version = api_version + self.body = body + self.impersonation_chain = impersonation_chain + self._check_input() + + def _check_input(self) -> None: + if not all([self.project_id, self.location, self.body]) or not ( + ( + isinstance(self.body, dict) + and "name" in self.body + and "initial_node_count" in self.body + ) + or ( + getattr(self.body, "name", None) + and getattr(self.body, "initial_node_count", None) + ) + ): + self.log.error( + "One of (project_id, location, body, body['name'], " + "body['initial_node_count']) is missing or incorrect" + ) + raise AirflowException("Operator has incorrect or missing input.") + + def execute(self, context) -> str: + hook = GKEHook( + gcp_conn_id=self.gcp_conn_id, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + create_op = hook.create_cluster(cluster=self.body, project_id=self.project_id) + return create_op + + +KUBE_CONFIG_ENV_VAR = "KUBECONFIG" + + +class GKEStartPodOperator(KubernetesPodOperator): + """ + Executes a task in a Kubernetes pod in the specified Google Kubernetes + Engine cluster + + This Operator assumes that the system has gcloud installed and has configured a + connection id with a service account. + + The **minimum** required to define a cluster to create are the variables + ``task_id``, ``project_id``, ``location``, ``cluster_name``, ``name``, + ``namespace``, and ``image`` + + .. seealso:: + For more detail about Kubernetes Engine authentication have a look at the reference: + https://cloud.google.com/kubernetes-engine/docs/how-to/cluster-access-for-kubectl#internal_ip + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GKEStartPodOperator` + + :param location: The name of the Google Kubernetes Engine zone in which the + cluster resides, e.g. 'us-central1-a' + :type location: str + :param cluster_name: The name of the Google Kubernetes Engine cluster the pod + should be spawned in + :type cluster_name: str + :param use_internal_ip: Use the internal IP address as the endpoint. + :param project_id: The Google Developers Console project id + :type project_id: str + :param gcp_conn_id: The google cloud connection id to use. This allows for + users to specify a service account. + :type gcp_conn_id: str + """ + + template_fields = {"project_id", "location", "cluster_name"} | set( + KubernetesPodOperator.template_fields + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + cluster_name: str, + use_internal_ip: bool = False, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.cluster_name = cluster_name + self.gcp_conn_id = gcp_conn_id + self.use_internal_ip = use_internal_ip + + if self.gcp_conn_id is None: + raise AirflowException( + "The gcp_conn_id parameter has become required. If you want to use Application Default " + "Credentials (ADC) strategy for authorization, create an empty connection " + "called `google_cloud_default`.", + ) + + def execute(self, context) -> Optional[str]: + hook = GoogleBaseHook(gcp_conn_id=self.gcp_conn_id) + self.project_id = self.project_id or hook.project_id + + if not self.project_id: + raise AirflowException( + "The project id must be passed either as " + "keyword project_id parameter or as project_id extra " + "in Google Cloud connection definition. Both are not set!" + ) + + # Write config to a temp file and set the environment variable to point to it. + # This is to avoid race conditions of reading/writing a single file + with tempfile.NamedTemporaryFile() as conf_file, patch_environ( + {KUBE_CONFIG_ENV_VAR: conf_file.name} + ), hook.provide_authorized_gcloud(): + # Attempt to get/update credentials + # We call gcloud directly instead of using google-cloud-python api + # because there is no way to write kubernetes config to a file, which is + # required by KubernetesPodOperator. + # The gcloud command looks at the env variable `KUBECONFIG` for where to save + # the kubernetes config file. + cmd = [ + "gcloud", + "container", + "clusters", + "get-credentials", + self.cluster_name, + "--zone", + self.location, + "--project", + self.project_id, + ] + if self.use_internal_ip: + cmd.append("--internal-ip") + execute_in_subprocess(cmd) + + # Tell `KubernetesPodOperator` where the config file is located + self.config_file = os.environ[KUBE_CONFIG_ENV_VAR] + return super().execute(context) diff --git a/reference/providers/google/cloud/operators/life_sciences.py b/reference/providers/google/cloud/operators/life_sciences.py new file mode 100644 index 0000000..33d411b --- /dev/null +++ b/reference/providers/google/cloud/operators/life_sciences.py @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operators that interact with Google Cloud Life Sciences service.""" + +from typing import Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.life_sciences import LifeSciencesHook +from airflow.utils.decorators import apply_defaults + + +class LifeSciencesRunPipelineOperator(BaseOperator): + """ + Runs a Life Sciences Pipeline + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:LifeSciencesRunPipelineOperator` + + :param body: The request body + :type body: dict + :param location: The location of the project + :type location: str + :param project_id: ID of the Google Cloud project if None then + default project_id is used. + :type project_id: str + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (for example v2beta). + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + body: dict, + location: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v2beta", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.body = body + self.location = location + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self._validate_inputs() + self.impersonation_chain = impersonation_chain + + def _validate_inputs(self) -> None: + if not self.body: + raise AirflowException("The required parameter 'body' is missing") + if not self.location: + raise AirflowException("The required parameter 'location' is missing") + + def execute(self, context) -> dict: + hook = LifeSciencesHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + + return hook.run_pipeline( + body=self.body, location=self.location, project_id=self.project_id + ) diff --git a/reference/providers/google/cloud/operators/mlengine.py b/reference/providers/google/cloud/operators/mlengine.py new file mode 100644 index 0000000..15f1bb0 --- /dev/null +++ b/reference/providers/google/cloud/operators/mlengine.py @@ -0,0 +1,1437 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud MLEngine operators.""" +import logging +import re +import warnings +from typing import Dict, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator, BaseOperatorLink +from airflow.models.taskinstance import TaskInstance +from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook +from airflow.utils.decorators import apply_defaults + +log = logging.getLogger(__name__) + + +def _normalize_mlengine_job_id(job_id: str) -> str: + """ + Replaces invalid MLEngine job_id characters with '_'. + + This also adds a leading 'z' in case job_id starts with an invalid + character. + + :param job_id: A job_id str that may have invalid characters. + :type job_id: str: + :return: A valid job_id representation. + :rtype: str + """ + # Add a prefix when a job_id starts with a digit or a template + match = re.search(r"\d|\{{2}", job_id) + if match and match.start() == 0: + job = f"z_{job_id}" + else: + job = job_id + + # Clean up 'bad' characters except templates + tracker = 0 + cleansed_job_id = "" + for match in re.finditer(r"\{{2}.+?\}{2}", job): + cleansed_job_id += re.sub(r"[^0-9a-zA-Z]+", "_", job[tracker : match.start()]) + cleansed_job_id += job[match.start() : match.end()] + tracker = match.end() + + # Clean up last substring or the full string if no templates + cleansed_job_id += re.sub(r"[^0-9a-zA-Z]+", "_", job[tracker:]) + + return cleansed_job_id + + +# pylint: disable=too-many-instance-attributes +class MLEngineStartBatchPredictionJobOperator(BaseOperator): + """ + Start a Google Cloud ML Engine prediction job. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MLEngineStartBatchPredictionJobOperator` + + NOTE: For model origin, users should consider exactly one from the + three options below: + + 1. Populate ``uri`` field only, which should be a GCS location that + points to a tensorflow savedModel directory. + 2. Populate ``model_name`` field only, which refers to an existing + model, and the default version of the model will be used. + 3. Populate both ``model_name`` and ``version_name`` fields, which + refers to a specific version of a specific model. + + In options 2 and 3, both model and version name should contain the + minimal identifier. For instance, call:: + + MLEngineBatchPredictionOperator( + ..., + model_name='my_model', + version_name='my_version', + ...) + + if the desired model version is + ``projects/my_project/models/my_model/versions/my_version``. + + See https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs + for further documentation on the parameters. + + :param job_id: A unique id for the prediction job on Google Cloud + ML Engine. (templated) + :type job_id: str + :param data_format: The format of the input data. + It will default to 'DATA_FORMAT_UNSPECIFIED' if is not provided + or is not one of ["TEXT", "TF_RECORD", "TF_RECORD_GZIP"]. + :type data_format: str + :param input_paths: A list of GCS paths of input data for batch + prediction. Accepting wildcard operator ``*``, but only at the end. (templated) + :type input_paths: list[str] + :param output_path: The GCS path where the prediction results are + written to. (templated) + :type output_path: str + :param region: The Google Compute Engine region to run the + prediction job in. (templated) + :type region: str + :param model_name: The Google Cloud ML Engine model to use for prediction. + If version_name is not provided, the default version of this + model will be used. + Should not be None if version_name is provided. + Should be None if uri is provided. (templated) + :type model_name: str + :param version_name: The Google Cloud ML Engine model version to use for + prediction. + Should be None if uri is provided. (templated) + :type version_name: str + :param uri: The GCS path of the saved model to use for prediction. + Should be None if model_name is provided. + It should be a GCS path pointing to a tensorflow SavedModel. (templated) + :type uri: str + :param max_worker_count: The maximum number of workers to be used + for parallel processing. Defaults to 10 if not specified. Should be a + string representing the worker count ("10" instead of 10, "50" instead + of 50, etc.) + :type max_worker_count: str + :param runtime_version: The Google Cloud ML Engine runtime version to use + for batch prediction. + :type runtime_version: str + :param signature_name: The name of the signature defined in the SavedModel + to use for this job. + :type signature_name: str + :param project_id: The Google Cloud project name where the prediction job is submitted. + If set to None or missing, the default project_id from the Google Cloud connection is used. + (templated) + :type project_id: str + :param gcp_conn_id: The connection ID used for connection to Google + Cloud Platform. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param labels: a dictionary containing labels for the job; passed to BigQuery + :type labels: Dict[str, str] + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :raises: ``ValueError``: if a unique model/version origin cannot be + determined. + """ + + template_fields = [ + "_project_id", + "_job_id", + "_region", + "_input_paths", + "_output_path", + "_model_name", + "_version_name", + "_uri", + "_impersonation_chain", + ] + + @apply_defaults + def __init__( + self, # pylint: disable=too-many-arguments + *, + job_id: str, + region: str, + data_format: str, + input_paths: List[str], + output_path: str, + model_name: Optional[str] = None, + version_name: Optional[str] = None, + uri: Optional[str] = None, + max_worker_count: Optional[int] = None, + runtime_version: Optional[str] = None, + signature_name: Optional[str] = None, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self._project_id = project_id + self._job_id = job_id + self._region = region + self._data_format = data_format + self._input_paths = input_paths + self._output_path = output_path + self._model_name = model_name + self._version_name = version_name + self._uri = uri + self._max_worker_count = max_worker_count + self._runtime_version = runtime_version + self._signature_name = signature_name + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._labels = labels + self._impersonation_chain = impersonation_chain + + if not self._project_id: + raise AirflowException("Google Cloud project id is required.") + if not self._job_id: + raise AirflowException( + "An unique job id is required for Google MLEngine prediction job." + ) + + if self._uri: + if self._model_name or self._version_name: + raise AirflowException( + "Ambiguous model origin: Both uri and model/version name are provided." + ) + + if self._version_name and not self._model_name: + raise AirflowException( + "Missing model: Batch prediction expects a model name when a version name is provided." + ) + + if not (self._uri or self._model_name): + raise AirflowException( + "Missing model origin: Batch prediction expects a model, " + "a model & version combination, or a URI to a savedModel." + ) + + def execute(self, context): + job_id = _normalize_mlengine_job_id(self._job_id) + prediction_request = { + "jobId": job_id, + "predictionInput": { + "dataFormat": self._data_format, + "inputPaths": self._input_paths, + "outputPath": self._output_path, + "region": self._region, + }, + } + if self._labels: + prediction_request["labels"] = self._labels + + if self._uri: + prediction_request["predictionInput"]["uri"] = self._uri + elif self._model_name: + origin_name = f"projects/{self._project_id}/models/{self._model_name}" + if not self._version_name: + prediction_request["predictionInput"]["modelName"] = origin_name + else: + prediction_request["predictionInput"][ + "versionName" + ] = origin_name + "/versions/{}".format(self._version_name) + + if self._max_worker_count: + prediction_request["predictionInput"][ + "maxWorkerCount" + ] = self._max_worker_count + + if self._runtime_version: + prediction_request["predictionInput"][ + "runtimeVersion" + ] = self._runtime_version + + if self._signature_name: + prediction_request["predictionInput"][ + "signatureName" + ] = self._signature_name + + hook = MLEngineHook( + self._gcp_conn_id, + self._delegate_to, + impersonation_chain=self._impersonation_chain, + ) + + # Helper method to check if the existing job's prediction input is the + # same as the request we get here. + def check_existing_job(existing_job): + return ( + existing_job.get("predictionInput") + == prediction_request["predictionInput"] + ) + + finished_prediction_job = hook.create_job( + project_id=self._project_id, + job=prediction_request, + use_existing_job_fn=check_existing_job, + ) + + if finished_prediction_job["state"] != "SUCCEEDED": + self.log.error( + "MLEngine batch prediction job failed: %s", str(finished_prediction_job) + ) + raise RuntimeError(finished_prediction_job["errorMessage"]) + + return finished_prediction_job["predictionOutput"] + + +class MLEngineManageModelOperator(BaseOperator): + """ + Operator for managing a Google Cloud ML Engine model. + + .. warning:: + This operator is deprecated. Consider using operators for specific operations: + MLEngineCreateModelOperator, MLEngineGetModelOperator. + + :param model: A dictionary containing the information about the model. + If the `operation` is `create`, then the `model` parameter should + contain all the information about this model such as `name`. + + If the `operation` is `get`, the `model` parameter + should contain the `name` of the model. + :type model: dict + :param operation: The operation to perform. Available operations are: + + * ``create``: Creates a new model as provided by the `model` parameter. + * ``get``: Gets a particular model where the name is specified in `model`. + :type operation: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + (templated) + :type project_id: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "_project_id", + "_model", + "_impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + model: dict, + operation: str = "create", + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + warnings.warn( + "This operator is deprecated. Consider using operators for specific operations: " + "MLEngineCreateModelOperator, MLEngineGetModelOperator.", + DeprecationWarning, + stacklevel=3, + ) + + self._project_id = project_id + self._model = model + self._operation = operation + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._impersonation_chain = impersonation_chain + + def execute(self, context): + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, + delegate_to=self._delegate_to, + impersonation_chain=self._impersonation_chain, + ) + if self._operation == "create": + return hook.create_model(project_id=self._project_id, model=self._model) + elif self._operation == "get": + return hook.get_model( + project_id=self._project_id, model_name=self._model["name"] + ) + else: + raise ValueError(f"Unknown operation: {self._operation}") + + +class MLEngineCreateModelOperator(BaseOperator): + """ + Creates a new model. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MLEngineCreateModelOperator` + + The model should be provided by the `model` parameter. + + :param model: A dictionary containing the information about the model. + :type model: dict + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + (templated) + :type project_id: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "_project_id", + "_model", + "_impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + model: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self._project_id = project_id + self._model = model + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._impersonation_chain = impersonation_chain + + def execute(self, context): + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, + delegate_to=self._delegate_to, + impersonation_chain=self._impersonation_chain, + ) + return hook.create_model(project_id=self._project_id, model=self._model) + + +class MLEngineGetModelOperator(BaseOperator): + """ + Gets a particular model + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MLEngineGetModelOperator` + + The name of model should be specified in `model_name`. + + :param model_name: The name of the model. + :type model_name: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + (templated) + :type project_id: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "_project_id", + "_model_name", + "_impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + model_name: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self._project_id = project_id + self._model_name = model_name + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._impersonation_chain = impersonation_chain + + def execute(self, context): + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, + delegate_to=self._delegate_to, + impersonation_chain=self._impersonation_chain, + ) + return hook.get_model(project_id=self._project_id, model_name=self._model_name) + + +class MLEngineDeleteModelOperator(BaseOperator): + """ + Deletes a model. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MLEngineDeleteModelOperator` + + The model should be provided by the `model_name` parameter. + + :param model_name: The name of the model. + :type model_name: str + :param delete_contents: (Optional) Whether to force the deletion even if the models is not empty. + Will delete all version (if any) in the dataset if set to True. + The default value is False. + :type delete_contents: bool + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + (templated) + :type project_id: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "_project_id", + "_model_name", + "_impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + model_name: str, + delete_contents: bool = False, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self._project_id = project_id + self._model_name = model_name + self._delete_contents = delete_contents + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._impersonation_chain = impersonation_chain + + def execute(self, context): + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, + delegate_to=self._delegate_to, + impersonation_chain=self._impersonation_chain, + ) + + return hook.delete_model( + project_id=self._project_id, + model_name=self._model_name, + delete_contents=self._delete_contents, + ) + + +class MLEngineManageVersionOperator(BaseOperator): + """ + Operator for managing a Google Cloud ML Engine version. + + .. warning:: + This operator is deprecated. Consider using operators for specific operations: + MLEngineCreateVersionOperator, MLEngineSetDefaultVersionOperator, + MLEngineListVersionsOperator, MLEngineDeleteVersionOperator. + + :param model_name: The name of the Google Cloud ML Engine model that the version + belongs to. (templated) + :type model_name: str + :param version_name: A name to use for the version being operated upon. + If not None and the `version` argument is None or does not have a value for + the `name` key, then this will be populated in the payload for the + `name` key. (templated) + :type version_name: str + :param version: A dictionary containing the information about the version. + If the `operation` is `create`, `version` should contain all the + information about this version such as name, and deploymentUrl. + If the `operation` is `get` or `delete`, the `version` parameter + should contain the `name` of the version. + If it is None, the only `operation` possible would be `list`. (templated) + :type version: dict + :param operation: The operation to perform. Available operations are: + + * ``create``: Creates a new version in the model specified by `model_name`, + in which case the `version` parameter should contain all the + information to create that version + (e.g. `name`, `deploymentUrl`). + + * ``set_defaults``: Sets a version in the model specified by `model_name` to be the default. + The name of the version should be specified in the `version` + parameter. + + * ``list``: Lists all available versions of the model specified + by `model_name`. + + * ``delete``: Deletes the version specified in `version` parameter from the + model specified by `model_name`). + The name of the version should be specified in the `version` + parameter. + :type operation: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + (templated) + :type project_id: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "_project_id", + "_model_name", + "_version_name", + "_version", + "_impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + model_name: str, + version_name: Optional[str] = None, + version: Optional[dict] = None, + operation: str = "create", + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self._project_id = project_id + self._model_name = model_name + self._version_name = version_name + self._version = version or {} + self._operation = operation + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._impersonation_chain = impersonation_chain + + warnings.warn( + "This operator is deprecated. Consider using operators for specific operations: " + "MLEngineCreateVersion, MLEngineSetDefaultVersion, MLEngineListVersions, MLEngineDeleteVersion.", + DeprecationWarning, + stacklevel=3, + ) + + def execute(self, context): + if "name" not in self._version: + self._version["name"] = self._version_name + + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, + delegate_to=self._delegate_to, + impersonation_chain=self._impersonation_chain, + ) + + if self._operation == "create": + if not self._version: + raise ValueError( + f"version attribute of {self.__class__.__name__} could not be empty" + ) + return hook.create_version( + project_id=self._project_id, + model_name=self._model_name, + version_spec=self._version, + ) + elif self._operation == "set_default": + return hook.set_default_version( + project_id=self._project_id, + model_name=self._model_name, + version_name=self._version["name"], + ) + elif self._operation == "list": + return hook.list_versions( + project_id=self._project_id, model_name=self._model_name + ) + elif self._operation == "delete": + return hook.delete_version( + project_id=self._project_id, + model_name=self._model_name, + version_name=self._version["name"], + ) + else: + raise ValueError(f"Unknown operation: {self._operation}") + + +class MLEngineCreateVersionOperator(BaseOperator): + """ + Creates a new version in the model + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MLEngineCreateVersionOperator` + + Model should be specified by `model_name`, in which case the `version` parameter should contain all the + information to create that version + + :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) + :type model_name: str + :param version: A dictionary containing the information about the version. (templated) + :type version: dict + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + (templated) + :type project_id: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "_project_id", + "_model_name", + "_version", + "_impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + model_name: str, + version: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self._project_id = project_id + self._model_name = model_name + self._version = version + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._impersonation_chain = impersonation_chain + self._validate_inputs() + + def _validate_inputs(self): + if not self._model_name: + raise AirflowException("The model_name parameter could not be empty.") + + if not self._version: + raise AirflowException("The version parameter could not be empty.") + + def execute(self, context): + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, + delegate_to=self._delegate_to, + impersonation_chain=self._impersonation_chain, + ) + + return hook.create_version( + project_id=self._project_id, + model_name=self._model_name, + version_spec=self._version, + ) + + +class MLEngineSetDefaultVersionOperator(BaseOperator): + """ + Sets a version in the model. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MLEngineSetDefaultVersionOperator` + + The model should be specified by `model_name` to be the default. The name of the version should be + specified in the `version_name` parameter. + + :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) + :type model_name: str + :param version_name: A name to use for the version being operated upon. (templated) + :type version_name: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + (templated) + :type project_id: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "_project_id", + "_model_name", + "_version_name", + "_impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + model_name: str, + version_name: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self._project_id = project_id + self._model_name = model_name + self._version_name = version_name + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._impersonation_chain = impersonation_chain + self._validate_inputs() + + def _validate_inputs(self): + if not self._model_name: + raise AirflowException("The model_name parameter could not be empty.") + + if not self._version_name: + raise AirflowException("The version_name parameter could not be empty.") + + def execute(self, context): + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, + delegate_to=self._delegate_to, + impersonation_chain=self._impersonation_chain, + ) + + return hook.set_default_version( + project_id=self._project_id, + model_name=self._model_name, + version_name=self._version_name, + ) + + +class MLEngineListVersionsOperator(BaseOperator): + """ + Lists all available versions of the model + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MLEngineListVersionsOperator` + + The model should be specified by `model_name`. + + :param model_name: The name of the Google Cloud ML Engine model that the version + belongs to. (templated) + :type model_name: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + (templated) + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "_project_id", + "_model_name", + "_impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + model_name: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self._project_id = project_id + self._model_name = model_name + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._impersonation_chain = impersonation_chain + self._validate_inputs() + + def _validate_inputs(self): + if not self._model_name: + raise AirflowException("The model_name parameter could not be empty.") + + def execute(self, context): + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, + delegate_to=self._delegate_to, + impersonation_chain=self._impersonation_chain, + ) + + return hook.list_versions( + project_id=self._project_id, + model_name=self._model_name, + ) + + +class MLEngineDeleteVersionOperator(BaseOperator): + """ + Deletes the version from the model. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MLEngineDeleteVersionOperator` + + The name of the version should be specified in `version_name` parameter from the model specified + by `model_name`. + + :param model_name: The name of the Google Cloud ML Engine model that the version + belongs to. (templated) + :type model_name: str + :param version_name: A name to use for the version being operated upon. (templated) + :type version_name: str + :param project_id: The Google Cloud project name to which MLEngine + model belongs. + :type project_id: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "_project_id", + "_model_name", + "_version_name", + "_impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + model_name: str, + version_name: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self._project_id = project_id + self._model_name = model_name + self._version_name = version_name + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._impersonation_chain = impersonation_chain + self._validate_inputs() + + def _validate_inputs(self): + if not self._model_name: + raise AirflowException("The model_name parameter could not be empty.") + + if not self._version_name: + raise AirflowException("The version_name parameter could not be empty.") + + def execute(self, context): + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, + delegate_to=self._delegate_to, + impersonation_chain=self._impersonation_chain, + ) + + return hook.delete_version( + project_id=self._project_id, + model_name=self._model_name, + version_name=self._version_name, + ) + + +class AIPlatformConsoleLink(BaseOperatorLink): + """Helper class for constructing AI Platform Console link.""" + + name = "AI Platform Console" + + def get_link(self, operator, dttm): + task_instance = TaskInstance(task=operator, execution_date=dttm) + gcp_metadata_dict = task_instance.xcom_pull( + task_ids=operator.task_id, key="gcp_metadata" + ) + if not gcp_metadata_dict: + return "" + job_id = gcp_metadata_dict["job_id"] + project_id = gcp_metadata_dict["project_id"] + console_link = f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}" + return console_link + + +# pylint: disable=too-many-instance-attributes +class MLEngineStartTrainingJobOperator(BaseOperator): + """ + Operator for launching a MLEngine training job. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MLEngineStartTrainingJobOperator` + + :param job_id: A unique templated id for the submitted Google MLEngine + training job. (templated) + :type job_id: str + :param region: The Google Compute Engine region to run the MLEngine training + job in (templated). + :type region: str + :param package_uris: A list of Python package locations for the training + job, which should include the main training program and any additional + dependencies. This is mutually exclusive with a custom image specified + via master_config. (templated) + :type package_uris: List[str] + :param training_python_module: The name of the Python module to run within + the training job after installing the packages. This is mutually + exclusive with a custom image specified via master_config. (templated) + :type training_python_module: str + :param training_args: A list of command-line arguments to pass to the + training program. (templated) + :type training_args: List[str] + :param scale_tier: Resource tier for MLEngine training job. (templated) + :type scale_tier: str + :param master_type: The type of virtual machine to use for the master + worker. It must be set whenever scale_tier is CUSTOM. (templated) + :type master_type: str + :param master_config: The configuration for the master worker. If this is + provided, master_type must be set as well. If a custom image is + specified, this is mutually exclusive with package_uris and + training_python_module. (templated) + :type master_config: dict + :param runtime_version: The Google Cloud ML runtime version to use for + training. (templated) + :type runtime_version: str + :param python_version: The version of Python used in training. (templated) + :type python_version: str + :param job_dir: A Google Cloud Storage path in which to store training + outputs and other data needed for training. (templated) + :type job_dir: str + :param service_account: Optional service account to use when running the training application. + (templated) + The specified service account must have the `iam.serviceAccounts.actAs` role. The + Google-managed Cloud ML Engine service account must have the `iam.serviceAccountAdmin` role + for the specified service account. + If set to None or missing, the Google-managed Cloud ML Engine service account will be used. + :type service_account: str + :param project_id: The Google Cloud project name within which MLEngine training job should run. + If set to None or missing, the default project_id from the Google Cloud connection is used. + (templated) + :type project_id: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param mode: Can be one of 'DRY_RUN'/'CLOUD'. In 'DRY_RUN' mode, no real + training job will be launched, but the MLEngine training job request + will be printed out. In 'CLOUD' mode, a real MLEngine training job + creation request will be issued. + :type mode: str + :param labels: a dictionary containing labels for the job; passed to BigQuery + :type labels: Dict[str, str] + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "_project_id", + "_job_id", + "_region", + "_package_uris", + "_training_python_module", + "_training_args", + "_scale_tier", + "_master_type", + "_master_config", + "_runtime_version", + "_python_version", + "_job_dir", + "_service_account", + "_impersonation_chain", + ] + + operator_extra_links = (AIPlatformConsoleLink(),) + + @apply_defaults + def __init__( + self, # pylint: disable=too-many-arguments + *, + job_id: str, + region: str, + package_uris: List[str] = None, + training_python_module: str = None, + training_args: List[str] = None, + scale_tier: Optional[str] = None, + master_type: Optional[str] = None, + master_config: Optional[Dict] = None, + runtime_version: Optional[str] = None, + python_version: Optional[str] = None, + job_dir: Optional[str] = None, + service_account: Optional[str] = None, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + mode: str = "PRODUCTION", + labels: Optional[Dict[str, str]] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self._project_id = project_id + self._job_id = job_id + self._region = region + self._package_uris = package_uris + self._training_python_module = training_python_module + self._training_args = training_args + self._scale_tier = scale_tier + self._master_type = master_type + self._master_config = master_config + self._runtime_version = runtime_version + self._python_version = python_version + self._job_dir = job_dir + self._service_account = service_account + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._mode = mode + self._labels = labels + self._impersonation_chain = impersonation_chain + + custom = self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM" + custom_image = ( + custom + and self._master_config is not None + and self._master_config.get("imageUri", None) is not None + ) + + if not self._project_id: + raise AirflowException("Google Cloud project id is required.") + if not self._job_id: + raise AirflowException( + "An unique job id is required for Google MLEngine training job." + ) + if not self._region: + raise AirflowException("Google Compute Engine region is required.") + if custom and not self._master_type: + raise AirflowException("master_type must be set when scale_tier is CUSTOM") + if self._master_config and not self._master_type: + raise AirflowException( + "master_type must be set when master_config is provided" + ) + if not (package_uris and training_python_module) and not custom_image: + raise AirflowException( + "Either a Python package with a Python module or a custom Docker image should be provided." + ) + if (package_uris or training_python_module) and custom_image: + raise AirflowException( + "Either a Python package with a Python module or " + "a custom Docker image should be provided but not both." + ) + + def execute(self, context): + job_id = _normalize_mlengine_job_id(self._job_id) + training_request = { + "jobId": job_id, + "trainingInput": { + "scaleTier": self._scale_tier, + "region": self._region, + }, + } + if self._package_uris: + training_request["trainingInput"]["packageUris"] = self._package_uris + + if self._training_python_module: + training_request["trainingInput"][ + "pythonModule" + ] = self._training_python_module + + if self._training_args: + training_request["trainingInput"]["args"] = self._training_args + + if self._master_type: + training_request["trainingInput"]["masterType"] = self._master_type + + if self._master_config: + training_request["trainingInput"]["masterConfig"] = self._master_config + + if self._runtime_version: + training_request["trainingInput"]["runtimeVersion"] = self._runtime_version + + if self._python_version: + training_request["trainingInput"]["pythonVersion"] = self._python_version + + if self._job_dir: + training_request["trainingInput"]["jobDir"] = self._job_dir + + if self._service_account: + training_request["trainingInput"]["serviceAccount"] = self._service_account + + if self._labels: + training_request["labels"] = self._labels + + if self._mode == "DRY_RUN": + self.log.info("In dry_run mode.") + self.log.info("MLEngine Training job request is: %s", training_request) + return + + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, + delegate_to=self._delegate_to, + impersonation_chain=self._impersonation_chain, + ) + + # Helper method to check if the existing job's training input is the + # same as the request we get here. + def check_existing_job(existing_job): + existing_training_input = existing_job.get("trainingInput") + requested_training_input = training_request["trainingInput"] + if "scaleTier" not in existing_training_input: + existing_training_input["scaleTier"] = None + + existing_training_input["args"] = existing_training_input.get("args") + requested_training_input["args"] = ( + requested_training_input["args"] + if requested_training_input["args"] + else None + ) + + return existing_training_input == requested_training_input + + finished_training_job = hook.create_job( + project_id=self._project_id, + job=training_request, + use_existing_job_fn=check_existing_job, + ) + + if finished_training_job["state"] != "SUCCEEDED": + self.log.error( + "MLEngine training job failed: %s", str(finished_training_job) + ) + raise RuntimeError(finished_training_job["errorMessage"]) + + gcp_metadata = { + "job_id": job_id, + "project_id": self._project_id, + } + context["task_instance"].xcom_push("gcp_metadata", gcp_metadata) + + +class MLEngineTrainingCancelJobOperator(BaseOperator): + """ + Operator for cleaning up failed MLEngine training job. + + :param job_id: A unique templated id for the submitted Google MLEngine + training job. (templated) + :type job_id: str + :param project_id: The Google Cloud project name within which MLEngine training job should run. + If set to None or missing, the default project_id from the Google Cloud connection is used. + (templated) + :type project_id: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "_project_id", + "_job_id", + "_impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + job_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self._project_id = project_id + self._job_id = job_id + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._impersonation_chain = impersonation_chain + + if not self._project_id: + raise AirflowException("Google Cloud project id is required.") + + def execute(self, context): + + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, + delegate_to=self._delegate_to, + impersonation_chain=self._impersonation_chain, + ) + + hook.cancel_job( + project_id=self._project_id, job_id=_normalize_mlengine_job_id(self._job_id) + ) diff --git a/reference/providers/google/cloud/operators/natural_language.py b/reference/providers/google/cloud/operators/natural_language.py new file mode 100644 index 0000000..89a3a32 --- /dev/null +++ b/reference/providers/google/cloud/operators/natural_language.py @@ -0,0 +1,358 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud Language operators.""" +from typing import Optional, Sequence, Tuple, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.natural_language import ( + CloudNaturalLanguageHook, +) +from airflow.utils.decorators import apply_defaults +from google.api_core.retry import Retry +from google.cloud.language_v1 import enums +from google.cloud.language_v1.types import Document +from google.protobuf.json_format import MessageToDict + +MetaData = Sequence[Tuple[str, str]] + + +class CloudNaturalLanguageAnalyzeEntitiesOperator(BaseOperator): + """ + Finds named entities in the text along with entity types, + salience, mentions for each entity, and other properties. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudNaturalLanguageAnalyzeEntitiesOperator` + + :param document: Input document. + If a dict is provided, it must be of the same form as the protobuf message Document + :type document: dict or google.cloud.language_v1.types.Document + :param encoding_type: The encoding type used by the API to calculate offsets. + :type encoding_type: google.cloud.language_v1.enums.EncodingType + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START natural_language_analyze_entities_template_fields] + template_fields = ( + "document", + "gcp_conn_id", + "impersonation_chain", + ) + # [END natural_language_analyze_entities_template_fields] + + @apply_defaults + def __init__( + self, + *, + document: Union[dict, Document], + encoding_type: Optional[enums.EncodingType] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.document = document + self.encoding_type = encoding_type + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudNaturalLanguageHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Start analyzing entities") + response = hook.analyze_entities( + document=self.document, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Finished analyzing entities") + + return MessageToDict(response) + + +class CloudNaturalLanguageAnalyzeEntitySentimentOperator(BaseOperator): + """ + Finds entities, similar to AnalyzeEntities in the text and analyzes sentiment associated with each + entity and its mentions. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudNaturalLanguageAnalyzeEntitySentimentOperator` + + :param document: Input document. + If a dict is provided, it must be of the same form as the protobuf message Document + :type document: dict or google.cloud.language_v1.types.Document + :param encoding_type: The encoding type used by the API to calculate offsets. + :type encoding_type: google.cloud.language_v1.enums.EncodingType + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.language_v1.types.AnalyzeEntitiesResponse + """ + + # [START natural_language_analyze_entity_sentiment_template_fields] + template_fields = ( + "document", + "gcp_conn_id", + "impersonation_chain", + ) + # [END natural_language_analyze_entity_sentiment_template_fields] + + @apply_defaults + def __init__( + self, + *, + document: Union[dict, Document], + encoding_type: Optional[enums.EncodingType] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.document = document + self.encoding_type = encoding_type + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudNaturalLanguageHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Start entity sentiment analyze") + response = hook.analyze_entity_sentiment( + document=self.document, + encoding_type=self.encoding_type, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Finished entity sentiment analyze") + + return MessageToDict(response) + + +class CloudNaturalLanguageAnalyzeSentimentOperator(BaseOperator): + """ + Analyzes the sentiment of the provided text. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudNaturalLanguageAnalyzeSentimentOperator` + + :param document: Input document. + If a dict is provided, it must be of the same form as the protobuf message Document + :type document: dict or google.cloud.language_v1.types.Document + :param encoding_type: The encoding type used by the API to calculate offsets. + :type encoding_type: google.cloud.language_v1.enums.EncodingType + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.language_v1.types.AnalyzeEntitiesResponse + """ + + # [START natural_language_analyze_sentiment_template_fields] + template_fields = ( + "document", + "gcp_conn_id", + "impersonation_chain", + ) + # [END natural_language_analyze_sentiment_template_fields] + + @apply_defaults + def __init__( + self, + *, + document: Union[dict, Document], + encoding_type: Optional[enums.EncodingType] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.document = document + self.encoding_type = encoding_type + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudNaturalLanguageHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Start sentiment analyze") + response = hook.analyze_sentiment( + document=self.document, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Finished sentiment analyze") + + return MessageToDict(response) + + +class CloudNaturalLanguageClassifyTextOperator(BaseOperator): + """ + Classifies a document into categories. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudNaturalLanguageClassifyTextOperator` + + :param document: Input document. + If a dict is provided, it must be of the same form as the protobuf message Document + :type document: dict or google.cloud.language_v1.types.Document + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START natural_language_classify_text_template_fields] + template_fields = ( + "document", + "gcp_conn_id", + "impersonation_chain", + ) + # [END natural_language_classify_text_template_fields] + + @apply_defaults + def __init__( + self, + *, + document: Union[dict, Document], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.document = document + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudNaturalLanguageHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Start text classify") + response = hook.classify_text( + document=self.document, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Finished text classify") + + return MessageToDict(response) diff --git a/reference/providers/google/cloud/operators/pubsub.py b/reference/providers/google/cloud/operators/pubsub.py new file mode 100644 index 0000000..4ea77ac --- /dev/null +++ b/reference/providers/google/cloud/operators/pubsub.py @@ -0,0 +1,965 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google PubSub operators.""" +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.pubsub import PubSubHook +from airflow.utils.decorators import apply_defaults +from google.api_core.retry import Retry +from google.cloud.pubsub_v1.types import ( + DeadLetterPolicy, + Duration, + ExpirationPolicy, + MessageStoragePolicy, + PushConfig, + ReceivedMessage, + RetryPolicy, +) + + +class PubSubCreateTopicOperator(BaseOperator): + """Create a PubSub topic. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:PubSubCreateTopicOperator` + + By default, if the topic already exists, this operator will + not cause the DAG to fail. :: + + with DAG('successful DAG') as dag: + ( + PubSubTopicCreateOperator(project='my-project', + topic='my_new_topic') + >> PubSubTopicCreateOperator(project='my-project', + topic='my_new_topic') + ) + + The operator can be configured to fail if the topic already exists. :: + + with DAG('failing DAG') as dag: + ( + PubSubTopicCreateOperator(project='my-project', + topic='my_new_topic') + >> PubSubTopicCreateOperator(project='my-project', + topic='my_new_topic', + fail_if_exists=True) + ) + + Both ``project`` and ``topic`` are templated so you can use + variables in them. + + :param project_id: Optional, the Google Cloud project ID where the topic will be created. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param topic: the topic to create. Do not include the + full topic path. In other words, instead of + ``projects/{project}/topics/{topic}``, provide only + ``{topic}``. (templated) + :type topic: str + :param gcp_conn_id: The connection ID to use connecting to + Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param labels: Client-assigned labels; see + https://cloud.google.com/pubsub/docs/labels + :type labels: Dict[str, str] + :param message_storage_policy: Policy constraining the set + of Google Cloud regions where messages published to + the topic may be stored. If not present, then no constraints + are in effect. + :type message_storage_policy: + Union[Dict, google.cloud.pubsub_v1.types.MessageStoragePolicy] + :param kms_key_name: The resource name of the Cloud KMS CryptoKey + to be used to protect access to messages published on this topic. + The expected format is + ``projects/*/locations/*/keyRings/*/cryptoKeys/*``. + :type kms_key_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + :param project: (Deprecated) the Google Cloud project ID where the topic will be created + :type project: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "project_id", + "topic", + "impersonation_chain", + ] + ui_color = "#0273d4" + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + topic: str, + project_id: Optional[str] = None, + fail_if_exists: bool = False, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + message_storage_policy: Union[Dict, MessageStoragePolicy] = None, + kms_key_name: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + project: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + # To preserve backward compatibility + # TODO: remove one day + if project: + warnings.warn( + "The project parameter has been deprecated. You should pass the project_id parameter.", + DeprecationWarning, + stacklevel=2, + ) + project_id = project + + super().__init__(**kwargs) + self.project_id = project_id + self.topic = topic + self.fail_if_exists = fail_if_exists + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.labels = labels + self.message_storage_policy = message_storage_policy + self.kms_key_name = kms_key_name + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = PubSubHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Creating topic %s", self.topic) + hook.create_topic( + project_id=self.project_id, + topic=self.topic, + fail_if_exists=self.fail_if_exists, + labels=self.labels, + message_storage_policy=self.message_storage_policy, + kms_key_name=self.kms_key_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Created topic %s", self.topic) + + +# pylint: disable=too-many-instance-attributes +class PubSubCreateSubscriptionOperator(BaseOperator): + """Create a PubSub subscription. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:PubSubCreateSubscriptionOperator` + + By default, the subscription will be created in ``topic_project``. If + ``subscription_project`` is specified and the Google Cloud credentials allow, the + Subscription can be created in a different project from its topic. + + By default, if the subscription already exists, this operator will + not cause the DAG to fail. However, the topic must exist in the project. :: + + with DAG('successful DAG') as dag: + ( + PubSubSubscriptionCreateOperator( + topic_project='my-project', topic='my-topic', + subscription='my-subscription') + >> PubSubSubscriptionCreateOperator( + topic_project='my-project', topic='my-topic', + subscription='my-subscription') + ) + + The operator can be configured to fail if the subscription already exists. + :: + + with DAG('failing DAG') as dag: + ( + PubSubSubscriptionCreateOperator( + topic_project='my-project', topic='my-topic', + subscription='my-subscription') + >> PubSubSubscriptionCreateOperator( + topic_project='my-project', topic='my-topic', + subscription='my-subscription', fail_if_exists=True) + ) + + Finally, subscription is not required. If not passed, the operator will + generated a universally unique identifier for the subscription's name. :: + + with DAG('DAG') as dag: + ( + PubSubSubscriptionCreateOperator( + topic_project='my-project', topic='my-topic') + ) + + ``topic_project``, ``topic``, ``subscription``, and + ``subscription`` are templated so you can use variables in them. + + :param project_id: Optional, the Google Cloud project ID where the topic exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param topic: the topic to create. Do not include the + full topic path. In other words, instead of + ``projects/{project}/topics/{topic}``, provide only + ``{topic}``. (templated) + :type topic: str + :param subscription: the Pub/Sub subscription name. If empty, a random + name will be generated using the uuid module + :type subscription: str + :param subscription_project_id: the Google Cloud project ID where the subscription + will be created. If empty, ``topic_project`` will be used. + :type subscription_project_id: str + :param ack_deadline_secs: Number of seconds that a subscriber has to + acknowledge each message pulled from the subscription + :type ack_deadline_secs: int + :param gcp_conn_id: The connection ID to use connecting to + Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param push_config: If push delivery is used with this subscription, + this field is used to configure it. An empty ``pushConfig`` signifies + that the subscriber will pull and ack messages using API methods. + :type push_config: Union[Dict, google.cloud.pubsub_v1.types.PushConfig] + :param retain_acked_messages: Indicates whether to retain acknowledged + messages. If true, then messages are not expunged from the subscription's + backlog, even if they are acknowledged, until they fall out of the + ``message_retention_duration`` window. This must be true if you would + like to Seek to a timestamp. + :type retain_acked_messages: bool + :param message_retention_duration: How long to retain unacknowledged messages + in the subscription's backlog, from the moment a message is published. If + ``retain_acked_messages`` is true, then this also configures the + retention of acknowledged messages, and thus configures how far back in + time a ``Seek`` can be done. Defaults to 7 days. Cannot be more than 7 + days or less than 10 minutes. + :type message_retention_duration: Union[Dict, google.cloud.pubsub_v1.types.Duration] + :param labels: Client-assigned labels; see + https://cloud.google.com/pubsub/docs/labels + :type labels: Dict[str, str] + :param enable_message_ordering: If true, messages published with the same + ordering_key in PubsubMessage will be delivered to the subscribers in the order + in which they are received by the Pub/Sub system. Otherwise, they may be + delivered in any order. + :type enable_message_ordering: bool + :param expiration_policy: A policy that specifies the conditions for this + subscription’s expiration. A subscription is considered active as long as any + connected subscriber is successfully consuming messages from the subscription or + is issuing operations on the subscription. If expiration_policy is not set, + a default policy with ttl of 31 days will be used. The minimum allowed value for + expiration_policy.ttl is 1 day. + :type expiration_policy: Union[Dict, google.cloud.pubsub_v1.types.ExpirationPolicy`] + :param filter_: An expression written in the Cloud Pub/Sub filter language. If + non-empty, then only PubsubMessages whose attributes field matches the filter are + delivered on this subscription. If empty, then no messages are filtered out. + :type filter_: str + :param dead_letter_policy: A policy that specifies the conditions for dead lettering + messages in this subscription. If dead_letter_policy is not set, dead lettering is + disabled. + :type dead_letter_policy: Union[Dict, google.cloud.pubsub_v1.types.DeadLetterPolicy] + :param retry_policy: A policy that specifies how Pub/Sub retries message delivery + for this subscription. If not set, the default retry policy is applied. This + generally implies that messages will be retried as soon as possible for healthy + subscribers. RetryPolicy will be triggered on NACKs or acknowledgement deadline + exceeded events for a given message. + :type retry_policy: Union[Dict, google.cloud.pubsub_v1.types.RetryPolicy] + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + :param topic_project: (Deprecated) the Google Cloud project ID where the topic exists + :type topic_project: str + :param subscription_project: (Deprecated) the Google Cloud project ID where the subscription + will be created. If empty, ``topic_project`` will be used. + :type subscription_project: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "project_id", + "topic", + "subscription", + "subscription_project_id", + "impersonation_chain", + ] + ui_color = "#0273d4" + + # pylint: disable=too-many-arguments, too-many-locals + @apply_defaults + def __init__( + self, + *, + topic: str, + project_id: Optional[str] = None, + subscription: Optional[str] = None, + subscription_project_id: Optional[str] = None, + ack_deadline_secs: int = 10, + fail_if_exists: bool = False, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + push_config: Optional[Union[Dict, PushConfig]] = None, + retain_acked_messages: Optional[bool] = None, + message_retention_duration: Optional[Union[Dict, Duration]] = None, + labels: Optional[Dict[str, str]] = None, + enable_message_ordering: bool = False, + expiration_policy: Optional[Union[Dict, ExpirationPolicy]] = None, + filter_: Optional[str] = None, + dead_letter_policy: Optional[Union[Dict, DeadLetterPolicy]] = None, + retry_policy: Optional[Union[Dict, RetryPolicy]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + topic_project: Optional[str] = None, + subscription_project: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + # To preserve backward compatibility + # TODO: remove one day + if topic_project: + warnings.warn( + "The topic_project parameter has been deprecated. You should pass " + "the project_id parameter.", + DeprecationWarning, + stacklevel=2, + ) + project_id = topic_project + if subscription_project: + warnings.warn( + "The project_id parameter has been deprecated. You should pass " + "the subscription_project parameter.", + DeprecationWarning, + stacklevel=2, + ) + subscription_project_id = subscription_project + + super().__init__(**kwargs) + self.project_id = project_id + self.topic = topic + self.subscription = subscription + self.subscription_project_id = subscription_project_id + self.ack_deadline_secs = ack_deadline_secs + self.fail_if_exists = fail_if_exists + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.push_config = push_config + self.retain_acked_messages = retain_acked_messages + self.message_retention_duration = message_retention_duration + self.labels = labels + self.enable_message_ordering = enable_message_ordering + self.expiration_policy = expiration_policy + self.filter_ = filter_ + self.dead_letter_policy = dead_letter_policy + self.retry_policy = retry_policy + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> str: + hook = PubSubHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Creating subscription for topic %s", self.topic) + result = hook.create_subscription( + project_id=self.project_id, + topic=self.topic, + subscription=self.subscription, + subscription_project_id=self.subscription_project_id, + ack_deadline_secs=self.ack_deadline_secs, + fail_if_exists=self.fail_if_exists, + push_config=self.push_config, + retain_acked_messages=self.retain_acked_messages, + message_retention_duration=self.message_retention_duration, + labels=self.labels, + enable_message_ordering=self.enable_message_ordering, + expiration_policy=self.expiration_policy, + filter_=self.filter_, + dead_letter_policy=self.dead_letter_policy, + retry_policy=self.retry_policy, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + self.log.info("Created subscription for topic %s", self.topic) + return result + + +class PubSubDeleteTopicOperator(BaseOperator): + """Delete a PubSub topic. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:PubSubDeleteTopicOperator` + + By default, if the topic does not exist, this operator will + not cause the DAG to fail. :: + + with DAG('successful DAG') as dag: + ( + PubSubTopicDeleteOperator(project='my-project', + topic='non_existing_topic') + ) + + The operator can be configured to fail if the topic does not exist. :: + + with DAG('failing DAG') as dag: + ( + PubSubTopicCreateOperator(project='my-project', + topic='non_existing_topic', + fail_if_not_exists=True) + ) + + Both ``project`` and ``topic`` are templated so you can use + variables in them. + + :param project_id: Optional, the Google Cloud project ID in which to work (templated). + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param topic: the topic to delete. Do not include the + full topic path. In other words, instead of + ``projects/{project}/topics/{topic}``, provide only + ``{topic}``. (templated) + :type topic: str + :param fail_if_not_exists: If True and the topic does not exist, fail + the task + :type fail_if_not_exists: bool + :param gcp_conn_id: The connection ID to use connecting to + Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + :param project: (Deprecated) the Google Cloud project ID where the topic will be created + :type project: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "project_id", + "topic", + "impersonation_chain", + ] + ui_color = "#cb4335" + + @apply_defaults + def __init__( + self, + *, + topic: str, + project_id: Optional[str] = None, + fail_if_not_exists: bool = False, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + project: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + # To preserve backward compatibility + # TODO: remove one day + if project: + warnings.warn( + "The project parameter has been deprecated. You should pass the project_id parameter.", + DeprecationWarning, + stacklevel=2, + ) + project_id = project + + super().__init__(**kwargs) + self.project_id = project_id + self.topic = topic + self.fail_if_not_exists = fail_if_not_exists + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = PubSubHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Deleting topic %s", self.topic) + hook.delete_topic( + project_id=self.project_id, + topic=self.topic, + fail_if_not_exists=self.fail_if_not_exists, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Deleted topic %s", self.topic) + + +class PubSubDeleteSubscriptionOperator(BaseOperator): + """Delete a PubSub subscription. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:PubSubDeleteSubscriptionOperator` + + By default, if the subscription does not exist, this operator will + not cause the DAG to fail. :: + + with DAG('successful DAG') as dag: + ( + PubSubSubscriptionDeleteOperator(project='my-project', + subscription='non-existing') + ) + + The operator can be configured to fail if the subscription already exists. + + :: + + with DAG('failing DAG') as dag: + ( + PubSubSubscriptionDeleteOperator( + project='my-project', subscription='non-existing', + fail_if_not_exists=True) + ) + + ``project``, and ``subscription`` are templated so you can use + variables in them. + + :param project_id: Optional, the Google Cloud project ID in which to work (templated). + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param subscription: the subscription to delete. Do not include the + full subscription path. In other words, instead of + ``projects/{project}/subscription/{subscription}``, provide only + ``{subscription}``. (templated) + :type subscription: str + :param fail_if_not_exists: If True and the subscription does not exist, + fail the task + :type fail_if_not_exists: bool + :param gcp_conn_id: The connection ID to use connecting to + Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + :param project: (Deprecated) the Google Cloud project ID where the topic will be created + :type project: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "project_id", + "subscription", + "impersonation_chain", + ] + ui_color = "#cb4335" + + @apply_defaults + def __init__( + self, + *, + subscription: str, + project_id: Optional[str] = None, + fail_if_not_exists: bool = False, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + project: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + # To preserve backward compatibility + # TODO: remove one day + if project: + warnings.warn( + "The project parameter has been deprecated. You should pass the project_id parameter.", + DeprecationWarning, + stacklevel=2, + ) + project_id = project + + super().__init__(**kwargs) + self.project_id = project_id + self.subscription = subscription + self.fail_if_not_exists = fail_if_not_exists + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = PubSubHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Deleting subscription %s", self.subscription) + hook.delete_subscription( + project_id=self.project_id, + subscription=self.subscription, + fail_if_not_exists=self.fail_if_not_exists, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Deleted subscription %s", self.subscription) + + +class PubSubPublishMessageOperator(BaseOperator): + """Publish messages to a PubSub topic. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:PubSubPublishMessageOperator` + + Each Task publishes all provided messages to the same topic + in a single Google Cloud project. If the topic does not exist, this + task will fail. :: + + m1 = {'data': b'Hello, World!', + 'attributes': {'type': 'greeting'} + } + m2 = {'data': b'Knock, knock'} + m3 = {'attributes': {'foo': ''}} + + t1 = PubSubPublishOperator( + project='my-project',topic='my_topic', + messages=[m1, m2, m3], + create_topic=True, + dag=dag) + + ``project`` , ``topic``, and ``messages`` are templated so you can use + variables in them. + + :param project_id: Optional, the Google Cloud project ID in which to work (templated). + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param topic: the topic to which to publish. Do not include the + full topic path. In other words, instead of + ``projects/{project}/topics/{topic}``, provide only + ``{topic}``. (templated) + :type topic: str + :param messages: a list of messages to be published to the + topic. Each message is a dict with one or more of the + following keys-value mappings: + * 'data': a bytestring (utf-8 encoded) + * 'attributes': {'key1': 'value1', ...} + Each message must contain at least a non-empty 'data' value + or an attribute dict with at least one key (templated). See + https://cloud.google.com/pubsub/docs/reference/rest/v1/PubsubMessage + :type messages: list + :param gcp_conn_id: The connection ID to use connecting to + Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param project: (Deprecated) the Google Cloud project ID where the topic will be created + :type project: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "project_id", + "topic", + "messages", + "impersonation_chain", + ] + ui_color = "#0273d4" + + @apply_defaults + def __init__( + self, + *, + topic: str, + messages: List, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + project: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + # To preserve backward compatibility + # TODO: remove one day + if project: + warnings.warn( + "The project parameter has been deprecated. You should pass the project_id parameter.", + DeprecationWarning, + stacklevel=2, + ) + project_id = project + + super().__init__(**kwargs) + self.project_id = project_id + self.topic = topic + self.messages = messages + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + hook = PubSubHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Publishing to topic %s", self.topic) + hook.publish( + project_id=self.project_id, topic=self.topic, messages=self.messages + ) + self.log.info("Published to topic %s", self.topic) + + +class PubSubPullOperator(BaseOperator): + """Pulls messages from a PubSub subscription and passes them through XCom. + If the queue is empty, returns empty list - never waits for messages. + If you do need to wait, please use :class:`airflow.providers.google.cloud.sensors.PubSubPullSensor` + instead. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:PubSubPullSensor` + + This sensor operator will pull up to ``max_messages`` messages from the + specified PubSub subscription. When the subscription returns messages, + the poke method's criteria will be fulfilled and the messages will be + returned from the operator and passed through XCom for downstream tasks. + + If ``ack_messages`` is set to True, messages will be immediately + acknowledged before being returned, otherwise, downstream tasks will be + responsible for acknowledging them. + + ``project`` and ``subscription`` are templated so you can use + variables in them. + + :param project: the Google Cloud project ID for the subscription (templated) + :type project: str + :param subscription: the Pub/Sub subscription name. Do not include the + full subscription path. + :type subscription: str + :param max_messages: The maximum number of messages to retrieve per + PubSub pull request + :type max_messages: int + :param ack_messages: If True, each message will be acknowledged + immediately rather than by any downstream tasks + :type ack_messages: bool + :param gcp_conn_id: The connection ID to use connecting to + Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param messages_callback: (Optional) Callback to process received messages. + It's return value will be saved to XCom. + If you are pulling large messages, you probably want to provide a custom callback. + If not provided, the default implementation will convert `ReceivedMessage` objects + into JSON-serializable dicts using `google.protobuf.json_format.MessageToDict` function. + :type messages_callback: Optional[Callable[[List[ReceivedMessage], Dict[str, Any]], Any]] + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "project_id", + "subscription", + "impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + project_id: str, + subscription: str, + max_messages: int = 5, + ack_messages: bool = False, + messages_callback: Optional[ + Callable[[List[ReceivedMessage], Dict[str, Any]], Any] + ] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.project_id = project_id + self.subscription = subscription + self.max_messages = max_messages + self.ack_messages = ack_messages + self.messages_callback = messages_callback + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> list: + hook = PubSubHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + pulled_messages = hook.pull( + project_id=self.project_id, + subscription=self.subscription, + max_messages=self.max_messages, + return_immediately=True, + ) + + handle_messages = self.messages_callback or self._default_message_callback + + ret = handle_messages(pulled_messages, context) + + if pulled_messages and self.ack_messages: + hook.acknowledge( + project_id=self.project_id, + subscription=self.subscription, + messages=pulled_messages, + ) + + return ret + + def _default_message_callback( + self, + pulled_messages: List[ReceivedMessage], + context: Dict[str, Any], # pylint: disable=unused-argument + ) -> list: + """ + This method can be overridden by subclasses or by `messages_callback` constructor argument. + This default implementation converts `ReceivedMessage` objects into JSON-serializable dicts. + + :param pulled_messages: messages received from the topic. + :type pulled_messages: List[ReceivedMessage] + :param context: same as in `execute` + :return: value to be saved to XCom. + """ + messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages] + + return messages_json diff --git a/reference/providers/google/cloud/operators/spanner.py b/reference/providers/google/cloud/operators/spanner.py new file mode 100644 index 0000000..b0c0b79 --- /dev/null +++ b/reference/providers/google/cloud/operators/spanner.py @@ -0,0 +1,635 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Spanner operators.""" +from typing import List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.spanner import SpannerHook +from airflow.utils.decorators import apply_defaults + + +class SpannerDeployInstanceOperator(BaseOperator): + """ + Creates a new Cloud Spanner instance, or if an instance with the same instance_id + exists in the specified project, updates the Cloud Spanner instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SpannerDeployInstanceOperator` + + :param instance_id: Cloud Spanner instance ID. + :type instance_id: str + :param configuration_name: The name of the Cloud Spanner instance configuration + defining how the instance will be created. Required for + instances that do not yet exist. + :type configuration_name: str + :param node_count: (Optional) The number of nodes allocated to the Cloud Spanner + instance. + :type node_count: int + :param display_name: (Optional) The display name for the Cloud Spanner instance in + the Google Cloud Console. (Must be between 4 and 30 characters.) If this value is not set + in the constructor, the name is the same as the instance ID. + :type display_name: str + :param project_id: Optional, the ID of the project which owns the Cloud Spanner + Database. If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_spanner_deploy_template_fields] + template_fields = ( + "project_id", + "instance_id", + "configuration_name", + "display_name", + "gcp_conn_id", + "impersonation_chain", + ) + # [END gcp_spanner_deploy_template_fields] + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + configuration_name: str, + node_count: int, + display_name: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.instance_id = instance_id + self.project_id = project_id + self.configuration_name = configuration_name + self.node_count = node_count + self.display_name = display_name + self.gcp_conn_id = gcp_conn_id + self._validate_inputs() + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def _validate_inputs(self) -> None: + if self.project_id == "": + raise AirflowException("The required parameter 'project_id' is empty") + if not self.instance_id: + raise AirflowException( + "The required parameter 'instance_id' is empty or None" + ) + + def execute(self, context) -> None: + hook = SpannerHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + if not hook.get_instance( + project_id=self.project_id, instance_id=self.instance_id + ): + self.log.info("Creating Cloud Spanner instance '%s'", self.instance_id) + func = hook.create_instance + else: + self.log.info("Updating Cloud Spanner instance '%s'", self.instance_id) + func = hook.update_instance + func( + project_id=self.project_id, + instance_id=self.instance_id, + configuration_name=self.configuration_name, + node_count=self.node_count, + display_name=self.display_name, + ) + + +class SpannerDeleteInstanceOperator(BaseOperator): + """ + Deletes a Cloud Spanner instance. If an instance does not exist, + no action is taken and the operator succeeds. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SpannerDeleteInstanceOperator` + + :param instance_id: The Cloud Spanner instance ID. + :type instance_id: str + :param project_id: Optional, the ID of the project that owns the Cloud Spanner + Database. If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_spanner_delete_template_fields] + template_fields = ( + "project_id", + "instance_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END gcp_spanner_delete_template_fields] + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.instance_id = instance_id + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self._validate_inputs() + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def _validate_inputs(self) -> None: + if self.project_id == "": + raise AirflowException("The required parameter 'project_id' is empty") + if not self.instance_id: + raise AirflowException( + "The required parameter 'instance_id' is empty or None" + ) + + def execute(self, context) -> Optional[bool]: + hook = SpannerHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + if hook.get_instance(project_id=self.project_id, instance_id=self.instance_id): + return hook.delete_instance( + project_id=self.project_id, instance_id=self.instance_id + ) + else: + self.log.info( + "Instance '%s' does not exist in project '%s'. Aborting delete.", + self.instance_id, + self.project_id, + ) + return True + + +class SpannerQueryDatabaseInstanceOperator(BaseOperator): + """ + Executes an arbitrary DML query (INSERT, UPDATE, DELETE). + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SpannerQueryDatabaseInstanceOperator` + + :param instance_id: The Cloud Spanner instance ID. + :type instance_id: str + :param database_id: The Cloud Spanner database ID. + :type database_id: str + :param query: The query or list of queries to be executed. Can be a path to a SQL + file. + :type query: str or list + :param project_id: Optional, the ID of the project that owns the Cloud Spanner + Database. If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_spanner_query_template_fields] + template_fields = ( + "project_id", + "instance_id", + "database_id", + "query", + "gcp_conn_id", + "impersonation_chain", + ) + template_ext = (".sql",) + # [END gcp_spanner_query_template_fields] + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + database_id: str, + query: Union[str, List[str]], + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.instance_id = instance_id + self.project_id = project_id + self.database_id = database_id + self.query = query + self.gcp_conn_id = gcp_conn_id + self._validate_inputs() + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def _validate_inputs(self) -> None: + if self.project_id == "": + raise AirflowException("The required parameter 'project_id' is empty") + if not self.instance_id: + raise AirflowException( + "The required parameter 'instance_id' is empty or None" + ) + if not self.database_id: + raise AirflowException( + "The required parameter 'database_id' is empty or None" + ) + if not self.query: + raise AirflowException("The required parameter 'query' is empty") + + def execute(self, context): + hook = SpannerHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + queries = self.query + if isinstance(self.query, str): + queries = [x.strip() for x in self.query.split(";")] + self.sanitize_queries(queries) + self.log.info( + "Executing DML query(-ies) on projects/%s/instances/%s/databases/%s", + self.project_id, + self.instance_id, + self.database_id, + ) + self.log.info(queries) + hook.execute_dml( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + queries=queries, + ) + + @staticmethod + def sanitize_queries(queries: List[str]) -> None: + """ + Drops empty query in queries. + + :param queries: queries + :type queries: List[str] + :rtype: None + """ + if queries and queries[-1] == "": + del queries[-1] + + +class SpannerDeployDatabaseInstanceOperator(BaseOperator): + """ + Creates a new Cloud Spanner database, or if database exists, + the operator does nothing. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SpannerDeployDatabaseInstanceOperator` + + :param instance_id: The Cloud Spanner instance ID. + :type instance_id: str + :param database_id: The Cloud Spanner database ID. + :type database_id: str + :param ddl_statements: The string list containing DDL for the new database. + :type ddl_statements: list[str] + :param project_id: Optional, the ID of the project that owns the Cloud Spanner + Database. If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_spanner_database_deploy_template_fields] + template_fields = ( + "project_id", + "instance_id", + "database_id", + "ddl_statements", + "gcp_conn_id", + "impersonation_chain", + ) + template_ext = (".sql",) + # [END gcp_spanner_database_deploy_template_fields] + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + database_id: str, + ddl_statements: List[str], + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.instance_id = instance_id + self.project_id = project_id + self.database_id = database_id + self.ddl_statements = ddl_statements + self.gcp_conn_id = gcp_conn_id + self._validate_inputs() + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def _validate_inputs(self) -> None: + if self.project_id == "": + raise AirflowException("The required parameter 'project_id' is empty") + if not self.instance_id: + raise AirflowException( + "The required parameter 'instance_id' is empty or None" + ) + if not self.database_id: + raise AirflowException( + "The required parameter 'database_id' is empty or None" + ) + + def execute(self, context) -> Optional[bool]: + hook = SpannerHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + if not hook.get_database( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + ): + self.log.info( + "Creating Cloud Spanner database '%s' in project '%s' and instance '%s'", + self.database_id, + self.project_id, + self.instance_id, + ) + return hook.create_database( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + ddl_statements=self.ddl_statements, + ) + else: + self.log.info( + "The database '%s' in project '%s' and instance '%s'" + " already exists. Nothing to do. Exiting.", + self.database_id, + self.project_id, + self.instance_id, + ) + return True + + +class SpannerUpdateDatabaseInstanceOperator(BaseOperator): + """ + Updates a Cloud Spanner database with the specified DDL statement. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SpannerUpdateDatabaseInstanceOperator` + + :param instance_id: The Cloud Spanner instance ID. + :type instance_id: str + :param database_id: The Cloud Spanner database ID. + :type database_id: str + :param ddl_statements: The string list containing DDL to apply to the database. + :type ddl_statements: list[str] + :param project_id: Optional, the ID of the project that owns the Cloud Spanner + Database. If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param operation_id: (Optional) Unique per database operation id that can + be specified to implement idempotency check. + :type operation_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_spanner_database_update_template_fields] + template_fields = ( + "project_id", + "instance_id", + "database_id", + "ddl_statements", + "gcp_conn_id", + "impersonation_chain", + ) + template_ext = (".sql",) + # [END gcp_spanner_database_update_template_fields] + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + database_id: str, + ddl_statements: List[str], + project_id: Optional[str] = None, + operation_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.instance_id = instance_id + self.project_id = project_id + self.database_id = database_id + self.ddl_statements = ddl_statements + self.operation_id = operation_id + self.gcp_conn_id = gcp_conn_id + self._validate_inputs() + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def _validate_inputs(self) -> None: + if self.project_id == "": + raise AirflowException("The required parameter 'project_id' is empty") + if not self.instance_id: + raise AirflowException( + "The required parameter 'instance_id' is empty or None" + ) + if not self.database_id: + raise AirflowException( + "The required parameter 'database_id' is empty or None" + ) + if not self.ddl_statements: + raise AirflowException( + "The required parameter 'ddl_statements' is empty or None" + ) + + def execute(self, context) -> None: + hook = SpannerHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + if not hook.get_database( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + ): + raise AirflowException( + "The Cloud Spanner database '{}' in project '{}' and " + "instance '{}' is missing. Create the database first " + "before you can update it.".format( + self.database_id, self.project_id, self.instance_id + ) + ) + else: + return hook.update_database( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + ddl_statements=self.ddl_statements, + operation_id=self.operation_id, + ) + + +class SpannerDeleteDatabaseInstanceOperator(BaseOperator): + """ + Deletes a Cloud Spanner database. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SpannerDeleteDatabaseInstanceOperator` + + :param instance_id: Cloud Spanner instance ID. + :type instance_id: str + :param database_id: Cloud Spanner database ID. + :type database_id: str + :param project_id: Optional, the ID of the project that owns the Cloud Spanner + Database. If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_spanner_database_delete_template_fields] + template_fields = ( + "project_id", + "instance_id", + "database_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END gcp_spanner_database_delete_template_fields] + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + database_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.instance_id = instance_id + self.project_id = project_id + self.database_id = database_id + self.gcp_conn_id = gcp_conn_id + self._validate_inputs() + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def _validate_inputs(self) -> None: + if self.project_id == "": + raise AirflowException("The required parameter 'project_id' is empty") + if not self.instance_id: + raise AirflowException( + "The required parameter 'instance_id' is empty or None" + ) + if not self.database_id: + raise AirflowException( + "The required parameter 'database_id' is empty or None" + ) + + def execute(self, context) -> bool: + hook = SpannerHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + database = hook.get_database( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + ) + if not database: + self.log.info( + "The Cloud Spanner database was missing: " + "'%s' in project '%s' and instance '%s'. Assuming success.", + self.database_id, + self.project_id, + self.instance_id, + ) + return True + else: + return hook.delete_database( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + ) diff --git a/reference/providers/google/cloud/operators/speech_to_text.py b/reference/providers/google/cloud/operators/speech_to_text.py new file mode 100644 index 0000000..4a34b46 --- /dev/null +++ b/reference/providers/google/cloud/operators/speech_to_text.py @@ -0,0 +1,119 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Speech to Text operator.""" +from typing import Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.speech_to_text import ( + CloudSpeechToTextHook, + RecognitionAudio, +) +from airflow.utils.decorators import apply_defaults +from google.api_core.retry import Retry +from google.cloud.speech_v1.types import RecognitionConfig +from google.protobuf.json_format import MessageToDict + + +class CloudSpeechToTextRecognizeSpeechOperator(BaseOperator): + """ + Recognizes speech from audio file and returns it as text. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudSpeechToTextRecognizeSpeechOperator` + + :param config: information to the recognizer that specifies how to process the request. See more: + https://googleapis.github.io/google-cloud-python/latest/speech/gapic/v1/types.html#google.cloud.speech_v1.types.RecognitionConfig + :type config: dict or google.cloud.speech_v1.types.RecognitionConfig + :param audio: audio data to be recognized. See more: + https://googleapis.github.io/google-cloud-python/latest/speech/gapic/v1/types.html#google.cloud.speech_v1.types.RecognitionAudio + :type audio: dict or google.cloud.speech_v1.types.RecognitionAudio + :param project_id: Optional, Google Cloud Project ID where the Compute + Engine Instance exists. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param retry: (Optional) A retry object used to retry requests. If None is specified, + requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to complete. + Note that if retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_speech_to_text_synthesize_template_fields] + template_fields = ( + "audio", + "config", + "project_id", + "gcp_conn_id", + "timeout", + "impersonation_chain", + ) + # [END gcp_speech_to_text_synthesize_template_fields] + + @apply_defaults + def __init__( + self, + *, + audio: RecognitionAudio, + config: RecognitionConfig, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.audio = audio + self.config = config + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.retry = retry + self.timeout = timeout + self._validate_inputs() + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def _validate_inputs(self) -> None: + if self.audio == "": + raise AirflowException("The required parameter 'audio' is empty") + if self.config == "": + raise AirflowException("The required parameter 'config' is empty") + + def execute(self, context): + hook = CloudSpeechToTextHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + response = hook.recognize_speech( + config=self.config, audio=self.audio, retry=self.retry, timeout=self.timeout + ) + return MessageToDict(response) diff --git a/reference/providers/google/cloud/operators/stackdriver.py b/reference/providers/google/cloud/operators/stackdriver.py new file mode 100644 index 0000000..22e99f2 --- /dev/null +++ b/reference/providers/google/cloud/operators/stackdriver.py @@ -0,0 +1,1035 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.stackdriver import StackdriverHook +from airflow.utils.decorators import apply_defaults +from google.api_core.gapic_v1.method import DEFAULT +from google.cloud.monitoring_v3 import AlertPolicy, NotificationChannel + + +class StackdriverListAlertPoliciesOperator(BaseOperator): + """ + Fetches all the Alert Policies identified by the filter passed as + filter parameter. The desired return type can be specified by the + format parameter, the supported formats are "dict", "json" and None + which returns python dictionary, stringified JSON and protobuf + respectively. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:StackdriverListAlertPoliciesOperator` + + :param format_: (Optional) Desired output format of the result. The + supported formats are "dict", "json" and None which returns + python dictionary, stringified JSON and protobuf respectively. + :type format_: str + :param filter_: If provided, this field specifies the criteria that must be met by alert + policies to be included in the response. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type filter_: str + :param order_by: A comma-separated list of fields by which to sort the result. + Supports the same set of field references as the ``filter`` field. Entries + can be prefixed with a minus sign to sort by the field in descending order. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type order_by: str + :param page_size: The maximum number of resources contained in the + underlying API response. If page streaming is performed per- + resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number + of resources in a page. + :type page_size: int + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google + Cloud Platform. + :type gcp_conn_id: str + :param project_id: The project to fetch alerts from. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "filter_", + "impersonation_chain", + ) + ui_color = "#e5ffcc" + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + format_: Optional[str] = None, + filter_: Optional[str] = None, + order_by: Optional[str] = None, + page_size: Optional[int] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.format_ = format_ + self.filter_ = filter_ + self.order_by = order_by + self.page_size = page_size + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook = None + + def execute(self, context): + self.log.info( + "List Alert Policies: Project id: %s Format: %s Filter: %s Order By: %s Page Size: %s", + self.project_id, + self.format_, + self.filter_, + self.order_by, + self.page_size, + ) + if self.hook is None: + self.hook = StackdriverHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + result = self.hook.list_alert_policies( + project_id=self.project_id, + format_=self.format_, + filter_=self.filter_, + order_by=self.order_by, + page_size=self.page_size, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return [AlertPolicy.to_dict(policy) for policy in result] + + +class StackdriverEnableAlertPoliciesOperator(BaseOperator): + """ + Enables one or more disabled alerting policies identified by filter + parameter. Inoperative in case the policy is already enabled. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:StackdriverEnableAlertPoliciesOperator` + + :param filter_: If provided, this field specifies the criteria that + must be met by alert policies to be enabled. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type filter_: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google + Cloud Platform. + :type gcp_conn_id: str + :param project_id: The project in which alert needs to be enabled. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + ui_color = "#e5ffcc" + template_fields = ( + "filter_", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + filter_: Optional[str] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.filter_ = filter_ + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.impersonation_chain = impersonation_chain + self.hook = None + + def execute(self, context): + self.log.info( + "Enable Alert Policies: Project id: %s Filter: %s", + self.project_id, + self.filter_, + ) + if self.hook is None: + self.hook = StackdriverHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + self.hook.enable_alert_policies( + filter_=self.filter_, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +# Disable Alert Operator +class StackdriverDisableAlertPoliciesOperator(BaseOperator): + """ + Disables one or more enabled alerting policies identified by filter + parameter. Inoperative in case the policy is already disabled. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:StackdriverDisableAlertPoliciesOperator` + + :param filter_: If provided, this field specifies the criteria that + must be met by alert policies to be disabled. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type filter_: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google + Cloud Platform. + :type gcp_conn_id: str + :param project_id: The project in which alert needs to be disabled. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + ui_color = "#e5ffcc" + template_fields = ( + "filter_", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + filter_: Optional[str] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.filter_ = filter_ + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.impersonation_chain = impersonation_chain + self.hook = None + + def execute(self, context): + self.log.info( + "Disable Alert Policies: Project id: %s Filter: %s", + self.project_id, + self.filter_, + ) + if self.hook is None: + self.hook = StackdriverHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + self.hook.disable_alert_policies( + filter_=self.filter_, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class StackdriverUpsertAlertOperator(BaseOperator): + """ + Creates a new alert or updates an existing policy identified + the name field in the alerts parameter. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:StackdriverUpsertAlertOperator` + + :param alerts: A JSON string or file that specifies all the alerts that needs + to be either created or updated. For more details, see + https://cloud.google.com/monitoring/api/ref_v3/rest/v3/projects.alertPolicies#AlertPolicy. + (templated) + :type alerts: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google + Cloud Platform. + :type gcp_conn_id: str + :param project_id: The project in which alert needs to be created/updated. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "alerts", + "impersonation_chain", + ) + template_ext = (".json",) + + ui_color = "#e5ffcc" + + @apply_defaults + def __init__( + self, + *, + alerts: str, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.alerts = alerts + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook = None + + def execute(self, context): + self.log.info( + "Upsert Alert Policies: Alerts: %s Project id: %s", + self.alerts, + self.project_id, + ) + if self.hook is None: + self.hook = StackdriverHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + self.hook.upsert_alert( + alerts=self.alerts, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class StackdriverDeleteAlertOperator(BaseOperator): + """ + Deletes an alerting policy. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:StackdriverDeleteAlertOperator` + + :param name: The alerting policy to delete. The format is: + ``projects/[PROJECT_ID]/alertPolicies/[ALERT_POLICY_ID]``. + :type name: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google + Cloud Platform. + :type gcp_conn_id: str + :param project_id: The project from which alert needs to be deleted. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "name", + "impersonation_chain", + ) + + ui_color = "#e5ffcc" + + @apply_defaults + def __init__( + self, + *, + name: str, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.name = name + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook = None + + def execute(self, context): + self.log.info( + "Delete Alert Policy: Project id: %s Name: %s", self.project_id, self.name + ) + if self.hook is None: + self.hook = StackdriverHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + self.hook.delete_alert_policy( + name=self.name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class StackdriverListNotificationChannelsOperator(BaseOperator): + """ + Fetches all the Notification Channels identified by the filter passed as + filter parameter. The desired return type can be specified by the + format parameter, the supported formats are "dict", "json" and None + which returns python dictionary, stringified JSON and protobuf + respectively. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:StackdriverListNotificationChannelsOperator` + + :param format_: (Optional) Desired output format of the result. The + supported formats are "dict", "json" and None which returns + python dictionary, stringified JSON and protobuf respectively. + :type format_: str + :param filter_: If provided, this field specifies the criteria that + must be met by notification channels to be included in the response. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type filter_: str + :param order_by: A comma-separated list of fields by which to sort the result. + Supports the same set of field references as the ``filter`` field. Entries + can be prefixed with a minus sign to sort by the field in descending order. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type order_by: str + :param page_size: The maximum number of resources contained in the + underlying API response. If page streaming is performed per- + resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number + of resources in a page. + :type page_size: int + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google + Cloud Platform. + :type gcp_conn_id: str + :param project_id: The project to fetch notification channels from. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "filter_", + "impersonation_chain", + ) + + ui_color = "#e5ffcc" + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + format_: Optional[str] = None, + filter_: Optional[str] = None, + order_by: Optional[str] = None, + page_size: Optional[int] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.format_ = format_ + self.filter_ = filter_ + self.order_by = order_by + self.page_size = page_size + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook = None + + def execute(self, context): + self.log.info( + "List Notification Channels: Project id: %s Format: %s Filter: %s Order By: %s Page Size: %s", + self.project_id, + self.format_, + self.filter_, + self.order_by, + self.page_size, + ) + if self.hook is None: + self.hook = StackdriverHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + channels = self.hook.list_notification_channels( + format_=self.format_, + project_id=self.project_id, + filter_=self.filter_, + order_by=self.order_by, + page_size=self.page_size, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = [NotificationChannel.to_dict(channel) for channel in channels] + return result + + +class StackdriverEnableNotificationChannelsOperator(BaseOperator): + """ + Enables one or more disabled alerting policies identified by filter + parameter. Inoperative in case the policy is already enabled. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:StackdriverEnableNotificationChannelsOperator` + + :param filter_: If provided, this field specifies the criteria that + must be met by notification channels to be enabled. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type filter_: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google + Cloud Platform. + :type gcp_conn_id: str + :param project_id: The location used for the operation. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "filter_", + "impersonation_chain", + ) + + ui_color = "#e5ffcc" + + @apply_defaults + def __init__( + self, + *, + filter_: Optional[str] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.filter_ = filter_ + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook = None + + def execute(self, context): + self.log.info( + "Enable Notification Channels: Project id: %s Filter: %s", + self.project_id, + self.filter_, + ) + if self.hook is None: + self.hook = StackdriverHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + self.hook.enable_notification_channels( + filter_=self.filter_, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class StackdriverDisableNotificationChannelsOperator(BaseOperator): + """ + Disables one or more enabled notification channels identified by filter + parameter. Inoperative in case the policy is already disabled. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:StackdriverDisableNotificationChannelsOperator` + + :param filter_: If provided, this field specifies the criteria that + must be met by alert policies to be disabled. + For more details, see https://cloud.google.com/monitoring/api/v3/sorting-and-filtering. + :type filter_: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google + Cloud Platform. + :type gcp_conn_id: str + :param project_id: The project in which notification channels needs to be enabled. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "filter_", + "impersonation_chain", + ) + + ui_color = "#e5ffcc" + + @apply_defaults + def __init__( + self, + *, + filter_: Optional[str] = None, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.filter_ = filter_ + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook = None + + def execute(self, context): + self.log.info( + "Disable Notification Channels: Project id: %s Filter: %s", + self.project_id, + self.filter_, + ) + if self.hook is None: + self.hook = StackdriverHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + self.hook.disable_notification_channels( + filter_=self.filter_, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class StackdriverUpsertNotificationChannelOperator(BaseOperator): + """ + Creates a new notification or updates an existing notification channel + identified the name field in the alerts parameter. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:StackdriverUpsertNotificationChannelOperator` + + :param channels: A JSON string or file that specifies all the alerts that needs + to be either created or updated. For more details, see + https://cloud.google.com/monitoring/api/ref_v3/rest/v3/projects.notificationChannels. + (templated) + :type channels: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google + Cloud Platform. + :type gcp_conn_id: str + :param project_id: The project in which notification channels needs to be created/updated. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "channels", + "impersonation_chain", + ) + template_ext = (".json",) + + ui_color = "#e5ffcc" + + @apply_defaults + def __init__( + self, + *, + channels: str, + retry: Optional[str] = DEFAULT, + timeout: Optional[str] = DEFAULT, + metadata: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.channels = channels + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook = None + + def execute(self, context): + self.log.info( + "Upsert Notification Channels: Channels: %s Project id: %s", + self.channels, + self.project_id, + ) + if self.hook is None: + self.hook = StackdriverHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + self.hook.upsert_channel( + channels=self.channels, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class StackdriverDeleteNotificationChannelOperator(BaseOperator): + """ + Deletes a notification channel. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:StackdriverDeleteNotificationChannelOperator` + + :param name: The alerting policy to delete. The format is: + ``projects/[PROJECT_ID]/notificationChannels/[CHANNEL_ID]``. + :type name: str + :param retry: A retry object used to retry requests. If ``None`` is + specified, requests will be retried using a default configuration. + :type retry: str + :param timeout: The amount of time, in seconds, to wait + for the request to complete. Note that if ``retry`` is + specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google + Cloud Platform. + :type gcp_conn_id: str + :param project_id: The project from which notification channel needs to be deleted. + :type project_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "name", + "impersonation_chain", + ) + + ui_color = "#e5ffcc" + + @apply_defaults + def __init__( + self, + *, + name: str, + retry: Optional[str] = DEFAULT, + timeout: Optional[float] = DEFAULT, + metadata: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + project_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.name = name + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook = None + + def execute(self, context): + self.log.info( + "Delete Notification Channel: Project id: %s Name: %s", + self.project_id, + self.name, + ) + if self.hook is None: + self.hook = StackdriverHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + self.hook.delete_notification_channel( + name=self.name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) diff --git a/reference/providers/google/cloud/operators/tasks.py b/reference/providers/google/cloud/operators/tasks.py new file mode 100644 index 0000000..a812599 --- /dev/null +++ b/reference/providers/google/cloud/operators/tasks.py @@ -0,0 +1,1212 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This module contains various Google Cloud Tasks operators +which allow you to perform basic operations using +Cloud Tasks queues/tasks. +""" +from typing import Dict, Optional, Sequence, Tuple, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.tasks import CloudTasksHook +from airflow.utils.decorators import apply_defaults +from google.api_core.exceptions import AlreadyExists +from google.api_core.retry import Retry +from google.cloud.tasks_v2.types import Queue, Task +from google.protobuf.field_mask_pb2 import FieldMask + +MetaData = Sequence[Tuple[str, str]] + + +class CloudTasksQueueCreateOperator(BaseOperator): + """ + Creates a queue in Cloud Tasks. + + :param location: The location name in which the queue will be created. + :type location: str + :param task_queue: The task queue to create. + Queue's name cannot be the same as an existing queue. + If a dict is provided, it must be of the same form as the protobuf message Queue. + :type task_queue: dict or google.cloud.tasks_v2.types.Queue + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param queue_name: (Optional) The queue's name. + If provided, it will be used to construct the full queue path. + :type queue_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.tasks_v2.types.Queue + """ + + template_fields = ( + "task_queue", + "project_id", + "location", + "queue_name", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + task_queue: Queue, + project_id: Optional[str] = None, + queue_name: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.task_queue = task_queue + self.project_id = project_id + self.queue_name = queue_name + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + queue = hook.create_queue( + location=self.location, + task_queue=self.task_queue, + project_id=self.project_id, + queue_name=self.queue_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except AlreadyExists: + queue = hook.get_queue( + location=self.location, + project_id=self.project_id, + queue_name=self.queue_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + return Queue.to_dict(queue) + + +class CloudTasksQueueUpdateOperator(BaseOperator): + """ + Updates a queue in Cloud Tasks. + + :param task_queue: The task queue to update. + This method creates the queue if it does not exist and updates the queue if + it does exist. The queue's name must be specified. + :type task_queue: dict or google.cloud.tasks_v2.types.Queue + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param location: (Optional) The location name in which the queue will be updated. + If provided, it will be used to construct the full queue path. + :type location: str + :param queue_name: (Optional) The queue's name. + If provided, it will be used to construct the full queue path. + :type queue_name: str + :param update_mask: A mast used to specify which fields of the queue are being updated. + If empty, then all fields will be updated. + If a dict is provided, it must be of the same form as the protobuf message. + :type update_mask: dict or google.protobuf.field_mask_pb2.FieldMask + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.tasks_v2.types.Queue + """ + + template_fields = ( + "task_queue", + "project_id", + "location", + "queue_name", + "update_mask", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + task_queue: Queue, + project_id: Optional[str] = None, + location: Optional[str] = None, + queue_name: Optional[str] = None, + update_mask: Union[Dict, FieldMask] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.task_queue = task_queue + self.project_id = project_id + self.location = location + self.queue_name = queue_name + self.update_mask = update_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + queue = hook.update_queue( + task_queue=self.task_queue, + project_id=self.project_id, + location=self.location, + queue_name=self.queue_name, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Queue.to_dict(queue) + + +class CloudTasksQueueGetOperator(BaseOperator): + """ + Gets a queue from Cloud Tasks. + + :param location: The location name in which the queue was created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.tasks_v2.types.Queue + """ + + template_fields = ( + "location", + "queue_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + queue_name: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.queue_name = queue_name + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + queue = hook.get_queue( + location=self.location, + queue_name=self.queue_name, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Queue.to_dict(queue) + + +class CloudTasksQueuesListOperator(BaseOperator): + """ + Lists queues from Cloud Tasks. + + :param location: The location name in which the queues were created. + :type location: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param results_filter: (Optional) Filter used to specify a subset of queues. + :type results_filter: str + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + + template_fields = ( + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + project_id: Optional[str] = None, + results_filter: Optional[str] = None, + page_size: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.project_id = project_id + self.results_filter = results_filter + self.page_size = page_size + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + queues = hook.list_queues( + location=self.location, + project_id=self.project_id, + results_filter=self.results_filter, + page_size=self.page_size, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return [Queue.to_dict(q) for q in queues] + + +class CloudTasksQueueDeleteOperator(BaseOperator): + """ + Deletes a queue from Cloud Tasks, even if it has tasks in it. + + :param location: The location name in which the queue will be deleted. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "queue_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + queue_name: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.queue_name = queue_name + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + hook.delete_queue( + location=self.location, + queue_name=self.queue_name, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksQueuePurgeOperator(BaseOperator): + """ + Purges a queue by deleting all of its tasks from Cloud Tasks. + + :param location: The location name in which the queue will be purged. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + + template_fields = ( + "location", + "queue_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + queue_name: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.queue_name = queue_name + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + queue = hook.purge_queue( + location=self.location, + queue_name=self.queue_name, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Queue.to_dict(queue) + + +class CloudTasksQueuePauseOperator(BaseOperator): + """ + Pauses a queue in Cloud Tasks. + + :param location: The location name in which the queue will be paused. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + + template_fields = ( + "location", + "queue_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + queue_name: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.queue_name = queue_name + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + queue = hook.pause_queue( + location=self.location, + queue_name=self.queue_name, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Queue.to_dict(queue) + + +class CloudTasksQueueResumeOperator(BaseOperator): + """ + Resumes a queue in Cloud Tasks. + + :param location: The location name in which the queue will be resumed. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: list[google.cloud.tasks_v2.types.Queue] + """ + + template_fields = ( + "location", + "queue_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + queue_name: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.queue_name = queue_name + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + queue = hook.resume_queue( + location=self.location, + queue_name=self.queue_name, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Queue.to_dict(queue) + + +class CloudTasksTaskCreateOperator(BaseOperator): + """ + Creates a task in Cloud Tasks. + + :param location: The location name in which the task will be created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task: The task to add. + If a dict is provided, it must be of the same form as the protobuf message Task. + :type task: dict or google.cloud.tasks_v2.types.Task + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param task_name: (Optional) The task's name. + If provided, it will be used to construct the full task path. + :type task_name: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.enums.Task.View + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.tasks_v2.types.Task + """ + + template_fields = ( + "task", + "project_id", + "location", + "queue_name", + "task_name", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + location: str, + queue_name: str, + task: Union[Dict, Task], + project_id: Optional[str] = None, + task_name: Optional[str] = None, + response_view: Optional = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.queue_name = queue_name + self.task = task + self.project_id = project_id + self.task_name = task_name + self.response_view = response_view + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + task = hook.create_task( + location=self.location, + queue_name=self.queue_name, + task=self.task, + project_id=self.project_id, + task_name=self.task_name, + response_view=self.response_view, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Task.to_dict(task) + + +class CloudTasksTaskGetOperator(BaseOperator): + """ + Gets a task from Cloud Tasks. + + :param location: The location name in which the task was created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task_name: The task's name. + :type task_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.enums.Task.View + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.tasks_v2.types.Task + """ + + template_fields = ( + "location", + "queue_name", + "task_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + queue_name: str, + task_name: str, + project_id: Optional[str] = None, + response_view: Optional = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.queue_name = queue_name + self.task_name = task_name + self.project_id = project_id + self.response_view = response_view + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + task = hook.get_task( + location=self.location, + queue_name=self.queue_name, + task_name=self.task_name, + project_id=self.project_id, + response_view=self.response_view, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Task.to_dict(task) + + +class CloudTasksTasksListOperator(BaseOperator): + """ + Lists the tasks in Cloud Tasks. + + :param location: The location name in which the tasks were created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.enums.Task.View + :param page_size: (Optional) The maximum number of resources contained in the + underlying API response. + :type page_size: int + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: list[google.cloud.tasks_v2.types.Task] + """ + + template_fields = ( + "location", + "queue_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + queue_name: str, + project_id: Optional[str] = None, + response_view: Optional = None, + page_size: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.queue_name = queue_name + self.project_id = project_id + self.response_view = response_view + self.page_size = page_size + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + tasks = hook.list_tasks( + location=self.location, + queue_name=self.queue_name, + project_id=self.project_id, + response_view=self.response_view, + page_size=self.page_size, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return [Task.to_dict(t) for t in tasks] + + +class CloudTasksTaskDeleteOperator(BaseOperator): + """ + Deletes a task from Cloud Tasks. + + :param location: The location name in which the task will be deleted. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task_name: The task's name. + :type task_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "location", + "queue_name", + "task_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + queue_name: str, + task_name: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.queue_name = queue_name + self.task_name = task_name + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + hook.delete_task( + location=self.location, + queue_name=self.queue_name, + task_name=self.task_name, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudTasksTaskRunOperator(BaseOperator): + """ + Forces to run a task in Cloud Tasks. + + :param location: The location name in which the task was created. + :type location: str + :param queue_name: The queue's name. + :type queue_name: str + :param task_name: The task's name. + :type task_name: str + :param project_id: (Optional) The ID of the Google Cloud project that owns the Cloud Tasks. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param response_view: (Optional) This field specifies which subset of the Task will + be returned. + :type response_view: google.cloud.tasks_v2.Task.View + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :rtype: google.cloud.tasks_v2.types.Task + """ + + template_fields = ( + "location", + "queue_name", + "task_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + location: str, + queue_name: str, + task_name: str, + project_id: Optional[str] = None, + response_view: Optional = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.queue_name = queue_name + self.task_name = task_name + self.project_id = project_id + self.response_view = response_view + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + task = hook.run_task( + location=self.location, + queue_name=self.queue_name, + task_name=self.task_name, + project_id=self.project_id, + response_view=self.response_view, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Task.to_dict(task) diff --git a/reference/providers/google/cloud/operators/text_to_speech.py b/reference/providers/google/cloud/operators/text_to_speech.py new file mode 100644 index 0000000..8c81f7d --- /dev/null +++ b/reference/providers/google/cloud/operators/text_to_speech.py @@ -0,0 +1,156 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Text to Speech operator.""" + +from tempfile import NamedTemporaryFile +from typing import Dict, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.cloud.hooks.text_to_speech import CloudTextToSpeechHook +from airflow.utils.decorators import apply_defaults +from google.api_core.retry import Retry +from google.cloud.texttospeech_v1.types import ( + AudioConfig, + SynthesisInput, + VoiceSelectionParams, +) + + +class CloudTextToSpeechSynthesizeOperator(BaseOperator): + """ + Synthesizes text to speech and stores it in Google Cloud Storage + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudTextToSpeechSynthesizeOperator` + + :param input_data: text input to be synthesized. See more: + https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.SynthesisInput + :type input_data: dict or google.cloud.texttospeech_v1.types.SynthesisInput + :param voice: configuration of voice to be used in synthesis. See more: + https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.VoiceSelectionParams + :type voice: dict or google.cloud.texttospeech_v1.types.VoiceSelectionParams + :param audio_config: configuration of the synthesized audio. See more: + https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.AudioConfig + :type audio_config: dict or google.cloud.texttospeech_v1.types.AudioConfig + :param target_bucket_name: name of the GCS bucket in which output file should be stored + :type target_bucket_name: str + :param target_filename: filename of the output file. + :type target_filename: str + :param project_id: Optional, Google Cloud Project ID where the Compute + Engine Instance exists. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + :param retry: (Optional) A retry object used to retry requests. If None is specified, + requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to complete. + Note that if retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_text_to_speech_synthesize_template_fields] + template_fields = ( + "input_data", + "voice", + "audio_config", + "project_id", + "gcp_conn_id", + "target_bucket_name", + "target_filename", + "impersonation_chain", + ) + # [END gcp_text_to_speech_synthesize_template_fields] + + @apply_defaults + def __init__( + self, + *, + input_data: Union[Dict, SynthesisInput], + voice: Union[Dict, VoiceSelectionParams], + audio_config: Union[Dict, AudioConfig], + target_bucket_name: str, + target_filename: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.input_data = input_data + self.voice = voice + self.audio_config = audio_config + self.target_bucket_name = target_bucket_name + self.target_filename = target_filename + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.retry = retry + self.timeout = timeout + self._validate_inputs() + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def _validate_inputs(self) -> None: + for parameter in [ + "input_data", + "voice", + "audio_config", + "target_bucket_name", + "target_filename", + ]: + if getattr(self, parameter) == "": + raise AirflowException(f"The required parameter '{parameter}' is empty") + + def execute(self, context) -> None: + hook = CloudTextToSpeechHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + result = hook.synthesize_speech( + input_data=self.input_data, + voice=self.voice, + audio_config=self.audio_config, + retry=self.retry, + timeout=self.timeout, + ) + with NamedTemporaryFile() as temp_file: + temp_file.write(result.audio_content) + cloud_storage_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + cloud_storage_hook.upload( + bucket_name=self.target_bucket_name, + object_name=self.target_filename, + filename=temp_file.name, + ) diff --git a/reference/providers/google/cloud/operators/translate.py b/reference/providers/google/cloud/operators/translate.py new file mode 100644 index 0000000..30184a2 --- /dev/null +++ b/reference/providers/google/cloud/operators/translate.py @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Translate operators.""" +from typing import List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook +from airflow.utils.decorators import apply_defaults + + +class CloudTranslateTextOperator(BaseOperator): + """ + Translate a string or list of strings. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudTranslateTextOperator` + + See https://cloud.google.com/translate/docs/translating-text + + Execute method returns str or list. + + This is a list of dictionaries for each queried value. Each + dictionary typically contains three keys (though not + all will be present in all cases). + + * ``detectedSourceLanguage``: The detected language (as an + ISO 639-1 language code) of the text. + * ``translatedText``: The translation of the text into the + target language. + * ``input``: The corresponding input value. + * ``model``: The model used to translate the text. + + If only a single value is passed, then only a single + dictionary is set as XCom return value. + + :type values: str or list + :param values: String or list of strings to translate. + + :type target_language: str + :param target_language: The language to translate results into. This + is required by the API and defaults to + the target language of the current instance. + + :type format_: str or None + :param format_: (Optional) One of ``text`` or ``html``, to specify + if the input text is plain text or HTML. + + :type source_language: str or None + :param source_language: (Optional) The language of the text to + be translated. + + :type model: str or None + :param model: (Optional) The model used to translate the text, such + as ``'base'`` or ``'nmt'``. + + :type impersonation_chain: Union[str, Sequence[str]] + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + + """ + + # [START translate_template_fields] + template_fields = ( + "values", + "target_language", + "format_", + "source_language", + "model", + "gcp_conn_id", + "impersonation_chain", + ) + # [END translate_template_fields] + + @apply_defaults + def __init__( + self, + *, + values: Union[List[str], str], + target_language: str, + format_: str, + source_language: Optional[str], + model: str, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.values = values + self.target_language = target_language + self.format_ = format_ + self.source_language = source_language + self.model = model + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> dict: + hook = CloudTranslateHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + translation = hook.translate( + values=self.values, + target_language=self.target_language, + format_=self.format_, + source_language=self.source_language, + model=self.model, + ) + self.log.debug("Translation %s", translation) + return translation + except ValueError as e: + self.log.error("An error has been thrown from translate method:") + self.log.error(e) + raise AirflowException(e) diff --git a/reference/providers/google/cloud/operators/translate_speech.py b/reference/providers/google/cloud/operators/translate_speech.py new file mode 100644 index 0000000..04908f8 --- /dev/null +++ b/reference/providers/google/cloud/operators/translate_speech.py @@ -0,0 +1,188 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Translate Speech operator.""" +from typing import Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.speech_to_text import CloudSpeechToTextHook +from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook +from airflow.utils.decorators import apply_defaults +from google.cloud.speech_v1.types import RecognitionAudio, RecognitionConfig +from google.protobuf.json_format import MessageToDict + + +class CloudTranslateSpeechOperator(BaseOperator): + """ + Recognizes speech in audio input and translates it. + + Note that it uses the first result from the recognition api response - the one with the highest confidence + In order to see other possible results please use + :ref:`howto/operator:CloudSpeechToTextRecognizeSpeechOperator` + and + :ref:`howto/operator:CloudTranslateTextOperator` + separately + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudTranslateSpeechOperator` + + See https://cloud.google.com/translate/docs/translating-text + + Execute method returns string object with the translation + + This is a list of dictionaries queried value. + Dictionary typically contains three keys (though not + all will be present in all cases). + + * ``detectedSourceLanguage``: The detected language (as an + ISO 639-1 language code) of the text. + * ``translatedText``: The translation of the text into the + target language. + * ``input``: The corresponding input value. + * ``model``: The model used to translate the text. + + Dictionary is set as XCom return value. + + :param audio: audio data to be recognized. See more: + https://googleapis.github.io/google-cloud-python/latest/speech/gapic/v1/types.html#google.cloud.speech_v1.types.RecognitionAudio + :type audio: dict or google.cloud.speech_v1.types.RecognitionAudio + + :param config: information to the recognizer that specifies how to process the request. See more: + https://googleapis.github.io/google-cloud-python/latest/speech/gapic/v1/types.html#google.cloud.speech_v1.types.RecognitionConfig + :type config: dict or google.cloud.speech_v1.types.RecognitionConfig + + :param target_language: The language to translate results into. This is required by the API and defaults + to the target language of the current instance. + Check the list of available languages here: https://cloud.google.com/translate/docs/languages + :type target_language: str + + :param format_: (Optional) One of ``text`` or ``html``, to specify + if the input text is plain text or HTML. + :type format_: str or None + + :param source_language: (Optional) The language of the text to + be translated. + :type source_language: str or None + + :param model: (Optional) The model used to translate the text, such + as ``'base'`` or ``'nmt'``. + :type model: str or None + + :param project_id: Optional, Google Cloud Project ID where the Compute + Engine Instance exists. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to 'google_cloud_default'. + :type gcp_conn_id: str + + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + """ + + # [START translate_speech_template_fields] + template_fields = ( + "target_language", + "format_", + "source_language", + "model", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END translate_speech_template_fields] + + @apply_defaults + def __init__( + self, + *, + audio: RecognitionAudio, + config: RecognitionConfig, + target_language: str, + format_: str, + source_language: Optional[str], + model: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.audio = audio + self.config = config + self.target_language = target_language + self.format_ = format_ + self.source_language = source_language + self.model = model + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> dict: + speech_to_text_hook = CloudSpeechToTextHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + translate_hook = CloudTranslateHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + recognize_result = speech_to_text_hook.recognize_speech( + config=self.config, audio=self.audio + ) + recognize_dict = MessageToDict(recognize_result) + + self.log.info("Recognition operation finished") + + if not recognize_dict["results"]: + self.log.info("No recognition results") + return {} + self.log.debug("Recognition result: %s", recognize_dict) + + try: + transcript = recognize_dict["results"][0]["alternatives"][0]["transcript"] + except KeyError as key: + raise AirflowException( + f"Wrong response '{recognize_dict}' returned - it should contain {key} field" + ) + + try: + translation = translate_hook.translate( + values=transcript, + target_language=self.target_language, + format_=self.format_, + source_language=self.source_language, + model=self.model, + ) + self.log.info("Translated output: %s", translation) + return translation + except ValueError as e: + self.log.error("An error has been thrown from translate speech method:") + self.log.error(e) + raise AirflowException(e) diff --git a/reference/providers/google/cloud/operators/video_intelligence.py b/reference/providers/google/cloud/operators/video_intelligence.py new file mode 100644 index 0000000..0c056aa --- /dev/null +++ b/reference/providers/google/cloud/operators/video_intelligence.py @@ -0,0 +1,329 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud Vision operators.""" +from typing import Dict, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.video_intelligence import ( + CloudVideoIntelligenceHook, +) +from airflow.utils.decorators import apply_defaults +from google.api_core.retry import Retry +from google.cloud.videointelligence_v1 import enums +from google.cloud.videointelligence_v1.types import VideoContext +from google.protobuf.json_format import MessageToDict + + +class CloudVideoIntelligenceDetectVideoLabelsOperator(BaseOperator): + """ + Performs video annotation, annotating video labels. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVideoIntelligenceDetectVideoLabelsOperator`. + + :param input_uri: Input video location. Currently, only Google Cloud Storage URIs are supported, + which must be specified in the following format: ``gs://bucket-id/object-id``. + :type input_uri: str + :param input_content: The video data bytes. + If unset, the input video(s) should be specified via ``input_uri``. + If set, ``input_uri`` should be unset. + :type input_content: bytes + :param output_uri: Optional, location where the output (in JSON format) should be stored. Currently, only + Google Cloud Storage URIs are supported, which must be specified in the following format: + ``gs://bucket-id/object-id``. + :type output_uri: str + :param video_context: Optional, Additional video context and/or feature-specific parameters. + :type video_context: dict or google.cloud.videointelligence_v1.types.VideoContext + :param location: Optional, cloud region where annotation should take place. Supported cloud regions: + us-east1, us-west1, europe-west1, asia-east1. If no region is specified, a region will be determined + based on video file location. + :type location: str + :param retry: Retry object used to determine when/if to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: Optional, The amount of time, in seconds, to wait for the request to complete. + Note that if retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to ``google_cloud_default``. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_video_intelligence_detect_labels_template_fields] + template_fields = ( + "input_uri", + "output_uri", + "gcp_conn_id", + "impersonation_chain", + ) + # [END gcp_video_intelligence_detect_labels_template_fields] + + @apply_defaults + def __init__( + self, + *, + input_uri: str, + input_content: Optional[bytes] = None, + output_uri: Optional[str] = None, + video_context: Union[Dict, VideoContext] = None, + location: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.input_uri = input_uri + self.input_content = input_content + self.output_uri = output_uri + self.video_context = video_context + self.location = location + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.timeout = timeout + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVideoIntelligenceHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + operation = hook.annotate_video( + input_uri=self.input_uri, + input_content=self.input_content, + video_context=self.video_context, + location=self.location, + retry=self.retry, + features=[enums.Feature.LABEL_DETECTION], + timeout=self.timeout, + ) + self.log.info("Processing video for label annotations") + result = MessageToDict(operation.result()) + self.log.info("Finished processing.") + return result + + +class CloudVideoIntelligenceDetectVideoExplicitContentOperator(BaseOperator): + """ + Performs video annotation, annotating explicit content. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVideoIntelligenceDetectVideoExplicitContentOperator` + + :param input_uri: Input video location. Currently, only Google Cloud Storage URIs are supported, + which must be specified in the following format: ``gs://bucket-id/object-id``. + :type input_uri: str + :param input_content: The video data bytes. + If unset, the input video(s) should be specified via ``input_uri``. + If set, ``input_uri`` should be unset. + :type input_content: bytes + :param output_uri: Optional, location where the output (in JSON format) should be stored. Currently, only + Google Cloud Storage URIs are supported, which must be specified in the following format: + ``gs://bucket-id/object-id``. + :type output_uri: str + :param video_context: Optional, Additional video context and/or feature-specific parameters. + :type video_context: dict or google.cloud.videointelligence_v1.types.VideoContext + :param location: Optional, cloud region where annotation should take place. Supported cloud regions: + us-east1, us-west1, europe-west1, asia-east1. If no region is specified, a region will be determined + based on video file location. + :type location: str + :param retry: Retry object used to determine when/if to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: Optional, The amount of time, in seconds, to wait for the request to complete. + Note that if retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud + Defaults to ``google_cloud_default``. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_video_intelligence_detect_explicit_content_template_fields] + template_fields = ( + "input_uri", + "output_uri", + "gcp_conn_id", + "impersonation_chain", + ) + # [END gcp_video_intelligence_detect_explicit_content_template_fields] + + @apply_defaults + def __init__( + self, + *, + input_uri: str, + output_uri: Optional[str] = None, + input_content: Optional[bytes] = None, + video_context: Union[Dict, VideoContext] = None, + location: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.input_uri = input_uri + self.output_uri = output_uri + self.input_content = input_content + self.video_context = video_context + self.location = location + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.timeout = timeout + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVideoIntelligenceHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + operation = hook.annotate_video( + input_uri=self.input_uri, + input_content=self.input_content, + video_context=self.video_context, + location=self.location, + retry=self.retry, + features=[enums.Feature.EXPLICIT_CONTENT_DETECTION], + timeout=self.timeout, + ) + self.log.info("Processing video for explicit content annotations") + result = MessageToDict(operation.result()) + self.log.info("Finished processing.") + return result + + +class CloudVideoIntelligenceDetectVideoShotsOperator(BaseOperator): + """ + Performs video annotation, annotating video shots. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVideoIntelligenceDetectVideoShotsOperator` + + :param input_uri: Input video location. Currently, only Google Cloud Storage URIs are supported, + which must be specified in the following format: ``gs://bucket-id/object-id``. + :type input_uri: str + :param input_content: The video data bytes. + If unset, the input video(s) should be specified via ``input_uri``. + If set, ``input_uri`` should be unset. + :type input_content: bytes + :param output_uri: Optional, location where the output (in JSON format) should be stored. Currently, only + Google Cloud Storage URIs are supported, which must be specified in the following format: + ``gs://bucket-id/object-id``. + :type output_uri: str + :param video_context: Optional, Additional video context and/or feature-specific parameters. + :type video_context: dict or google.cloud.videointelligence_v1.types.VideoContext + :param location: Optional, cloud region where annotation should take place. Supported cloud regions: + us-east1, us-west1, europe-west1, asia-east1. If no region is specified, a region will be determined + based on video file location. + :type location: str + :param retry: Retry object used to determine when/if to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: Optional, The amount of time, in seconds, to wait for the request to complete. + Note that if retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud. + Defaults to ``google_cloud_default``. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_video_intelligence_detect_video_shots_template_fields] + template_fields = ( + "input_uri", + "output_uri", + "gcp_conn_id", + "impersonation_chain", + ) + # [END gcp_video_intelligence_detect_video_shots_template_fields] + + @apply_defaults + def __init__( + self, + *, + input_uri: str, + output_uri: Optional[str] = None, + input_content: Optional[bytes] = None, + video_context: Union[Dict, VideoContext] = None, + location: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.input_uri = input_uri + self.output_uri = output_uri + self.input_content = input_content + self.video_context = video_context + self.location = location + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.timeout = timeout + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVideoIntelligenceHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + operation = hook.annotate_video( + input_uri=self.input_uri, + input_content=self.input_content, + video_context=self.video_context, + location=self.location, + retry=self.retry, + features=[enums.Feature.SHOT_CHANGE_DETECTION], + timeout=self.timeout, + ) + self.log.info("Processing video for video shots annotations") + result = MessageToDict(operation.result()) + self.log.info("Finished processing.") + return result diff --git a/reference/providers/google/cloud/operators/vision.py b/reference/providers/google/cloud/operators/vision.py new file mode 100644 index 0000000..bf48760 --- /dev/null +++ b/reference/providers/google/cloud/operators/vision.py @@ -0,0 +1,1693 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Vision operator.""" + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.vision import CloudVisionHook +from airflow.utils.decorators import apply_defaults +from google.api_core.exceptions import AlreadyExists +from google.api_core.retry import Retry +from google.cloud.vision_v1.types import ( + AnnotateImageRequest, + FieldMask, + Image, + Product, + ProductSet, + ReferenceImage, +) + +MetaData = Sequence[Tuple[str, str]] + + +class CloudVisionCreateProductSetOperator(BaseOperator): + """ + Creates a new ProductSet resource. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionCreateProductSetOperator` + + :param product_set: (Required) The ProductSet to create. If a dict is provided, it must be of the same + form as the protobuf message `ProductSet`. + :type product_set: dict or google.cloud.vision_v1.types.ProductSet + :param location: (Required) The region where the ProductSet should be created. Valid regions + (as of 2019-02-05) are: us-east1, us-west1, europe-west1, asia-east1 + :type location: str + :param project_id: (Optional) The project in which the ProductSet should be created. If set to None or + missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param product_set_id: (Optional) A user-supplied resource id for this ProductSet. + If set, the server will attempt to use this value as the resource id. If it is + already in use, an error is returned with code ALREADY_EXISTS. Must be at most + 128 characters long. It cannot contain the character /. + :type product_set_id: str + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_productset_create_template_fields] + template_fields = ( + "location", + "project_id", + "product_set_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_productset_create_template_fields] + + @apply_defaults + def __init__( + self, + *, + product_set: Union[dict, ProductSet], + location: str, + project_id: Optional[str] = None, + product_set_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.project_id = project_id + self.product_set = product_set + self.product_set_id = product_set_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + return hook.create_product_set( + location=self.location, + project_id=self.project_id, + product_set=self.product_set, + product_set_id=self.product_set_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except AlreadyExists: + self.log.info( + "Product set with id %s already exists. Exiting from the create operation.", + self.product_set_id, + ) + return self.product_set_id + + +class CloudVisionGetProductSetOperator(BaseOperator): + """ + Gets information associated with a ProductSet. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionGetProductSetOperator` + + :param location: (Required) The region where the ProductSet is located. Valid regions (as of 2019-02-05) + are: us-east1, us-west1, europe-west1, asia-east1 + :type location: str + :param product_set_id: (Required) The resource id of this ProductSet. + :type product_set_id: str + :param project_id: (Optional) The project in which the ProductSet is located. If set + to None or missing, the default `project_id` from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_productset_get_template_fields] + template_fields = ( + "location", + "project_id", + "product_set_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_productset_get_template_fields] + + @apply_defaults + def __init__( + self, + *, + location: str, + product_set_id: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.project_id = project_id + self.product_set_id = product_set_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + return hook.get_product_set( + location=self.location, + product_set_id=self.product_set_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudVisionUpdateProductSetOperator(BaseOperator): + """ + Makes changes to a `ProductSet` resource. Only display_name can be updated currently. + + .. note:: To locate the `ProductSet` resource, its `name` in the form + `projects/PROJECT_ID/locations/LOC_ID/productSets/PRODUCT_SET_ID` is necessary. + + You can provide the `name` directly as an attribute of the `product_set` object. + However, you can leave it blank and provide `location` and `product_set_id` instead + (and optionally `project_id` - if not present, the connection default will be used) + and the `name` will be created by the operator itself. + + This mechanism exists for your convenience, to allow leaving the `project_id` empty + and having Airflow use the connection default `project_id`. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionUpdateProductSetOperator` + + :param product_set: (Required) The ProductSet resource which replaces the one on the + server. If a dict is provided, it must be of the same form as the protobuf + message `ProductSet`. + :type product_set: dict or google.cloud.vision_v1.types.ProductSet + :param location: (Optional) The region where the ProductSet is located. Valid regions (as of 2019-02-05) + are: us-east1, us-west1, europe-west1, asia-east1 + :type location: str + :param product_set_id: (Optional) The resource id of this ProductSet. + :type product_set_id: str + :param project_id: (Optional) The project in which the ProductSet should be created. If set to None or + missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param update_mask: (Optional) The `FieldMask` that specifies which fields to update. If update_mask + isn’t specified, all mutable fields are to be updated. Valid mask path is display_name. If a dict is + provided, it must be of the same form as the protobuf message `FieldMask`. + :type update_mask: dict or google.cloud.vision_v1.types.FieldMask + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_productset_update_template_fields] + template_fields = ( + "location", + "project_id", + "product_set_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_productset_update_template_fields] + + @apply_defaults + def __init__( + self, + *, + product_set: Union[Dict, ProductSet], + location: Optional[str] = None, + product_set_id: Optional[str] = None, + project_id: Optional[str] = None, + update_mask: Union[Dict, FieldMask] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.product_set = product_set + self.update_mask = update_mask + self.location = location + self.project_id = project_id + self.product_set_id = product_set_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + return hook.update_product_set( + location=self.location, + product_set_id=self.product_set_id, + project_id=self.project_id, + product_set=self.product_set, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudVisionDeleteProductSetOperator(BaseOperator): + """ + Permanently deletes a `ProductSet`. `Products` and `ReferenceImages` in the + `ProductSet` are not deleted. The actual image files are not deleted from Google + Cloud Storage. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionDeleteProductSetOperator` + + :param location: (Required) The region where the ProductSet is located. Valid regions (as of 2019-02-05) + are: us-east1, us-west1, europe-west1, asia-east1 + :type location: str + :param product_set_id: (Required) The resource id of this ProductSet. + :type product_set_id: str + :param project_id: (Optional) The project in which the ProductSet should be created. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_productset_delete_template_fields] + template_fields = ( + "location", + "project_id", + "product_set_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_productset_delete_template_fields] + + @apply_defaults + def __init__( + self, + *, + location: str, + product_set_id: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.project_id = project_id + self.product_set_id = product_set_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + hook.delete_product_set( + location=self.location, + product_set_id=self.product_set_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudVisionCreateProductOperator(BaseOperator): + """ + Creates and returns a new product resource. + + Possible errors regarding the `Product` object provided: + + - Returns `INVALID_ARGUMENT` if `display_name` is missing or longer than 4096 characters. + - Returns `INVALID_ARGUMENT` if `description` is longer than 4096 characters. + - Returns `INVALID_ARGUMENT` if `product_category` is missing or invalid. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionCreateProductOperator` + + :param location: (Required) The region where the Product should be created. Valid regions + (as of 2019-02-05) are: us-east1, us-west1, europe-west1, asia-east1 + :type location: str + :param product: (Required) The product to create. If a dict is provided, it must be of the same form as + the protobuf message `Product`. + :type product: dict or google.cloud.vision_v1.types.Product + :param project_id: (Optional) The project in which the Product should be created. If set to None or + missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param product_id: (Optional) A user-supplied resource id for this Product. + If set, the server will attempt to use this value as the resource id. If it is + already in use, an error is returned with code ALREADY_EXISTS. Must be at most + 128 characters long. It cannot contain the character /. + :type product_id: str + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_product_create_template_fields] + template_fields = ( + "location", + "project_id", + "product_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_product_create_template_fields] + + @apply_defaults + def __init__( + self, + *, + location: str, + product: str, + project_id: Optional[str] = None, + product_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.product = product + self.project_id = project_id + self.product_id = product_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + return hook.create_product( + location=self.location, + product=self.product, + project_id=self.project_id, + product_id=self.product_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except AlreadyExists: + self.log.info( + "Product with id %s already exists. Exiting from the create operation.", + self.product_id, + ) + return self.product_id + + +class CloudVisionGetProductOperator(BaseOperator): + """ + Gets information associated with a `Product`. + + Possible errors: + + - Returns `NOT_FOUND` if the `Product` does not exist. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionGetProductOperator` + + :param location: (Required) The region where the Product is located. Valid regions (as of 2019-02-05) are: + us-east1, us-west1, europe-west1, asia-east1 + :type location: str + :param product_id: (Required) The resource id of this Product. + :type product_id: str + :param project_id: (Optional) The project in which the Product is located. If set to + None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_product_get_template_fields] + template_fields = ( + "location", + "project_id", + "product_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_product_get_template_fields] + + @apply_defaults + def __init__( + self, + *, + location: str, + product_id: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.product_id = product_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + return hook.get_product( + location=self.location, + product_id=self.product_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudVisionUpdateProductOperator(BaseOperator): + """ + Makes changes to a Product resource. Only the display_name, description, and labels fields can be + updated right now. + + If labels are updated, the change will not be reflected in queries until the next index time. + + .. note:: To locate the `Product` resource, its `name` in the form + `projects/PROJECT_ID/locations/LOC_ID/products/PRODUCT_ID` is necessary. + + You can provide the `name` directly as an attribute of the `product` object. However, you can leave it + blank and provide `location` and `product_id` instead (and optionally `project_id` - if not present, + the connection default will be used) and the `name` will be created by the operator itself. + + This mechanism exists for your convenience, to allow leaving the `project_id` empty and having Airflow + use the connection default `project_id`. + + Possible errors related to the provided `Product`: + + - Returns `NOT_FOUND` if the Product does not exist. + - Returns `INVALID_ARGUMENT` if `display_name` is present in update_mask but is missing from the request + or longer than 4096 characters. + - Returns `INVALID_ARGUMENT` if `description` is present in update_mask but is longer than 4096 + characters. + - Returns `INVALID_ARGUMENT` if `product_category` is present in update_mask. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionUpdateProductOperator` + + :param product: (Required) The Product resource which replaces the one on the server. product.name is + immutable. If a dict is provided, it must be of the same form as the protobuf message `Product`. + :type product: dict or google.cloud.vision_v1.types.ProductSet + :param location: (Optional) The region where the Product is located. Valid regions (as of 2019-02-05) are: + us-east1, us-west1, europe-west1, asia-east1 + :type location: str + :param product_id: (Optional) The resource id of this Product. + :type product_id: str + :param project_id: (Optional) The project in which the Product is located. If set to None or + missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param update_mask: (Optional) The `FieldMask` that specifies which fields to update. If update_mask + isn’t specified, all mutable fields are to be updated. Valid mask paths include product_labels, + display_name, and description. If a dict is provided, it must be of the same form as the protobuf + message `FieldMask`. + :type update_mask: dict or google.cloud.vision_v1.types.FieldMask + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_product_update_template_fields] + template_fields = ( + "location", + "project_id", + "product_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_product_update_template_fields] + + @apply_defaults + def __init__( + self, + *, + product: Union[Dict, Product], + location: Optional[str] = None, + product_id: Optional[str] = None, + project_id: Optional[str] = None, + update_mask: Union[Dict, FieldMask] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.product = product + self.location = location + self.product_id = product_id + self.project_id = project_id + self.update_mask = update_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + return hook.update_product( + product=self.product, + location=self.location, + product_id=self.product_id, + project_id=self.project_id, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudVisionDeleteProductOperator(BaseOperator): + """ + Permanently deletes a product and its reference images. + + Metadata of the product and all its images will be deleted right away, but search queries against + ProductSets containing the product may still work until all related caches are refreshed. + + Possible errors: + + - Returns `NOT_FOUND` if the product does not exist. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionDeleteProductOperator` + + :param location: (Required) The region where the Product is located. Valid regions (as of 2019-02-05) are: + us-east1, us-west1, europe-west1, asia-east1 + :type location: str + :param product_id: (Required) The resource id of this Product. + :type product_id: str + :param project_id: (Optional) The project in which the Product is located. If set to None or + missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_product_delete_template_fields] + template_fields = ( + "location", + "project_id", + "product_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_product_delete_template_fields] + + @apply_defaults + def __init__( + self, + *, + location: str, + product_id: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.product_id = product_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + hook.delete_product( + location=self.location, + product_id=self.product_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudVisionImageAnnotateOperator(BaseOperator): + """ + Run image detection and annotation for an image or a batch of images. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionImageAnnotateOperator` + + :param request: (Required) Annotation request for image or a batch. + If a dict is provided, it must be of the same form as the protobuf + message class:`google.cloud.vision_v1.types.AnnotateImageRequest` + :type request: list[dict or google.cloud.vision_v1.types.AnnotateImageRequest] for batch or + dict or google.cloud.vision_v1.types.AnnotateImageRequest for single image. + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: float + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_annotate_image_template_fields] + template_fields = ( + "request", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_annotate_image_template_fields] + + @apply_defaults + def __init__( + self, + *, + request: Union[Dict, AnnotateImageRequest], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.request = request + self.retry = retry + self.timeout = timeout + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + if not isinstance(self.request, list): + response = hook.annotate_image( + request=self.request, retry=self.retry, timeout=self.timeout + ) + else: + response = hook.batch_annotate_images( + requests=self.request, retry=self.retry, timeout=self.timeout + ) + + return response + + +class CloudVisionCreateReferenceImageOperator(BaseOperator): + """ + Creates and returns a new ReferenceImage ID resource. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionCreateReferenceImageOperator` + + :param location: (Required) The region where the Product is located. Valid regions (as of 2019-02-05) are: + us-east1, us-west1, europe-west1, asia-east1 + :type location: str + :param reference_image: (Required) The reference image to create. If an image ID is specified, it is + ignored. + If a dict is provided, it must be of the same form as the protobuf message + :class:`google.cloud.vision_v1.types.ReferenceImage` + :type reference_image: dict or google.cloud.vision_v1.types.ReferenceImage + :param reference_image_id: (Optional) A user-supplied resource id for the ReferenceImage to be added. + If set, the server will attempt to use this value as the resource id. If it is already in use, an + error is returned with code ALREADY_EXISTS. Must be at most 128 characters long. It cannot contain + the character `/`. + :type reference_image_id: str + :param product_id: (Optional) The resource id of this Product. + :type product_id: str + :param project_id: (Optional) The project in which the Product is located. If set to None or + missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_reference_image_create_template_fields] + template_fields = ( + "location", + "reference_image", + "product_id", + "reference_image_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_reference_image_create_template_fields] + + @apply_defaults + def __init__( + self, + *, + location: str, + reference_image: Union[Dict, ReferenceImage], + product_id: str, + reference_image_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.product_id = product_id + self.reference_image = reference_image + self.reference_image_id = reference_image_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + try: + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + return hook.create_reference_image( + location=self.location, + product_id=self.product_id, + reference_image=self.reference_image, + reference_image_id=self.reference_image_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except AlreadyExists: + self.log.info( + "ReferenceImage with id %s already exists. Exiting from the create operation.", + self.product_id, + ) + return self.reference_image_id + + +class CloudVisionDeleteReferenceImageOperator(BaseOperator): + """ + Deletes a ReferenceImage ID resource. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionDeleteReferenceImageOperator` + + :param location: (Required) The region where the Product is located. Valid regions (as of 2019-02-05) are: + us-east1, us-west1, europe-west1, asia-east1 + :type location: str + :param reference_image_id: (Optional) A user-supplied resource id for the ReferenceImage to be added. + If set, the server will attempt to use this value as the resource id. If it is already in use, an + error is returned with code ALREADY_EXISTS. Must be at most 128 characters long. It cannot contain + the character `/`. + :type reference_image_id: str + :param product_id: (Optional) The resource id of this Product. + :type product_id: str + :param project_id: (Optional) The project in which the Product is located. If set to None or + missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_reference_image_create_template_fields] + template_fields = ( + "location", + "product_id", + "reference_image_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_reference_image_create_template_fields] + + @apply_defaults + def __init__( + self, + *, + location: str, + product_id: str, + reference_image_id: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.product_id = product_id + self.reference_image_id = reference_image_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + hook.delete_reference_image( + location=self.location, + product_id=self.product_id, + reference_image_id=self.reference_image_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudVisionAddProductToProductSetOperator(BaseOperator): + """ + Adds a Product to the specified ProductSet. If the Product is already present, no change is made. + + One Product can be added to at most 100 ProductSets. + + Possible errors: + + - Returns `NOT_FOUND` if the Product or the ProductSet doesn’t exist. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionAddProductToProductSetOperator` + + :param product_set_id: (Required) The resource id for the ProductSet to modify. + :type product_set_id: str + :param product_id: (Required) The resource id of this Product. + :type product_id: str + :param location: (Required) The region where the ProductSet is located. Valid regions (as of 2019-02-05) + are: us-east1, us-west1, europe-west1, asia-east1 + :type: str + :param project_id: (Optional) The project in which the Product is located. If set to None or + missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_add_product_to_product_set_template_fields] + template_fields = ( + "location", + "product_set_id", + "product_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_add_product_to_product_set_template_fields] + + @apply_defaults + def __init__( + self, + *, + product_set_id: str, + product_id: str, + location: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.product_set_id = product_set_id + self.product_id = product_id + self.location = location + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + return hook.add_product_to_product_set( + product_set_id=self.product_set_id, + product_id=self.product_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudVisionRemoveProductFromProductSetOperator(BaseOperator): + """ + Removes a Product from the specified ProductSet. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionRemoveProductFromProductSetOperator` + + :param product_set_id: (Required) The resource id for the ProductSet to modify. + :type product_set_id: str + :param product_id: (Required) The resource id of this Product. + :type product_id: str + :param location: (Required) The region where the ProductSet is located. Valid regions (as of 2019-02-05) + are: us-east1, us-west1, europe-west1, asia-east1 + :type: str + :param project_id: (Optional) The project in which the Product is located. If set to None or + missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]] + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_remove_product_from_product_set_template_fields] + template_fields = ( + "location", + "product_set_id", + "product_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_remove_product_from_product_set_template_fields] + + @apply_defaults + def __init__( + self, + *, + product_set_id: str, + product_id: str, + location: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.product_set_id = product_set_id + self.product_id = product_id + self.location = location + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + return hook.remove_product_from_product_set( + product_set_id=self.product_set_id, + product_id=self.product_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudVisionDetectTextOperator(BaseOperator): + """ + Detects Text in the image + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionDetectTextOperator` + + :param image: (Required) The image to analyze. See more: + https://googleapis.github.io/google-cloud-python/latest/vision/gapic/v1/types.html#google.cloud.vision_v1.types.Image + :type image: dict or google.cloud.vision_v1.types.Image + :param max_results: (Optional) Number of results to return. + :type max_results: int + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: Number of seconds before timing out. + :type timeout: float + :param language_hints: List of languages to use for TEXT_DETECTION. + In most cases, an empty value yields the best results since it enables automatic language detection. + For languages based on the Latin alphabet, setting language_hints is not needed. + :type language_hints: str or list[str] + :param web_detection_params: Parameters for web detection. + :type web_detection_params: dict + :param additional_properties: Additional properties to be set on the AnnotateImageRequest. See more: + :class:`google.cloud.vision_v1.types.AnnotateImageRequest` + :type additional_properties: dict + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_detect_text_set_template_fields] + template_fields = ( + "image", + "max_results", + "timeout", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_detect_text_set_template_fields] + + def __init__( + self, + image: Union[Dict, Image], + max_results: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + language_hints: Optional[Union[str, List[str]]] = None, + web_detection_params: Optional[Dict] = None, + additional_properties: Optional[Dict] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.image = image + self.max_results = max_results + self.retry = retry + self.timeout = timeout + self.gcp_conn_id = gcp_conn_id + self.kwargs = kwargs + self.additional_properties = prepare_additional_parameters( + additional_properties=additional_properties, + language_hints=language_hints, + web_detection_params=web_detection_params, + ) + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + return hook.text_detection( + image=self.image, + max_results=self.max_results, + retry=self.retry, + timeout=self.timeout, + additional_properties=self.additional_properties, + ) + + +class CloudVisionTextDetectOperator(BaseOperator): + """ + Detects Document Text in the image + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionTextDetectOperator` + + :param image: (Required) The image to analyze. See more: + https://googleapis.github.io/google-cloud-python/latest/vision/gapic/v1/types.html#google.cloud.vision_v1.types.Image + :type image: dict or google.cloud.vision_v1.types.Image + :param max_results: Number of results to return. + :type max_results: int + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: Number of seconds before timing out. + :type timeout: float + :param language_hints: List of languages to use for TEXT_DETECTION. + In most cases, an empty value yields the best results since it enables automatic language detection. + For languages based on the Latin alphabet, setting language_hints is not needed. + :type language_hints: str or list[str] + :param web_detection_params: Parameters for web detection. + :type web_detection_params: dict + :param additional_properties: Additional properties to be set on the AnnotateImageRequest. See more: + https://googleapis.github.io/google-cloud-python/latest/vision/gapic/v1/types.html#google.cloud.vision_v1.types.AnnotateImageRequest + :type additional_properties: dict + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_document_detect_text_set_template_fields] + template_fields = ( + "image", + "max_results", + "timeout", + "gcp_conn_id", + "impersonation_chain", + ) # Iterable[str] + # [END vision_document_detect_text_set_template_fields] + + def __init__( + self, + image: Union[Dict, Image], + max_results: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + language_hints: Optional[Union[str, List[str]]] = None, + web_detection_params: Optional[Dict] = None, + additional_properties: Optional[Dict] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.image = image + self.max_results = max_results + self.retry = retry + self.timeout = timeout + self.gcp_conn_id = gcp_conn_id + self.additional_properties = prepare_additional_parameters( + additional_properties=additional_properties, + language_hints=language_hints, + web_detection_params=web_detection_params, + ) + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + return hook.document_text_detection( + image=self.image, + max_results=self.max_results, + retry=self.retry, + timeout=self.timeout, + additional_properties=self.additional_properties, + ) + + +class CloudVisionDetectImageLabelsOperator(BaseOperator): + """ + Detects Document Text in the image + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionDetectImageLabelsOperator` + + :param image: (Required) The image to analyze. See more: + https://googleapis.github.io/google-cloud-python/latest/vision/gapic/v1/types.html#google.cloud.vision_v1.types.Image + :type image: dict or google.cloud.vision_v1.types.Image + :param max_results: Number of results to return. + :type max_results: int + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: Number of seconds before timing out. + :type timeout: float + :param additional_properties: Additional properties to be set on the AnnotateImageRequest. See more: + https://googleapis.github.io/google-cloud-python/latest/vision/gapic/v1/types.html#google.cloud.vision_v1.types.AnnotateImageRequest + :type additional_properties: dict + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_detect_labels_template_fields] + template_fields = ( + "image", + "max_results", + "timeout", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_detect_labels_template_fields] + + def __init__( + self, + image: Union[Dict, Image], + max_results: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + additional_properties: Optional[Dict] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.image = image + self.max_results = max_results + self.retry = retry + self.timeout = timeout + self.gcp_conn_id = gcp_conn_id + self.additional_properties = additional_properties + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + return hook.label_detection( + image=self.image, + max_results=self.max_results, + retry=self.retry, + timeout=self.timeout, + additional_properties=self.additional_properties, + ) + + +class CloudVisionDetectImageSafeSearchOperator(BaseOperator): + """ + Detects Document Text in the image + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudVisionDetectImageSafeSearchOperator` + + :param image: (Required) The image to analyze. See more: + https://googleapis.github.io/google-cloud-python/latest/vision/gapic/v1/types.html#google.cloud.vision_v1.types.Image + :type image: dict or google.cloud.vision_v1.types.Image + :param max_results: Number of results to return. + :type max_results: int + :param retry: (Optional) A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: Number of seconds before timing out. + :type timeout: float + :param additional_properties: Additional properties to be set on the AnnotateImageRequest. See more: + https://googleapis.github.io/google-cloud-python/latest/vision/gapic/v1/types.html#google.cloud.vision_v1.types.AnnotateImageRequest + :type additional_properties: dict + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START vision_detect_safe_search_template_fields] + template_fields = ( + "image", + "max_results", + "timeout", + "gcp_conn_id", + "impersonation_chain", + ) + # [END vision_detect_safe_search_template_fields] + + def __init__( + self, + image: Union[Dict, Image], + max_results: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + additional_properties: Optional[Dict] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.image = image + self.max_results = max_results + self.retry = retry + self.timeout = timeout + self.gcp_conn_id = gcp_conn_id + self.additional_properties = additional_properties + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = CloudVisionHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + return hook.safe_search_detection( + image=self.image, + max_results=self.max_results, + retry=self.retry, + timeout=self.timeout, + additional_properties=self.additional_properties, + ) + + +def prepare_additional_parameters( + additional_properties: Optional[Dict], + language_hints: Any, + web_detection_params: Any, +) -> Optional[Dict]: + """ + Creates additional_properties parameter based on language_hints, web_detection_params and + additional_properties parameters specified by the user + """ + if language_hints is None and web_detection_params is None: + return additional_properties + + if additional_properties is None: + return {} + + merged_additional_parameters = deepcopy(additional_properties) + + if "image_context" not in merged_additional_parameters: + merged_additional_parameters["image_context"] = {} + + merged_additional_parameters["image_context"][ + "language_hints" + ] = merged_additional_parameters["image_context"].get( + "language_hints", language_hints + ) + merged_additional_parameters["image_context"][ + "web_detection_params" + ] = merged_additional_parameters["image_context"].get( + "web_detection_params", web_detection_params + ) + + return merged_additional_parameters diff --git a/reference/providers/google/cloud/operators/workflows.py b/reference/providers/google/cloud/operators/workflows.py new file mode 100644 index 0000000..466a7a7 --- /dev/null +++ b/reference/providers/google/cloud/operators/workflows.py @@ -0,0 +1,741 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import hashlib +import json +import re +import uuid +from datetime import datetime, timedelta +from typing import Dict, Optional, Sequence, Tuple, Union + +import pytz +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.workflows import WorkflowsHook +from google.api_core.exceptions import AlreadyExists +from google.api_core.retry import Retry + +# pylint: disable=no-name-in-module +from google.cloud.workflows.executions_v1beta import Execution +from google.cloud.workflows_v1beta import Workflow + +# pylint: enable=no-name-in-module +from google.protobuf.field_mask_pb2 import FieldMask + + +class WorkflowsCreateWorkflowOperator(BaseOperator): + """ + Creates a new workflow. If a workflow with the specified name + already exists in the specified project and location, the long + running operation will return + [ALREADY_EXISTS][google.rpc.Code.ALREADY_EXISTS] error. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:WorkflowsCreateWorkflowOperator` + + :param workflow: Required. Workflow to be created. + :type workflow: Dict + :param workflow_id: Required. The ID of the workflow to be created. + :type workflow_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + + template_fields = ("location", "workflow", "workflow_id") + template_fields_renderers = {"workflow": "json"} + + def __init__( + self, + *, + workflow: Dict, + workflow_id: str, + location: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + force_rerun: bool = False, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.workflow = workflow + self.workflow_id = workflow_id + self.location = location + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.force_rerun = force_rerun + + def _workflow_id(self, context): + if self.workflow_id and not self.force_rerun: + # If users provide workflow id then assuring the idempotency + # is on their side + return self.workflow_id + + if self.force_rerun: + hash_base = str(uuid.uuid4()) + else: + hash_base = json.dumps(self.workflow, sort_keys=True) + + # We are limited by allowed length of workflow_id so + # we use hash of whole information + exec_date = context["execution_date"].isoformat() + base = f"airflow_{self.dag_id}_{self.task_id}_{exec_date}_{hash_base}" + workflow_id = hashlib.md5(base.encode()).hexdigest() + return re.sub(r"[:\-+.]", "_", workflow_id) + + def execute(self, context): + hook = WorkflowsHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + workflow_id = self._workflow_id(context) + + self.log.info("Creating workflow") + try: + operation = hook.create_workflow( + workflow=self.workflow, + workflow_id=workflow_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + workflow = operation.result() + except AlreadyExists: + workflow = hook.get_workflow( + workflow_id=workflow_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Workflow.to_dict(workflow) + + +class WorkflowsUpdateWorkflowOperator(BaseOperator): + """ + Updates an existing workflow. + Running this method has no impact on already running + executions of the workflow. A new revision of the + workflow may be created as a result of a successful + update operation. In that case, such revision will be + used in new workflow executions. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:WorkflowsUpdateWorkflowOperator` + + :param workflow_id: Required. The ID of the workflow to be updated. + :type workflow_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param update_mask: List of fields to be updated. If not present, + the entire workflow will be updated. + :type update_mask: FieldMask + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + + template_fields = ("workflow_id", "update_mask") + template_fields_renderers = {"update_mask": "json"} + + def __init__( + self, + *, + workflow_id: str, + location: str, + project_id: Optional[str] = None, + update_mask: Optional[FieldMask] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.workflow_id = workflow_id + self.location = location + self.project_id = project_id + self.update_mask = update_mask + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = WorkflowsHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + + workflow = hook.get_workflow( + workflow_id=self.workflow_id, + project_id=self.project_id, + location=self.location, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Updating workflow") + operation = hook.update_workflow( + workflow=workflow, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + workflow = operation.result() + return Workflow.to_dict(workflow) + + +class WorkflowsDeleteWorkflowOperator(BaseOperator): + """ + Deletes a workflow with the specified name. + This method also cancels and deletes all running + executions of the workflow. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:WorkflowsDeleteWorkflowOperator` + + :param workflow_id: Required. The ID of the workflow to be created. + :type workflow_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + + template_fields = ("location", "workflow_id") + + def __init__( + self, + *, + workflow_id: str, + location: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.workflow_id = workflow_id + self.location = location + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = WorkflowsHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info("Deleting workflow %s", self.workflow_id) + operation = hook.delete_workflow( + workflow_id=self.workflow_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + operation.result() + + +class WorkflowsListWorkflowsOperator(BaseOperator): + """ + Lists Workflows in a given project and location. + The default order is not specified. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:WorkflowsListWorkflowsOperator` + + :param filter_: Filter to restrict results to specific workflows. + :type filter_: str + :param order_by: Comma-separated list of fields that that + specify the order of the results. Default sorting order for a field is ascending. + To specify descending order for a field, append a "desc" suffix. + If not specified, the results will be returned in an unspecified order. + :type order_by: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + + template_fields = ("location", "order_by", "filter_") + + def __init__( + self, + *, + location: str, + project_id: Optional[str] = None, + filter_: Optional[str] = None, + order_by: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.filter_ = filter_ + self.order_by = order_by + self.location = location + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = WorkflowsHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info("Retrieving workflows") + workflows_iter = hook.list_workflows( + filter_=self.filter_, + order_by=self.order_by, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return [Workflow.to_dict(w) for w in workflows_iter] + + +class WorkflowsGetWorkflowOperator(BaseOperator): + """ + Gets details of a single Workflow. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:WorkflowsGetWorkflowOperator` + + :param workflow_id: Required. The ID of the workflow to be created. + :type workflow_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + + template_fields = ("location", "workflow_id") + + def __init__( + self, + *, + workflow_id: str, + location: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.workflow_id = workflow_id + self.location = location + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = WorkflowsHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info("Retrieving workflow") + workflow = hook.get_workflow( + workflow_id=self.workflow_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Workflow.to_dict(workflow) + + +class WorkflowsCreateExecutionOperator(BaseOperator): + """ + Creates a new execution using the latest revision of + the given workflow. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:WorkflowsCreateExecutionOperator` + + :param execution: Required. Execution to be created. + :type execution: Dict + :param workflow_id: Required. The ID of the workflow. + :type workflow_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + + template_fields = ("location", "workflow_id", "execution") + template_fields_renderers = {"execution": "json"} + + def __init__( + self, + *, + workflow_id: str, + execution: Dict, + location: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.workflow_id = workflow_id + self.execution = execution + self.location = location + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = WorkflowsHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info("Creating execution") + execution = hook.create_execution( + workflow_id=self.workflow_id, + execution=self.execution, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + execution_id = execution.name.split("/")[-1] + self.xcom_push(context, key="execution_id", value=execution_id) + return Execution.to_dict(execution) + + +class WorkflowsCancelExecutionOperator(BaseOperator): + """ + Cancels an execution using the given ``workflow_id`` and ``execution_id``. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:WorkflowsCancelExecutionOperator` + + :param workflow_id: Required. The ID of the workflow. + :type workflow_id: str + :param execution_id: Required. The ID of the execution. + :type execution_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + + template_fields = ("location", "workflow_id", "execution_id") + + def __init__( + self, + *, + workflow_id: str, + execution_id: str, + location: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.workflow_id = workflow_id + self.execution_id = execution_id + self.location = location + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = WorkflowsHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info("Canceling execution %s", self.execution_id) + execution = hook.cancel_execution( + workflow_id=self.workflow_id, + execution_id=self.execution_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Execution.to_dict(execution) + + +class WorkflowsListExecutionsOperator(BaseOperator): + """ + Returns a list of executions which belong to the + workflow with the given name. The method returns + executions of all workflow revisions. Returned + executions are ordered by their start time (newest + first). + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:WorkflowsListExecutionsOperator` + + :param workflow_id: Required. The ID of the workflow to be created. + :type workflow_id: str + :param start_date_filter: If passed only executions older that this date will be returned. + By default operators return executions from last 60 minutes + :type start_date_filter: datetime + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + + template_fields = ("location", "workflow_id") + + def __init__( + self, + *, + workflow_id: str, + location: str, + start_date_filter: Optional[datetime] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.workflow_id = workflow_id + self.location = location + self.start_date_filter = start_date_filter or datetime.now( + tz=pytz.UTC + ) - timedelta(minutes=60) + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = WorkflowsHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info("Retrieving executions for workflow %s", self.workflow_id) + execution_iter = hook.list_executions( + workflow_id=self.workflow_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + return [ + Execution.to_dict(e) + for e in execution_iter + if e.start_time > self.start_date_filter + ] + + +class WorkflowsGetExecutionOperator(BaseOperator): + """ + Returns an execution for the given ``workflow_id`` and ``execution_id``. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:WorkflowsGetExecutionOperator` + + :param workflow_id: Required. The ID of the workflow. + :type workflow_id: str + :param execution_id: Required. The ID of the execution. + :type execution_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The GCP region in which to handle the request. + :type location: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + + template_fields = ("location", "workflow_id", "execution_id") + + def __init__( + self, + *, + workflow_id: str, + execution_id: str, + location: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.workflow_id = workflow_id + self.execution_id = execution_id + self.location = location + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context): + hook = WorkflowsHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info( + "Retrieving execution %s for workflow %s", + self.execution_id, + self.workflow_id, + ) + execution = hook.get_execution( + workflow_id=self.workflow_id, + execution_id=self.execution_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return Execution.to_dict(execution) diff --git a/reference/providers/google/cloud/secrets/__init__.py b/reference/providers/google/cloud/secrets/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/cloud/secrets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/cloud/secrets/secret_manager.py b/reference/providers/google/cloud/secrets/secret_manager.py new file mode 100644 index 0000000..3ec248c --- /dev/null +++ b/reference/providers/google/cloud/secrets/secret_manager.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Objects relating to sourcing connections from Google Cloud Secrets Manager""" +from typing import Optional + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud._internal_client.secret_manager_client import ( # noqa + _SecretManagerClient, +) +from airflow.providers.google.cloud.utils.credentials_provider import ( + get_credentials_and_project_id, +) +from airflow.secrets import BaseSecretsBackend +from airflow.utils.log.logging_mixin import LoggingMixin + +SECRET_ID_PATTERN = r"^[a-zA-Z0-9-_]*$" + + +class CloudSecretManagerBackend(BaseSecretsBackend, LoggingMixin): + """ + Retrieves Connection object from Google Cloud Secrets Manager + + Configurable via ``airflow.cfg`` as follows: + + .. code-block:: ini + + [secrets] + backend = airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend + backend_kwargs = {"connections_prefix": "airflow-connections", "sep": "-"} + + For example, if the Secrets Manager secret id is ``airflow-connections-smtp_default``, this would be + accessible if you provide ``{"connections_prefix": "airflow-connections", "sep": "-"}`` and request + conn_id ``smtp_default``. + + If the Secrets Manager secret id is ``airflow-variables-hello``, this would be + accessible if you provide ``{"variables_prefix": "airflow-variables", "sep": "-"}`` and request + Variable Key ``hello``. + + The full secret id should follow the pattern "[a-zA-Z0-9-_]". + + :param connections_prefix: Specifies the prefix of the secret to read to get Connections. + If set to None (null), requests for connections will not be sent to GCP Secrets Manager + :type connections_prefix: str + :param variables_prefix: Specifies the prefix of the secret to read to get Variables. + If set to None (null), requests for variables will not be sent to GCP Secrets Manager + :type variables_prefix: str + :param config_prefix: Specifies the prefix of the secret to read to get Airflow Configurations + containing secrets. + If set to None (null), requests for configurations will not be sent to GCP Secrets Manager + :type config_prefix: str + :param gcp_key_path: Path to Google Cloud Service Account key file (JSON). Mutually exclusive with + gcp_keyfile_dict. use default credentials in the current environment if not provided. + :type gcp_key_path: str + :param gcp_keyfile_dict: Dictionary of keyfile parameters. Mutually exclusive with gcp_key_path. + :type gcp_keyfile_dict: dict + :param gcp_scopes: Comma-separated string containing OAuth2 scopes + :type gcp_scopes: str + :param project_id: Project ID to read the secrets from. If not passed, the project ID from credentials + will be used. + :type project_id: str + :param sep: Separator used to concatenate connections_prefix and conn_id. Default: "-" + :type sep: str + """ + + def __init__( + self, + connections_prefix: str = "airflow-connections", + variables_prefix: str = "airflow-variables", + config_prefix: str = "airflow-config", + gcp_keyfile_dict: Optional[dict] = None, + gcp_key_path: Optional[str] = None, + gcp_scopes: Optional[str] = None, + project_id: Optional[str] = None, + sep: str = "-", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.connections_prefix = connections_prefix + self.variables_prefix = variables_prefix + self.config_prefix = config_prefix + self.sep = sep + if connections_prefix is not None: + if not self._is_valid_prefix_and_sep(): + raise AirflowException( + "`connections_prefix`, `variables_prefix` and `sep` should " + f"follows that pattern {SECRET_ID_PATTERN}" + ) + self.credentials, self.project_id = get_credentials_and_project_id( + keyfile_dict=gcp_keyfile_dict, key_path=gcp_key_path, scopes=gcp_scopes + ) + # In case project id provided + if project_id: + self.project_id = project_id + + @cached_property + def client(self) -> _SecretManagerClient: + """ + Cached property returning secret client. + + :return: Secrets client + """ + return _SecretManagerClient(credentials=self.credentials) + + def _is_valid_prefix_and_sep(self) -> bool: + prefix = self.connections_prefix + self.sep + return _SecretManagerClient.is_valid_secret_name(prefix) + + def get_conn_uri(self, conn_id: str) -> Optional[str]: + """ + Get secret value from the SecretManager. + + :param conn_id: connection id + :type conn_id: str + """ + if self.connections_prefix is None: + return None + + return self._get_secret(self.connections_prefix, conn_id) + + def get_variable(self, key: str) -> Optional[str]: + """ + Get Airflow Variable from Environment Variable + + :param key: Variable Key + :return: Variable Value + """ + if self.variables_prefix is None: + return None + + return self._get_secret(self.variables_prefix, key) + + def get_config(self, key: str) -> Optional[str]: + """ + Get Airflow Configuration + + :param key: Configuration Option Key + :return: Configuration Option Value + """ + if self.config_prefix is None: + return None + + return self._get_secret(self.config_prefix, key) + + def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]: + """ + Get secret value from the SecretManager based on prefix. + + :param path_prefix: Prefix for the Path to get Secret + :type path_prefix: str + :param secret_id: Secret Key + :type secret_id: str + """ + secret_id = self.build_path(path_prefix, secret_id, self.sep) + return self.client.get_secret(secret_id=secret_id, project_id=self.project_id) diff --git a/reference/providers/google/cloud/sensors/__init__.py b/reference/providers/google/cloud/sensors/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/cloud/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/cloud/sensors/bigquery.py b/reference/providers/google/cloud/sensors/bigquery.py new file mode 100644 index 0000000..4b81555 --- /dev/null +++ b/reference/providers/google/cloud/sensors/bigquery.py @@ -0,0 +1,183 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Bigquery sensor.""" +from typing import Optional, Sequence, Union + +from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class BigQueryTableExistenceSensor(BaseSensorOperator): + """ + Checks for the existence of a table in Google Bigquery. + + :param project_id: The Google cloud project in which to look for the table. + The connection supplied to the hook must provide + access to the specified project. + :type project_id: str + :param dataset_id: The name of the dataset in which to look for the table. + storage bucket. + :type dataset_id: str + :param table_id: The name of the table to check the existence of. + :type table_id: str + :param bigquery_conn_id: The connection ID to use when connecting to + Google BigQuery. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "project_id", + "dataset_id", + "table_id", + "impersonation_chain", + ) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + *, + project_id: str, + dataset_id: str, + table_id: str, + bigquery_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self.project_id = project_id + self.dataset_id = dataset_id + self.table_id = table_id + self.bigquery_conn_id = bigquery_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def poke(self, context: dict) -> bool: + table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}" + self.log.info("Sensor checks existence of table: %s", table_uri) + hook = BigQueryHook( + bigquery_conn_id=self.bigquery_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + return hook.table_exists( + project_id=self.project_id, + dataset_id=self.dataset_id, + table_id=self.table_id, + ) + + +class BigQueryTablePartitionExistenceSensor(BaseSensorOperator): + """ + Checks for the existence of a partition within a table in Google Bigquery. + + :param project_id: The Google cloud project in which to look for the table. + The connection supplied to the hook must provide + access to the specified project. + :type project_id: str + :param dataset_id: The name of the dataset in which to look for the table. + storage bucket. + :type dataset_id: str + :param table_id: The name of the table to check the existence of. + :type table_id: str + :param partition_id: The name of the partition to check the existence of. + :type partition_id: str + :param bigquery_conn_id: The connection ID to use when connecting to + Google BigQuery. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must + have domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "project_id", + "dataset_id", + "table_id", + "partition_id", + "impersonation_chain", + ) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + *, + project_id: str, + dataset_id: str, + table_id: str, + partition_id: str, + bigquery_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self.project_id = project_id + self.dataset_id = dataset_id + self.table_id = table_id + self.partition_id = partition_id + self.bigquery_conn_id = bigquery_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def poke(self, context: dict) -> bool: + table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}" + self.log.info( + 'Sensor checks existence of partition: "%s" in table: %s', + self.partition_id, + table_uri, + ) + hook = BigQueryHook( + bigquery_conn_id=self.bigquery_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + return hook.table_partition_exists( + project_id=self.project_id, + dataset_id=self.dataset_id, + table_id=self.table_id, + partition_id=self.partition_id, + ) diff --git a/reference/providers/google/cloud/sensors/bigquery_dts.py b/reference/providers/google/cloud/sensors/bigquery_dts.py new file mode 100644 index 0000000..6ada087 --- /dev/null +++ b/reference/providers/google/cloud/sensors/bigquery_dts.py @@ -0,0 +1,139 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google BigQuery Data Transfer Service sensor.""" +from typing import Optional, Sequence, Set, Tuple, Union + +from airflow.providers.google.cloud.hooks.bigquery_dts import ( + BiqQueryDataTransferServiceHook, +) +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults +from google.api_core.retry import Retry +from google.cloud.bigquery_datatransfer_v1 import TransferState + + +class BigQueryDataTransferServiceTransferRunSensor(BaseSensorOperator): + """ + Waits for Data Transfer Service run to complete. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/operator:BigQueryDataTransferServiceTransferRunSensor` + + :param expected_statuses: The expected state of the operation. + See: + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferOperations#Status + :type expected_statuses: Union[Set[str], str] + :param run_id: ID of the transfer run. + :type run_id: str + :param transfer_config_id: ID of transfer config to be used. + :type transfer_config_id: str + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param request_timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type request_timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :return: An ``google.cloud.bigquery_datatransfer_v1.types.TransferRun`` instance. + """ + + template_fields = ( + "run_id", + "transfer_config_id", + "expected_statuses", + "project_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + run_id: str, + transfer_config_id: str, + expected_statuses: Union[ + Set[Union[str, TransferState, int]], str, TransferState, int + ] = TransferState.SUCCEEDED, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + retry: Optional[Retry] = None, + request_timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.run_id = run_id + self.transfer_config_id = transfer_config_id + self.retry = retry + self.request_timeout = request_timeout + self.metadata = metadata + self.expected_statuses = self._normalize_state_list(expected_statuses) + self.project_id = project_id + self.gcp_cloud_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def _normalize_state_list(self, states) -> Set[TransferState]: + states = {states} if isinstance(states, (str, TransferState, int)) else states + result = set() + for state in states: + if isinstance(state, str): + result.add(TransferState[state.upper()]) + elif isinstance(state, int): + result.add(TransferState(state)) + elif isinstance(state, TransferState): + result.add(state) + else: + raise TypeError( + f"Unsupported type. " + f"Expected: str, int, google.cloud.bigquery_datatransfer_v1.TransferState." + f"Current type: {type(state)}" + ) + return result + + def poke(self, context: dict) -> bool: + hook = BiqQueryDataTransferServiceHook( + gcp_conn_id=self.gcp_cloud_conn_id, + impersonation_chain=self.impersonation_chain, + ) + run = hook.get_transfer_run( + run_id=self.run_id, + transfer_config_id=self.transfer_config_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.request_timeout, + metadata=self.metadata, + ) + self.log.info("Status of %s run: %s", self.run_id, str(run.state)) + return run.state in self.expected_statuses diff --git a/reference/providers/google/cloud/sensors/bigtable.py b/reference/providers/google/cloud/sensors/bigtable.py new file mode 100644 index 0000000..403126d --- /dev/null +++ b/reference/providers/google/cloud/sensors/bigtable.py @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud Bigtable sensor.""" +from typing import Optional, Sequence, Union + +import google.api_core.exceptions +from airflow.providers.google.cloud.hooks.bigtable import BigtableHook +from airflow.providers.google.cloud.operators.bigtable import BigtableValidationMixin +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults +from google.cloud.bigtable.table import ClusterState +from google.cloud.bigtable_admin_v2 import enums + + +class BigtableTableReplicationCompletedSensor( + BaseSensorOperator, BigtableValidationMixin +): + """ + Sensor that waits for Cloud Bigtable table to be fully replicated to its clusters. + No exception will be raised if the instance or the table does not exist. + + For more details about cluster states for a table, have a look at the reference: + https://googleapis.github.io/google-cloud-python/latest/bigtable/table.html#google.cloud.bigtable.table.Table.get_cluster_states + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigtableTableReplicationCompletedSensor` + + :type instance_id: str + :param instance_id: The ID of the Cloud Bigtable instance. + :type table_id: str + :param table_id: The ID of the table to check replication status. + :type project_id: str + :param project_id: Optional, the ID of the Google Cloud project. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + REQUIRED_ATTRIBUTES = ("instance_id", "table_id") + template_fields = [ + "project_id", + "instance_id", + "table_id", + "impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + instance_id: str, + table_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.project_id = project_id + self.instance_id = instance_id + self.table_id = table_id + self.gcp_conn_id = gcp_conn_id + self._validate_inputs() + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def poke(self, context: dict) -> bool: + hook = BigtableHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + instance = hook.get_instance( + project_id=self.project_id, instance_id=self.instance_id + ) + if not instance: + self.log.info("Dependency: instance '%s' does not exist.", self.instance_id) + return False + + try: + cluster_states = hook.get_cluster_states_for_table( + instance=instance, table_id=self.table_id + ) + except google.api_core.exceptions.NotFound: + self.log.info( + "Dependency: table '%s' does not exist in instance '%s'.", + self.table_id, + self.instance_id, + ) + return False + + ready_state = ClusterState(enums.Table.ClusterState.ReplicationState.READY) + + is_table_replicated = True + for cluster_id in cluster_states.keys(): + if cluster_states[cluster_id] != ready_state: + self.log.info( + "Table '%s' is not yet replicated on cluster '%s'.", + self.table_id, + cluster_id, + ) + is_table_replicated = False + + if not is_table_replicated: + return False + + self.log.info("Table '%s' is replicated.", self.table_id) + return True diff --git a/reference/providers/google/cloud/sensors/cloud_storage_transfer_service.py b/reference/providers/google/cloud/sensors/cloud_storage_transfer_service.py new file mode 100644 index 0000000..e01a029 --- /dev/null +++ b/reference/providers/google/cloud/sensors/cloud_storage_transfer_service.py @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Transfer sensor.""" +from typing import Optional, Sequence, Set, Union + +from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( + COUNTERS, + METADATA, + NAME, + CloudDataTransferServiceHook, +) +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class CloudDataTransferServiceJobStatusSensor(BaseSensorOperator): + """ + Waits for at least one operation belonging to the job to have the + expected status. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudDataTransferServiceJobStatusSensor` + + :param job_name: The name of the transfer job + :type job_name: str + :param expected_statuses: The expected state of the operation. + See: + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferOperations#Status + :type expected_statuses: set[str] or string + :param project_id: (Optional) the ID of the project that owns the Transfer + Job. If set to None or missing, the default project_id from the Google Cloud + connection is used. + :type project_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + # [START gcp_transfer_job_sensor_template_fields] + template_fields = ( + "job_name", + "impersonation_chain", + ) + # [END gcp_transfer_job_sensor_template_fields] + + @apply_defaults + def __init__( + self, + *, + job_name: str, + expected_statuses: Union[Set[str], str], + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_name = job_name + self.expected_statuses = ( + {expected_statuses} + if isinstance(expected_statuses, str) + else expected_statuses + ) + self.project_id = project_id + self.gcp_cloud_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def poke(self, context: dict) -> bool: + hook = CloudDataTransferServiceHook( + gcp_conn_id=self.gcp_cloud_conn_id, + impersonation_chain=self.impersonation_chain, + ) + operations = hook.list_transfer_operations( + request_filter={"project_id": self.project_id, "job_names": [self.job_name]} + ) + + for operation in operations: + self.log.info( + "Progress for operation %s: %s", + operation[NAME], + operation[METADATA][COUNTERS], + ) + + check = CloudDataTransferServiceHook.operations_contain_expected_statuses( + operations=operations, expected_statuses=self.expected_statuses + ) + if check: + self.xcom_push(key="sensed_operations", value=operations, context=context) + + return check diff --git a/reference/providers/google/cloud/sensors/dataflow.py b/reference/providers/google/cloud/sensors/dataflow.py new file mode 100644 index 0000000..4c661c4 --- /dev/null +++ b/reference/providers/google/cloud/sensors/dataflow.py @@ -0,0 +1,408 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Dataflow sensor.""" +from typing import Callable, Optional, Sequence, Set, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.dataflow import ( + DEFAULT_DATAFLOW_LOCATION, + DataflowHook, + DataflowJobStatus, +) +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class DataflowJobStatusSensor(BaseSensorOperator): + """ + Checks for the status of a job in Google Cloud Dataflow. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataflowJobStatusSensor` + + :param job_id: ID of the job to be checked. + :type job_id: str + :param expected_statuses: The expected state of the operation. + See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.jobs#Job.JobState + :type expected_statuses: Union[Set[str], str] + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param location: The location of the Dataflow job (for example europe-west1). See: + https://cloud.google.com/dataflow/docs/concepts/regional-endpoints + :type location: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. See: + https://developers.google.com/identity/protocols/oauth2/service-account#delegatingauthority + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ["job_id"] + + @apply_defaults + def __init__( + self, + *, + job_id: str, + expected_statuses: Union[Set[str], str], + project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_id = job_id + self.expected_statuses = ( + {expected_statuses} + if isinstance(expected_statuses, str) + else expected_statuses + ) + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook: Optional[DataflowHook] = None + + def poke(self, context: dict) -> bool: + self.log.info( + "Waiting for job %s to be in one of the states: %s.", + self.job_id, + ", ".join(self.expected_statuses), + ) + self.hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + job = self.hook.get_job( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + + job_status = job["currentState"] + self.log.debug("Current job status for job %s: %s.", self.job_id, job_status) + + if job_status in self.expected_statuses: + return True + elif job_status in DataflowJobStatus.TERMINAL_STATES: + raise AirflowException( + f"Job with id '{self.job_id}' is already in terminal state: {job_status}" + ) + + return False + + +class DataflowJobMetricsSensor(BaseSensorOperator): + """ + Checks the metrics of a job in Google Cloud Dataflow. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataflowJobMetricsSensor` + + :param job_id: ID of the job to be checked. + :type job_id: str + :param callback: callback which is called with list of read job metrics + See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/MetricUpdate + :type callback: callable + :param fail_on_terminal_state: If set to true sensor will raise Exception when + job is in terminal state + :type fail_on_terminal_state: bool + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param location: The location of the Dataflow job (for example europe-west1). See: + https://cloud.google.com/dataflow/docs/concepts/regional-endpoints + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ["job_id"] + + @apply_defaults + def __init__( + self, + *, + job_id: str, + callback: Callable[[dict], bool], + fail_on_terminal_state: bool = True, + project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_id = job_id + self.project_id = project_id + self.callback = callback + self.fail_on_terminal_state = fail_on_terminal_state + self.location = location + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook: Optional[DataflowHook] = None + + def poke(self, context: dict) -> bool: + self.hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + if self.fail_on_terminal_state: + job = self.hook.get_job( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + job_status = job["currentState"] + if job_status in DataflowJobStatus.TERMINAL_STATES: + raise AirflowException( + f"Job with id '{self.job_id}' is already in terminal state: {job_status}" + ) + + result = self.hook.fetch_job_metrics_by_id( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + + return self.callback(result["metrics"]) + + +class DataflowJobMessagesSensor(BaseSensorOperator): + """ + Checks for the job message in Google Cloud Dataflow. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataflowJobMessagesSensor` + + :param job_id: ID of the job to be checked. + :type job_id: str + :param callback: callback which is called with list of read job metrics + See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/MetricUpdate + :type callback: callable + :param fail_on_terminal_state: If set to true sensor will raise Exception when + job is in terminal state + :type fail_on_terminal_state: bool + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param location: Job location. + :type location: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ["job_id"] + + @apply_defaults + def __init__( + self, + *, + job_id: str, + callback: Callable, + fail_on_terminal_state: bool = True, + project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_id = job_id + self.project_id = project_id + self.callback = callback + self.fail_on_terminal_state = fail_on_terminal_state + self.location = location + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook: Optional[DataflowHook] = None + + def poke(self, context: dict) -> bool: + self.hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + if self.fail_on_terminal_state: + job = self.hook.get_job( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + job_status = job["currentState"] + if job_status in DataflowJobStatus.TERMINAL_STATES: + raise AirflowException( + f"Job with id '{self.job_id}' is already in terminal state: {job_status}" + ) + + result = self.hook.fetch_job_messages_by_id( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + + return self.callback(result) + + +class DataflowJobAutoScalingEventsSensor(BaseSensorOperator): + """ + Checks for the job autoscaling event in Google Cloud Dataflow. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataflowJobAutoScalingEventsSensor` + + :param job_id: ID of the job to be checked. + :type job_id: str + :param callback: callback which is called with list of read job metrics + See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/MetricUpdate + :type callback: callable + :param fail_on_terminal_state: If set to true sensor will raise Exception when + job is in terminal state + :type fail_on_terminal_state: bool + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param location: Job location. + :type location: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ["job_id"] + + @apply_defaults + def __init__( + self, + *, + job_id: str, + callback: Callable, + fail_on_terminal_state: bool = True, + project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_id = job_id + self.project_id = project_id + self.callback = callback + self.fail_on_terminal_state = fail_on_terminal_state + self.location = location + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook: Optional[DataflowHook] = None + + def poke(self, context: dict) -> bool: + self.hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + if self.fail_on_terminal_state: + job = self.hook.get_job( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + job_status = job["currentState"] + if job_status in DataflowJobStatus.TERMINAL_STATES: + raise AirflowException( + f"Job with id '{self.job_id}' is already in terminal state: {job_status}" + ) + + result = self.hook.fetch_job_autoscaling_events_by_id( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + + return self.callback(result) diff --git a/reference/providers/google/cloud/sensors/dataproc.py b/reference/providers/google/cloud/sensors/dataproc.py new file mode 100644 index 0000000..e76ee46 --- /dev/null +++ b/reference/providers/google/cloud/sensors/dataproc.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Dataproc Job sensor.""" +# pylint: disable=C0302 + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.dataproc import DataprocHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults +from google.cloud.dataproc_v1beta2.types import JobStatus + + +class DataprocJobSensor(BaseSensorOperator): + """ + Check for the state of a previously submitted Dataproc job. + + :param project_id: The ID of the google cloud project in which + to create the cluster. (templated) + :type project_id: str + :param dataproc_job_id: The Dataproc job ID to poll. (templated) + :type dataproc_job_id: str + :param location: Required. The Cloud Dataproc region in which to handle the request. (templated) + :type location: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ("project_id", "location", "dataproc_job_id") + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + *, + project_id: str, + dataproc_job_id: str, + location: str, + gcp_conn_id: str = "google_cloud_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.dataproc_job_id = dataproc_job_id + self.location = location + + def poke(self, context: dict) -> bool: + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id) + job = hook.get_job( + job_id=self.dataproc_job_id, + location=self.location, + project_id=self.project_id, + ) + state = job.status.state + + if state == JobStatus.State.ERROR: + raise AirflowException(f"Job failed:\n{job}") + elif state in { + JobStatus.State.CANCELLED, + JobStatus.State.CANCEL_PENDING, + JobStatus.State.CANCEL_STARTED, + }: + raise AirflowException(f"Job was cancelled:\n{job}") + elif JobStatus.State.DONE == state: + self.log.debug("Job %s completed successfully.", self.dataproc_job_id) + return True + elif JobStatus.State.ATTEMPT_FAILURE == state: + self.log.debug("Job %s attempt has failed.", self.dataproc_job_id) + + self.log.info("Waiting for job %s to complete.", self.dataproc_job_id) + return False diff --git a/reference/providers/google/cloud/sensors/gcs.py b/reference/providers/google/cloud/sensors/gcs.py new file mode 100644 index 0000000..b94ff10 --- /dev/null +++ b/reference/providers/google/cloud/sensors/gcs.py @@ -0,0 +1,446 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud Storage sensors.""" + +import os +import warnings +from datetime import datetime +from typing import Callable, List, Optional, Sequence, Set, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.sensors.base import BaseSensorOperator, poke_mode_only +from airflow.utils.decorators import apply_defaults + + +class GCSObjectExistenceSensor(BaseSensorOperator): + """ + Checks for the existence of a file in Google Cloud Storage. + + :param bucket: The Google Cloud Storage bucket where the object is. + :type bucket: str + :param object: The name of the object to check in the Google cloud + storage bucket. + :type object: str + :param google_cloud_conn_id: The connection ID to use when + connecting to Google Cloud Storage. + :type google_cloud_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "bucket", + "object", + "impersonation_chain", + ) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + *, + bucket: str, + object: str, # pylint: disable=redefined-builtin + google_cloud_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self.bucket = bucket + self.object = object + self.google_cloud_conn_id = google_cloud_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def poke(self, context: dict) -> bool: + self.log.info("Sensor checks existence of : %s, %s", self.bucket, self.object) + hook = GCSHook( + gcp_conn_id=self.google_cloud_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + return hook.exists(self.bucket, self.object) + + +def ts_function(context): + """ + Default callback for the GoogleCloudStorageObjectUpdatedSensor. The default + behaviour is check for the object being updated after execution_date + + schedule_interval. + """ + return context["dag"].following_schedule(context["execution_date"]) + + +class GCSObjectUpdateSensor(BaseSensorOperator): + """ + Checks if an object is updated in Google Cloud Storage. + + :param bucket: The Google Cloud Storage bucket where the object is. + :type bucket: str + :param object: The name of the object to download in the Google cloud + storage bucket. + :type object: str + :param ts_func: Callback for defining the update condition. The default callback + returns execution_date + schedule_interval. The callback takes the context + as parameter. + :type ts_func: function + :param google_cloud_conn_id: The connection ID to use when + connecting to Google Cloud Storage. + :type google_cloud_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "bucket", + "object", + "impersonation_chain", + ) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + bucket: str, + object: str, # pylint: disable=redefined-builtin + ts_func: Callable = ts_function, + google_cloud_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self.bucket = bucket + self.object = object + self.ts_func = ts_func + self.google_cloud_conn_id = google_cloud_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def poke(self, context: dict) -> bool: + self.log.info("Sensor checks existence of : %s, %s", self.bucket, self.object) + hook = GCSHook( + gcp_conn_id=self.google_cloud_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + return hook.is_updated_after(self.bucket, self.object, self.ts_func(context)) + + +class GCSObjectsWithPrefixExistenceSensor(BaseSensorOperator): + """ + Checks for the existence of GCS objects at a given prefix, passing matches via XCom. + + When files matching the given prefix are found, the poke method's criteria will be + fulfilled and the matching objects will be returned from the operator and passed + through XCom for downstream tasks. + + :param bucket: The Google Cloud Storage bucket where the object is. + :type bucket: str + :param prefix: The name of the prefix to check in the Google cloud + storage bucket. + :type prefix: str + :param google_cloud_conn_id: The connection ID to use when + connecting to Google Cloud Storage. + :type google_cloud_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "bucket", + "prefix", + "impersonation_chain", + ) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + bucket: str, + prefix: str, + google_cloud_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.bucket = bucket + self.prefix = prefix + self.google_cloud_conn_id = google_cloud_conn_id + self.delegate_to = delegate_to + self._matches: List[str] = [] + self.impersonation_chain = impersonation_chain + + def poke(self, context: dict) -> bool: + self.log.info( + "Sensor checks existence of objects: %s, %s", self.bucket, self.prefix + ) + hook = GCSHook( + gcp_conn_id=self.google_cloud_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + self._matches = hook.list(self.bucket, prefix=self.prefix) + return bool(self._matches) + + def execute(self, context: dict) -> List[str]: + """Overridden to allow matches to be passed""" + super().execute(context) + return self._matches + + +class GCSObjectsWtihPrefixExistenceSensor(GCSObjectsWithPrefixExistenceSensor): + """ + This class is deprecated. + Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectsWithPrefixExistenceSensor`. + """ + + def __init__(self, *args, **kwargs): + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectsWithPrefixExistenceSensor`.""", + DeprecationWarning, + stacklevel=3, + ) + super().__init__(*args, **kwargs) + + +def get_time(): + """ + This is just a wrapper of datetime.datetime.now to simplify mocking in the + unittests. + """ + return datetime.now() + + +@poke_mode_only +class GCSUploadSessionCompleteSensor(BaseSensorOperator): + """ + Checks for changes in the number of objects at prefix in Google Cloud Storage + bucket and returns True if the inactivity period has passed with no + increase in the number of objects. Note, this sensor will no behave correctly + in reschedule mode, as the state of the listed objects in the GCS bucket will + be lost between rescheduled invocations. + + :param bucket: The Google Cloud Storage bucket where the objects are. + expected. + :type bucket: str + :param prefix: The name of the prefix to check in the Google cloud + storage bucket. + :param inactivity_period: The total seconds of inactivity to designate + an upload session is over. Note, this mechanism is not real time and + this operator may not return until a poke_interval after this period + has passed with no additional objects sensed. + :type inactivity_period: float + :param min_objects: The minimum number of objects needed for upload session + to be considered valid. + :type min_objects: int + :param previous_objects: The set of object ids found during the last poke. + :type previous_objects: set[str] + :param allow_delete: Should this sensor consider objects being deleted + between pokes valid behavior. If true a warning message will be logged + when this happens. If false an error will be raised. + :type allow_delete: bool + :param google_cloud_conn_id: The connection ID to use when connecting + to Google Cloud Storage. + :type google_cloud_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "bucket", + "prefix", + "impersonation_chain", + ) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + bucket: str, + prefix: str, + inactivity_period: float = 60 * 60, + min_objects: int = 1, + previous_objects: Optional[Set[str]] = None, + allow_delete: bool = True, + google_cloud_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + + self.bucket = bucket + self.prefix = prefix + if inactivity_period < 0: + raise ValueError("inactivity_period must be non-negative") + self.inactivity_period = inactivity_period + self.min_objects = min_objects + self.previous_objects = previous_objects if previous_objects else set() + self.inactivity_seconds = 0 + self.allow_delete = allow_delete + self.google_cloud_conn_id = google_cloud_conn_id + self.delegate_to = delegate_to + self.last_activity_time = None + self.impersonation_chain = impersonation_chain + self.hook: Optional[GCSHook] = None + + def _get_gcs_hook(self) -> Optional[GCSHook]: + if not self.hook: + self.hook = GCSHook( + gcp_conn_id=self.google_cloud_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + return self.hook + + def is_bucket_updated(self, current_objects: Set[str]) -> bool: + """ + Checks whether new objects have been uploaded and the inactivity_period + has passed and updates the state of the sensor accordingly. + + :param current_objects: set of object ids in bucket during last poke. + :type current_objects: set[str] + """ + current_num_objects = len(current_objects) + if current_objects > self.previous_objects: + # When new objects arrived, reset the inactivity_seconds + # and update previous_objects for the next poke. + self.log.info( + "New objects found at %s resetting last_activity_time.", + os.path.join(self.bucket, self.prefix), + ) + self.log.debug( + "New objects: %s", "\n".join(current_objects - self.previous_objects) + ) + self.last_activity_time = get_time() + self.inactivity_seconds = 0 + self.previous_objects = current_objects + return False + + if self.previous_objects - current_objects: + # During the last poke interval objects were deleted. + if self.allow_delete: + self.previous_objects = current_objects + self.last_activity_time = get_time() + self.log.warning( + """ + Objects were deleted during the last + poke interval. Updating the file counter and + resetting last_activity_time. + %s + """, + self.previous_objects - current_objects, + ) + return False + + raise AirflowException( + """ + Illegal behavior: objects were deleted in {} between pokes. + """.format( + os.path.join(self.bucket, self.prefix) + ) + ) + + if self.last_activity_time: + self.inactivity_seconds = ( + get_time() - self.last_activity_time + ).total_seconds() + else: + # Handles the first poke where last inactivity time is None. + self.last_activity_time = get_time() + self.inactivity_seconds = 0 + + if self.inactivity_seconds >= self.inactivity_period: + path = os.path.join(self.bucket, self.prefix) + + if current_num_objects >= self.min_objects: + self.log.info( + """SUCCESS: + Sensor found %s objects at %s. + Waited at least %s seconds, with no new objects dropped. + """, + current_num_objects, + path, + self.inactivity_period, + ) + return True + + self.log.error( + "FAILURE: Inactivity Period passed, not enough objects found in %s", + path, + ) + + return False + return False + + def poke(self, context: dict) -> bool: + return self.is_bucket_updated( + set(self._get_gcs_hook().list(self.bucket, prefix=self.prefix)) # type: ignore[union-attr] + ) diff --git a/reference/providers/google/cloud/sensors/pubsub.py b/reference/providers/google/cloud/sensors/pubsub.py new file mode 100644 index 0000000..0c26c2b --- /dev/null +++ b/reference/providers/google/cloud/sensors/pubsub.py @@ -0,0 +1,205 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google PubSub sensor.""" +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +from airflow.providers.google.cloud.hooks.pubsub import PubSubHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults +from google.cloud.pubsub_v1.types import ReceivedMessage + + +class PubSubPullSensor(BaseSensorOperator): + """Pulls messages from a PubSub subscription and passes them through XCom. + Always waits for at least one message to be returned from the subscription. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:PubSubPullSensor` + + .. seealso:: + If you don't want to wait for at least one message to come, use Operator instead: + :class:`~airflow.providers.google.cloud.operators.pubsub.PubSubPullOperator` + + This sensor operator will pull up to ``max_messages`` messages from the + specified PubSub subscription. When the subscription returns messages, + the poke method's criteria will be fulfilled and the messages will be + returned from the operator and passed through XCom for downstream tasks. + + If ``ack_messages`` is set to True, messages will be immediately + acknowledged before being returned, otherwise, downstream tasks will be + responsible for acknowledging them. + + ``project`` and ``subscription`` are templated so you can use + variables in them. + + :param project: the Google Cloud project ID for the subscription (templated) + :type project: str + :param subscription: the Pub/Sub subscription name. Do not include the + full subscription path. + :type subscription: str + :param max_messages: The maximum number of messages to retrieve per + PubSub pull request + :type max_messages: int + :param return_immediately: + (Deprecated) This is an underlying PubSub API implementation detail. + It has no real effect on Sensor behaviour other than some internal wait time before retrying + on empty queue. + The Sensor task will (by definition) always wait for a message, regardless of this argument value. + + If you want a non-blocking task that does not to wait for messages, please use + :class:`~airflow.providers.google.cloud.operators.pubsub.PubSubPullOperator` + instead. + :type return_immediately: bool + :param ack_messages: If True, each message will be acknowledged + immediately rather than by any downstream tasks + :type ack_messages: bool + :param gcp_conn_id: The connection ID to use connecting to + Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param messages_callback: (Optional) Callback to process received messages. + It's return value will be saved to XCom. + If you are pulling large messages, you probably want to provide a custom callback. + If not provided, the default implementation will convert `ReceivedMessage` objects + into JSON-serializable dicts using `google.protobuf.json_format.MessageToDict` function. + :type messages_callback: Optional[Callable[[List[ReceivedMessage], Dict[str, Any]], Any]] + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "project_id", + "subscription", + "impersonation_chain", + ] + ui_color = "#ff7f50" + + @apply_defaults + def __init__( + self, + *, + project_id: str, + subscription: str, + max_messages: int = 5, + return_immediately: bool = True, + ack_messages: bool = False, + gcp_conn_id: str = "google_cloud_default", + messages_callback: Optional[ + Callable[[List[ReceivedMessage], Dict[str, Any]], Any] + ] = None, + delegate_to: Optional[str] = None, + project: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + # To preserve backward compatibility + # TODO: remove one day + if project: + warnings.warn( + "The project parameter has been deprecated. You should pass the project_id parameter.", + DeprecationWarning, + stacklevel=2, + ) + project_id = project + + if not return_immediately: + warnings.warn( + "The return_immediately parameter is deprecated.\n" + " It exposes what is really just an implementation detail of underlying PubSub API.\n" + " It has no effect on PubSubPullSensor behaviour.\n" + " It should be left as default value of True.\n" + " If is here only because of backwards compatibility.\n" + " If may be removed in the future.\n", + DeprecationWarning, + stacklevel=2, + ) + + super().__init__(**kwargs) + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.project_id = project_id + self.subscription = subscription + self.max_messages = max_messages + self.return_immediately = return_immediately + self.ack_messages = ack_messages + self.messages_callback = messages_callback + self.impersonation_chain = impersonation_chain + + self._return_value = None + + def execute(self, context: dict): + """Overridden to allow messages to be passed""" + super().execute(context) + return self._return_value + + def poke(self, context: dict) -> bool: + hook = PubSubHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + pulled_messages = hook.pull( + project_id=self.project_id, + subscription=self.subscription, + max_messages=self.max_messages, + return_immediately=self.return_immediately, + ) + + handle_messages = self.messages_callback or self._default_message_callback + + self._return_value = handle_messages(pulled_messages, context) + + if pulled_messages and self.ack_messages: + hook.acknowledge( + project_id=self.project_id, + subscription=self.subscription, + messages=pulled_messages, + ) + + return bool(pulled_messages) + + def _default_message_callback( + self, + pulled_messages: List[ReceivedMessage], + context: Dict[str, Any], # pylint: disable=unused-argument + ): + """ + This method can be overridden by subclasses or by `messages_callback` constructor argument. + This default implementation converts `ReceivedMessage` objects into JSON-serializable dicts. + + :param pulled_messages: messages received from the topic. + :type pulled_messages: List[ReceivedMessage] + :param context: same as in `execute` + :return: value to be saved to XCom. + """ + messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages] + + return messages_json diff --git a/reference/providers/google/cloud/sensors/workflows.py b/reference/providers/google/cloud/sensors/workflows.py new file mode 100644 index 0000000..c1dd7d4 --- /dev/null +++ b/reference/providers/google/cloud/sensors/workflows.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Sequence, Set, Tuple, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.workflows import WorkflowsHook +from airflow.sensors.base import BaseSensorOperator +from google.api_core.retry import Retry +from google.cloud.workflows.executions_v1beta import Execution + + +class WorkflowExecutionSensor(BaseSensorOperator): + """ + Checks state of an execution for the given ``workflow_id`` and ``execution_id``. + + :param workflow_id: Required. The ID of the workflow. + :type workflow_id: str + :param execution_id: Required. The ID of the execution. + :type execution_id: str + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :type project_id: str + :param location: Required. The Cloud Dataproc region in which to handle the request. + :type location: str + :param success_states: Execution states to be considered as successful, by default + it's only ``SUCCEEDED`` state + :type success_states: List[Execution.State] + :param failure_states: Execution states to be considered as failures, by default + they are ``FAILED`` and ``CANCELLED`` states. + :type failure_states: List[Execution.State] + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param request_timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type request_timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + + template_fields = ("location", "workflow_id", "execution_id") + + def __init__( + self, + *, + workflow_id: str, + execution_id: str, + location: str, + project_id: str, + success_states: Optional[Set[Execution.State]] = None, + failure_states: Optional[Set[Execution.State]] = None, + retry: Optional[Retry] = None, + request_timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.success_states = success_states or {Execution.State.SUCCEEDED} + self.failure_states = failure_states or { + Execution.State.FAILED, + Execution.State.CANCELLED, + } + self.workflow_id = workflow_id + self.execution_id = execution_id + self.location = location + self.project_id = project_id + self.retry = retry + self.request_timeout = request_timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def poke(self, context): + hook = WorkflowsHook( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) + self.log.info( + "Checking state of execution %s for workflow %s", + self.execution_id, + self.workflow_id, + ) + execution: Execution = hook.get_execution( + workflow_id=self.workflow_id, + execution_id=self.execution_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.request_timeout, + metadata=self.metadata, + ) + + state = execution.state + if state in self.failure_states: + raise AirflowException( + f"Execution {self.execution_id} for workflow {self.execution_id} " + f"failed and is in `{state}` state", + ) + + if state in self.success_states: + self.log.info( + "Execution %s for workflow %s completed with state: %s", + self.execution_id, + self.workflow_id, + state, + ) + return True + + self.log.info( + "Execution %s for workflow %s does not completed yet, current state: %s", + self.execution_id, + self.workflow_id, + state, + ) + return False diff --git a/reference/providers/google/cloud/transfers/__init__.py b/reference/providers/google/cloud/transfers/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/cloud/transfers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/cloud/transfers/adls_to_gcs.py b/reference/providers/google/cloud/transfers/adls_to_gcs.py new file mode 100644 index 0000000..5c5f70c --- /dev/null +++ b/reference/providers/google/cloud/transfers/adls_to_gcs.py @@ -0,0 +1,194 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains Azure Data Lake Storage to +Google Cloud Storage operator. +""" +import os +import warnings +from tempfile import NamedTemporaryFile +from typing import Optional, Sequence, Union + +from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url +from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook +from airflow.providers.microsoft.azure.operators.adls_list import ( + AzureDataLakeStorageListOperator, +) +from airflow.utils.decorators import apply_defaults + + +class ADLSToGCSOperator(AzureDataLakeStorageListOperator): + """ + Synchronizes an Azure Data Lake Storage path with a GCS bucket + + :param src_adls: The Azure Data Lake path to find the objects (templated) + :type src_adls: str + :param dest_gcs: The Google Cloud Storage bucket and prefix to + store the objects. (templated) + :type dest_gcs: str + :param replace: If true, replaces same-named files in GCS + :type replace: bool + :param gzip: Option to compress file for upload + :type gzip: bool + :param azure_data_lake_conn_id: The connection ID to use when + connecting to Azure Data Lake Storage. + :type azure_data_lake_conn_id: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param delegate_to: Google account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + + **Examples**: + The following Operator would copy a single file named + ``hello/world.avro`` from ADLS to the GCS bucket ``mybucket``. Its full + resulting gcs path will be ``gs://mybucket/hello/world.avro`` :: + + copy_single_file = AdlsToGoogleCloudStorageOperator( + task_id='copy_single_file', + src_adls='hello/world.avro', + dest_gcs='gs://mybucket', + replace=False, + azure_data_lake_conn_id='azure_data_lake_default', + gcp_conn_id='google_cloud_default' + ) + + The following Operator would copy all parquet files from ADLS + to the GCS bucket ``mybucket``. :: + + copy_all_files = AdlsToGoogleCloudStorageOperator( + task_id='copy_all_files', + src_adls='*.parquet', + dest_gcs='gs://mybucket', + replace=False, + azure_data_lake_conn_id='azure_data_lake_default', + gcp_conn_id='google_cloud_default' + ) + + The following Operator would copy all parquet files from ADLS + path ``/hello/world``to the GCS bucket ``mybucket``. :: + + copy_world_files = AdlsToGoogleCloudStorageOperator( + task_id='copy_world_files', + src_adls='hello/world/*.parquet', + dest_gcs='gs://mybucket', + replace=False, + azure_data_lake_conn_id='azure_data_lake_default', + gcp_conn_id='google_cloud_default' + ) + """ + + template_fields: Sequence[str] = ( + "src_adls", + "dest_gcs", + "google_impersonation_chain", + ) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + *, + src_adls: str, + dest_gcs: str, + azure_data_lake_conn_id: str, + gcp_conn_id: str = "google_cloud_default", + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + replace: bool = False, + gzip: bool = False, + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + super().__init__( + path=src_adls, azure_data_lake_conn_id=azure_data_lake_conn_id, **kwargs + ) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = google_cloud_storage_conn_id + + self.src_adls = src_adls + self.dest_gcs = dest_gcs + self.replace = replace + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.gzip = gzip + self.google_impersonation_chain = google_impersonation_chain + + def execute(self, context): + # use the super to list all files in an Azure Data Lake path + files = super().execute(context) + g_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.google_impersonation_chain, + ) + + if not self.replace: + # if we are not replacing -> list all files in the ADLS path + # and only keep those files which are present in + # ADLS and not in Google Cloud Storage + bucket_name, prefix = _parse_gcs_url(self.dest_gcs) + existing_files = g_hook.list(bucket_name=bucket_name, prefix=prefix) + files = set(files) - set(existing_files) + + if files: + hook = AzureDataLakeHook( + azure_data_lake_conn_id=self.azure_data_lake_conn_id + ) + + for obj in files: + with NamedTemporaryFile(mode="wb", delete=True) as f: + hook.download_file(local_path=f.name, remote_path=obj) + f.flush() + dest_gcs_bucket, dest_gcs_prefix = _parse_gcs_url(self.dest_gcs) + dest_path = os.path.join(dest_gcs_prefix, obj) + self.log.info("Saving file to %s", dest_path) + + g_hook.upload( + bucket_name=dest_gcs_bucket, + object_name=dest_path, + filename=f.name, + gzip=self.gzip, + ) + + self.log.info("All done, uploaded %d files to GCS", len(files)) + else: + self.log.info("In sync, no files needed to be uploaded to GCS") + + return files diff --git a/reference/providers/google/cloud/transfers/azure_fileshare_to_gcs.py b/reference/providers/google/cloud/transfers/azure_fileshare_to_gcs.py new file mode 100644 index 0000000..b83dfb4 --- /dev/null +++ b/reference/providers/google/cloud/transfers/azure_fileshare_to_gcs.py @@ -0,0 +1,191 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tempfile import NamedTemporaryFile +from typing import Iterable, Optional, Sequence, Union + +from airflow import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import ( + GCSHook, + _parse_gcs_url, + gcs_object_is_directory, +) +from airflow.providers.microsoft.azure.hooks.azure_fileshare import AzureFileShareHook +from airflow.utils.decorators import apply_defaults + + +class AzureFileShareToGCSOperator(BaseOperator): + """ + Synchronizes a Azure FileShare directory content (excluding subdirectories), + possibly filtered by a prefix, with a Google Cloud Storage destination path. + + :param share_name: The Azure FileShare share where to find the objects. (templated) + :type share_name: str + :param directory_name: (Optional) Path to Azure FileShare directory which content is to be transferred. + Defaults to root directory (templated) + :type directory_name: str + :param prefix: Prefix string which filters objects whose name begin with + such prefix. (templated) + :type prefix: str + :param wasb_conn_id: The source WASB connection + :type wasb_conn_id: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param dest_gcs: The destination Google Cloud Storage bucket and prefix + where you want to store the files. (templated) + :type dest_gcs: str + :param delegate_to: Google account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param replace: Whether you want to replace existing destination files + or not. + :type replace: bool + :param gzip: Option to compress file for upload + :type gzip: bool + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Optional[Union[str, Sequence[str]]] + + Note that ``share_name``, ``directory_name``, ``prefix``, ``delimiter`` and ``dest_gcs`` are + templated, so you can use variables in them if you wish. + """ + + template_fields: Iterable[str] = ( + "share_name", + "directory_name", + "prefix", + "dest_gcs", + ) + + @apply_defaults + def __init__( + self, + *, + share_name: str, + dest_gcs: str, + directory_name: Optional[str] = None, + prefix: str = "", + wasb_conn_id: str = "wasb_default", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + replace: bool = False, + gzip: bool = False, + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.share_name = share_name + self.directory_name = directory_name + self.prefix = prefix + self.wasb_conn_id = wasb_conn_id + self.gcp_conn_id = gcp_conn_id + self.dest_gcs = dest_gcs + self.delegate_to = delegate_to + self.replace = replace + self.gzip = gzip + self.google_impersonation_chain = google_impersonation_chain + + if dest_gcs and not gcs_object_is_directory(self.dest_gcs): + self.log.info( + "Destination Google Cloud Storage path is not a valid " + '"directory", define a path that ends with a slash "/" or ' + "leave it empty for the root of the bucket." + ) + raise AirflowException( + 'The destination Google Cloud Storage path must end with a slash "/" or be empty.' + ) + + def execute(self, context): + azure_fileshare_hook = AzureFileShareHook(self.wasb_conn_id) + files = azure_fileshare_hook.list_files( + share_name=self.share_name, directory_name=self.directory_name + ) + + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.google_impersonation_chain, + ) + + dest_gcs_bucket, dest_gcs_object_prefix = _parse_gcs_url(self.dest_gcs) + + # pylint: disable=too-many-nested-blocks + if not self.replace: + # if we are not replacing -> list all files in the GCS bucket + # and only keep those files which are present in + # S3 and not in Google Cloud Storage + existing_files_prefixed = gcs_hook.list( + dest_gcs_bucket, prefix=dest_gcs_object_prefix + ) + + existing_files = [] + + # Remove the object prefix itself, an empty directory was found + if dest_gcs_object_prefix in existing_files_prefixed: + existing_files_prefixed.remove(dest_gcs_object_prefix) + + # Remove the object prefix from all object string paths + for file in existing_files_prefixed: + if file.startswith(dest_gcs_object_prefix): + existing_files.append(file[len(dest_gcs_object_prefix) :]) + else: + existing_files.append(file) + + files = list(set(files) - set(existing_files)) + + if files: + self.log.info("%s files are going to be synced.", len(files)) + else: + self.log.info("There are no new files to sync. Have a nice day!") + + for file in files: + with NamedTemporaryFile() as temp_file: + azure_fileshare_hook.get_file_to_stream( + stream=temp_file, + share_name=self.share_name, + directory_name=self.directory_name, + file_name=file, + ) + temp_file.flush() + + # There will always be a '/' before file because it is + # enforced at instantiation time + dest_gcs_object = dest_gcs_object_prefix + file + gcs_hook.upload( + dest_gcs_bucket, dest_gcs_object, temp_file.name, gzip=self.gzip + ) + + if files: + self.log.info( + "All done, uploaded %d files to Google Cloud Storage.", len(files) + ) + else: + self.log.info( + "In sync, no files needed to be uploaded to Google Cloud Storage" + ) + + return files diff --git a/reference/providers/google/cloud/transfers/bigquery_to_bigquery.py b/reference/providers/google/cloud/transfers/bigquery_to_bigquery.py new file mode 100644 index 0000000..ddc824b --- /dev/null +++ b/reference/providers/google/cloud/transfers/bigquery_to_bigquery.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google BigQuery to BigQuery operator.""" +import warnings +from typing import Dict, List, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from airflow.utils.decorators import apply_defaults + + +class BigQueryToBigQueryOperator(BaseOperator): + """ + Copies data from one BigQuery table to another. + + .. seealso:: + For more details about these parameters: + https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.copy + + :param source_project_dataset_tables: One or more + dotted ``(project:|project.).
`` BigQuery tables to use as the + source data. If ```` is not included, project will be the + project defined in the connection json. Use a list if there are multiple + source tables. (templated) + :type source_project_dataset_tables: list|string + :param destination_project_dataset_table: The destination BigQuery + table. Format is: ``(project:|project.).
`` (templated) + :type destination_project_dataset_table: str + :param write_disposition: The write disposition if the table already exists. + :type write_disposition: str + :param create_disposition: The create disposition if the table doesn't exist. + :type create_disposition: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param labels: a dictionary containing labels for the job/query, + passed to BigQuery + :type labels: dict + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + :param location: The location used for the operation. + :type location: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "source_project_dataset_tables", + "destination_project_dataset_table", + "labels", + "impersonation_chain", + ) + template_ext = (".sql",) + ui_color = "#e6f0e4" + + @apply_defaults + def __init__( + self, + *, # pylint: disable=too-many-arguments + source_project_dataset_tables: Union[List[str], str], + destination_project_dataset_table: str, + write_disposition: str = "WRITE_EMPTY", + create_disposition: str = "CREATE_IF_NEEDED", + gcp_conn_id: str = "google_cloud_default", + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + labels: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = bigquery_conn_id + + self.source_project_dataset_tables = source_project_dataset_tables + self.destination_project_dataset_table = destination_project_dataset_table + self.write_disposition = write_disposition + self.create_disposition = create_disposition + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.labels = labels + self.encryption_configuration = encryption_configuration + self.location = location + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + self.log.info( + "Executing copy of %s into: %s", + self.source_project_dataset_tables, + self.destination_project_dataset_table, + ) + hook = BigQueryHook( + bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + conn = hook.get_conn() + cursor = conn.cursor() + cursor.run_copy( + source_project_dataset_tables=self.source_project_dataset_tables, + destination_project_dataset_table=self.destination_project_dataset_table, + write_disposition=self.write_disposition, + create_disposition=self.create_disposition, + labels=self.labels, + encryption_configuration=self.encryption_configuration, + ) diff --git a/reference/providers/google/cloud/transfers/bigquery_to_gcs.py b/reference/providers/google/cloud/transfers/bigquery_to_gcs.py new file mode 100644 index 0000000..c866945 --- /dev/null +++ b/reference/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -0,0 +1,165 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google BigQuery to Google Cloud Storage operator.""" +import warnings +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from airflow.utils.decorators import apply_defaults +from google.cloud.bigquery.table import TableReference + + +class BigQueryToGCSOperator(BaseOperator): + """ + Transfers a BigQuery table to a Google Cloud Storage bucket. + + .. seealso:: + For more details about these parameters: + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + :param source_project_dataset_table: The dotted + ``(.|:).
`` BigQuery table to use as the + source data. If ```` is not included, project will be the project + defined in the connection json. (templated) + :type source_project_dataset_table: str + :param destination_cloud_storage_uris: The destination Google Cloud + Storage URI (e.g. gs://some-bucket/some-file.txt). (templated) Follows + convention defined here: + https://cloud.google.com/bigquery/exporting-data-from-bigquery#exportingmultiple + :type destination_cloud_storage_uris: List[str] + :param compression: Type of compression to use. + :type compression: str + :param export_format: File format to export. + :type export_format: str + :param field_delimiter: The delimiter to use when extracting to a CSV. + :type field_delimiter: str + :param print_header: Whether to print a header for a CSV file extract. + :type print_header: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param labels: a dictionary containing labels for the job/query, + passed to BigQuery + :type labels: dict + :param location: The location used for the operation. + :type location: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "source_project_dataset_table", + "destination_cloud_storage_uris", + "labels", + "impersonation_chain", + ) + template_ext = () + ui_color = "#e4e6f0" + + @apply_defaults + def __init__( + self, + *, # pylint: disable=too-many-arguments + source_project_dataset_table: str, + destination_cloud_storage_uris: List[str], + compression: str = "NONE", + export_format: str = "CSV", + field_delimiter: str = ",", + print_header: bool = True, + gcp_conn_id: str = "google_cloud_default", + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + labels: Optional[Dict] = None, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = bigquery_conn_id + + self.source_project_dataset_table = source_project_dataset_table + self.destination_cloud_storage_uris = destination_cloud_storage_uris + self.compression = compression + self.export_format = export_format + self.field_delimiter = field_delimiter + self.print_header = print_header + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.labels = labels + self.location = location + self.impersonation_chain = impersonation_chain + + def execute(self, context): + self.log.info( + "Executing extract of %s into: %s", + self.source_project_dataset_table, + self.destination_cloud_storage_uris, + ) + hook = BigQueryHook( + bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + table_ref = TableReference.from_string( + self.source_project_dataset_table, hook.project_id + ) + + configuration: Dict[str, Any] = { + "extract": { + "sourceTable": table_ref.to_api_repr(), + "compression": self.compression, + "destinationUris": self.destination_cloud_storage_uris, + "destinationFormat": self.export_format, + } + } + + if self.labels: + configuration["labels"] = self.labels + + if self.export_format == "CSV": + # Only set fieldDelimiter and printHeader fields if using CSV. + # Google does not like it if you set these fields for other export + # formats. + configuration["extract"]["fieldDelimiter"] = self.field_delimiter + configuration["extract"]["printHeader"] = self.print_header + + hook.insert_job(configuration=configuration) diff --git a/reference/providers/google/cloud/transfers/bigquery_to_mysql.py b/reference/providers/google/cloud/transfers/bigquery_to_mysql.py new file mode 100644 index 0000000..f0ea47b --- /dev/null +++ b/reference/providers/google/cloud/transfers/bigquery_to_mysql.py @@ -0,0 +1,167 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google BigQuery to MySQL operator.""" +from typing import Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.utils.decorators import apply_defaults + + +class BigQueryToMySqlOperator(BaseOperator): + """ + Fetches the data from a BigQuery table (alternatively fetch data for selected columns) + and insert that data into a MySQL table. + + + .. note:: + If you pass fields to ``selected_fields`` which are in different order than the + order of columns already in + BQ table, the data will still be in the order of BQ table. + For example if the BQ table has 3 columns as + ``[A,B,C]`` and you pass 'B,A' in the ``selected_fields`` + the data would still be of the form ``'A,B'`` and passed through this form + to MySQL + + **Example**: :: + + transfer_data = BigQueryToMySqlOperator( + task_id='task_id', + dataset_table='origin_bq_table', + mysql_table='dest_table_name', + replace=True, + ) + + :param dataset_table: A dotted ``.
``: the big query table of origin + :type dataset_table: str + :param selected_fields: List of fields to return (comma-separated). If + unspecified, all fields are returned. + :type selected_fields: str + :param gcp_conn_id: reference to a specific Google Cloud hook. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param mysql_conn_id: reference to a specific mysql hook + :type mysql_conn_id: str + :param database: name of database which overwrite defined one in connection + :type database: str + :param replace: Whether to replace instead of insert + :type replace: bool + :param batch_size: The number of rows to take in each batch + :type batch_size: int + :param location: The location used for the operation. + :type location: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "dataset_id", + "table_id", + "mysql_table", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, # pylint: disable=too-many-arguments + dataset_table: str, + mysql_table: str, + selected_fields: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + mysql_conn_id: str = "mysql_default", + database: Optional[str] = None, + delegate_to: Optional[str] = None, + replace: bool = False, + batch_size: int = 1000, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.selected_fields = selected_fields + self.gcp_conn_id = gcp_conn_id + self.mysql_conn_id = mysql_conn_id + self.database = database + self.mysql_table = mysql_table + self.replace = replace + self.delegate_to = delegate_to + self.batch_size = batch_size + self.location = location + self.impersonation_chain = impersonation_chain + try: + self.dataset_id, self.table_id = dataset_table.split(".") + except ValueError: + raise ValueError(f"Could not parse {dataset_table} as .
") + + def _bq_get_data(self): + self.log.info("Fetching Data from:") + self.log.info("Dataset: %s ; Table: %s", self.dataset_id, self.table_id) + + hook = BigQueryHook( + bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + conn = hook.get_conn() + cursor = conn.cursor() + i = 0 + while True: + response = cursor.get_tabledata( + dataset_id=self.dataset_id, + table_id=self.table_id, + max_results=self.batch_size, + selected_fields=self.selected_fields, + start_index=i * self.batch_size, + ) + + if "rows" in response: + rows = response["rows"] + else: + self.log.info("Job Finished") + return + + self.log.info("Total Extracted rows: %s", len(rows) + i * self.batch_size) + + table_data = [] + for dict_row in rows: + single_row = [] + for fields in dict_row["f"]: + single_row.append(fields["v"]) + table_data.append(single_row) + + yield table_data + i += 1 + + def execute(self, context): + mysql_hook = MySqlHook(schema=self.database, mysql_conn_id=self.mysql_conn_id) + for rows in self._bq_get_data(): + mysql_hook.insert_rows(self.mysql_table, rows, replace=self.replace) diff --git a/reference/providers/google/cloud/transfers/cassandra_to_gcs.py b/reference/providers/google/cloud/transfers/cassandra_to_gcs.py new file mode 100644 index 0000000..8117b17 --- /dev/null +++ b/reference/providers/google/cloud/transfers/cassandra_to_gcs.py @@ -0,0 +1,391 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains operator for copying +data from Cassandra to Google Cloud Storage in JSON format. +""" + +import json +import warnings +from base64 import b64encode +from datetime import datetime +from decimal import Decimal +from tempfile import NamedTemporaryFile +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from uuid import UUID + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils.decorators import apply_defaults +from cassandra.util import Date, OrderedMapSerializedKey, SortedSet, Time + + +class CassandraToGCSOperator(BaseOperator): + """ + Copy data from Cassandra to Google Cloud Storage in JSON format + + Note: Arrays of arrays are not supported. + + :param cql: The CQL to execute on the Cassandra table. + :type cql: str + :param bucket: The bucket to upload to. + :type bucket: str + :param filename: The filename to use as the object name when uploading + to Google Cloud Storage. A {} should be specified in the filename + to allow the operator to inject file numbers in cases where the + file is split due to size. + :type filename: str + :param schema_filename: If set, the filename to use as the object name + when uploading a .json file containing the BigQuery schema fields + for the table that was dumped from MySQL. + :type schema_filename: str + :param approx_max_file_size_bytes: This operator supports the ability + to split large table dumps into multiple files (see notes in the + filename param docs above). This param allows developers to specify the + file size of the splits. Check https://cloud.google.com/storage/quotas + to see the maximum allowed file size for a single object. + :type approx_max_file_size_bytes: long + :param cassandra_conn_id: Reference to a specific Cassandra hook. + :type cassandra_conn_id: str + :param gzip: Option to compress file for upload + :type gzip: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "cql", + "bucket", + "filename", + "schema_filename", + "impersonation_chain", + ) + template_ext = (".cql",) + ui_color = "#a0e08c" + + @apply_defaults + def __init__( + self, + *, # pylint: disable=too-many-arguments + cql: str, + bucket: str, + filename: str, + schema_filename: Optional[str] = None, + approx_max_file_size_bytes: int = 1900000000, + gzip: bool = False, + cassandra_conn_id: str = "cassandra_default", + gcp_conn_id: str = "google_cloud_default", + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = google_cloud_storage_conn_id + + self.cql = cql + self.bucket = bucket + self.filename = filename + self.schema_filename = schema_filename + self.approx_max_file_size_bytes = approx_max_file_size_bytes + self.cassandra_conn_id = cassandra_conn_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.gzip = gzip + self.impersonation_chain = impersonation_chain + + # Default Cassandra to BigQuery type mapping + CQL_TYPE_MAP = { + "BytesType": "BYTES", + "DecimalType": "FLOAT", + "UUIDType": "BYTES", + "BooleanType": "BOOL", + "ByteType": "INTEGER", + "AsciiType": "STRING", + "FloatType": "FLOAT", + "DoubleType": "FLOAT", + "LongType": "INTEGER", + "Int32Type": "INTEGER", + "IntegerType": "INTEGER", + "InetAddressType": "STRING", + "CounterColumnType": "INTEGER", + "DateType": "TIMESTAMP", + "SimpleDateType": "DATE", + "TimestampType": "TIMESTAMP", + "TimeUUIDType": "BYTES", + "ShortType": "INTEGER", + "TimeType": "TIME", + "DurationType": "INTEGER", + "UTF8Type": "STRING", + "VarcharType": "STRING", + } + + def execute(self, context: Dict[str, str]): + hook = CassandraHook(cassandra_conn_id=self.cassandra_conn_id) + cursor = hook.get_conn().execute(self.cql) + + files_to_upload = self._write_local_data_files(cursor) + + # If a schema is set, create a BQ schema JSON file. + if self.schema_filename: + files_to_upload.update(self._write_local_schema_file(cursor)) + + # Flush all files before uploading + for file_handle in files_to_upload.values(): + file_handle.flush() + + self._upload_to_gcs(files_to_upload) + + # Close all temp file handles. + for file_handle in files_to_upload.values(): + file_handle.close() + + # Close all sessions and connection associated with this Cassandra cluster + hook.shutdown_cluster() + + def _write_local_data_files(self, cursor): + """ + Takes a cursor, and writes results to a local file. + + :return: A dictionary where keys are filenames to be used as object + names in GCS, and values are file handles to local files that + contain the data for the GCS objects. + """ + file_no = 0 + tmp_file_handle = NamedTemporaryFile(delete=True) + tmp_file_handles = {self.filename.format(file_no): tmp_file_handle} + for row in cursor: + row_dict = self.generate_data_dict(row._fields, row) + content = json.dumps(row_dict).encode("utf-8") + tmp_file_handle.write(content) + + # Append newline to make dumps BigQuery compatible. + tmp_file_handle.write(b"\n") + + if tmp_file_handle.tell() >= self.approx_max_file_size_bytes: + file_no += 1 + tmp_file_handle = NamedTemporaryFile(delete=True) + tmp_file_handles[self.filename.format(file_no)] = tmp_file_handle + + return tmp_file_handles + + def _write_local_schema_file(self, cursor): + """ + Takes a cursor, and writes the BigQuery schema for the results to a + local file system. + + :return: A dictionary where key is a filename to be used as an object + name in GCS, and values are file handles to local files that + contains the BigQuery schema fields in .json format. + """ + schema = [] + tmp_schema_file_handle = NamedTemporaryFile(delete=True) + + for name, type_ in zip(cursor.column_names, cursor.column_types): + schema.append(self.generate_schema_dict(name, type_)) + json_serialized_schema = json.dumps(schema).encode("utf-8") + + tmp_schema_file_handle.write(json_serialized_schema) + return {self.schema_filename: tmp_schema_file_handle} + + def _upload_to_gcs(self, files_to_upload: Dict[str, Any]): + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + for obj, tmp_file_handle in files_to_upload.items(): + hook.upload( + bucket_name=self.bucket, + object_name=obj, + filename=tmp_file_handle.name, + mime_type="application/json", + gzip=self.gzip, + ) + + @classmethod + def generate_data_dict(cls, names: Iterable[str], values: Any) -> Dict[str, Any]: + """Generates data structure that will be stored as file in GCS.""" + return {n: cls.convert_value(v) for n, v in zip(names, values)} + + @classmethod + def convert_value( # pylint: disable=too-many-return-statements + cls, value: Optional[Any] + ) -> Optional[Any]: + """Convert value to BQ type.""" + if not value: + return value + elif isinstance(value, (str, int, float, bool, dict)): + return value + elif isinstance(value, bytes): + return b64encode(value).decode("ascii") + elif isinstance(value, UUID): + return b64encode(value.bytes).decode("ascii") + elif isinstance(value, (datetime, Date)): + return str(value) + elif isinstance(value, Decimal): + return float(value) + elif isinstance(value, Time): + return str(value).split(".")[0] + elif isinstance(value, (list, SortedSet)): + return cls.convert_array_types(value) + elif hasattr(value, "_fields"): + return cls.convert_user_type(value) + elif isinstance(value, tuple): + return cls.convert_tuple_type(value) + elif isinstance(value, OrderedMapSerializedKey): + return cls.convert_map_type(value) + else: + raise AirflowException("Unexpected value: " + str(value)) + + @classmethod + def convert_array_types(cls, value: Union[List[Any], SortedSet]) -> List[Any]: + """Maps convert_value over array.""" + return [cls.convert_value(nested_value) for nested_value in value] + + @classmethod + def convert_user_type(cls, value: Any) -> Dict[str, Any]: + """ + Converts a user type to RECORD that contains n fields, where n is the + number of attributes. Each element in the user type class will be converted to its + corresponding data type in BQ. + """ + names = value._fields + values = [cls.convert_value(getattr(value, name)) for name in names] + return cls.generate_data_dict(names, values) + + @classmethod + def convert_tuple_type(cls, values: Tuple[Any]) -> Dict[str, Any]: + """ + Converts a tuple to RECORD that contains n fields, each will be converted + to its corresponding data type in bq and will be named 'field_', where + index is determined by the order of the tuple elements defined in cassandra. + """ + names = ["field_" + str(i) for i in range(len(values))] + return cls.generate_data_dict(names, values) + + @classmethod + def convert_map_type(cls, value: OrderedMapSerializedKey) -> List[Dict[str, Any]]: + """ + Converts a map to a repeated RECORD that contains two fields: 'key' and 'value', + each will be converted to its corresponding data type in BQ. + """ + converted_map = [] + for k, v in zip(value.keys(), value.values()): + converted_map.append( + {"key": cls.convert_value(k), "value": cls.convert_value(v)} + ) + return converted_map + + @classmethod + def generate_schema_dict(cls, name: str, type_: Any) -> Dict[str, Any]: + """Generates BQ schema.""" + field_schema: Dict[str, Any] = {} + field_schema.update({"name": name}) + field_schema.update({"type_": cls.get_bq_type(type_)}) + field_schema.update({"mode": cls.get_bq_mode(type_)}) + fields = cls.get_bq_fields(type_) + if fields: + field_schema.update({"fields": fields}) + return field_schema + + @classmethod + def get_bq_fields(cls, type_: Any) -> List[Dict[str, Any]]: + """Converts non simple type value to BQ representation.""" + if cls.is_simple_type(type_): + return [] + + # In case of not simple type + names: List[str] = [] + types: List[Any] = [] + if cls.is_array_type(type_) and cls.is_record_type(type_.subtypes[0]): + names = type_.subtypes[0].fieldnames + types = type_.subtypes[0].subtypes + elif cls.is_record_type(type_): + names = type_.fieldnames + types = type_.subtypes + + if types and not names and type_.cassname == "TupleType": + names = ["field_" + str(i) for i in range(len(types))] + elif types and not names and type_.cassname == "MapType": + names = ["key", "value"] + + return [cls.generate_schema_dict(n, t) for n, t in zip(names, types)] + + @staticmethod + def is_simple_type(type_: Any) -> bool: + """Check if type is a simple type.""" + return type_.cassname in CassandraToGCSOperator.CQL_TYPE_MAP + + @staticmethod + def is_array_type(type_: Any) -> bool: + """Check if type is an array type.""" + return type_.cassname in ["ListType", "SetType"] + + @staticmethod + def is_record_type(type_: Any) -> bool: + """Checks the record type.""" + return type_.cassname in ["UserType", "TupleType", "MapType"] + + @classmethod + def get_bq_type(cls, type_: Any) -> str: + """Converts type to equivalent BQ type.""" + if cls.is_simple_type(type_): + return CassandraToGCSOperator.CQL_TYPE_MAP[type_.cassname] + elif cls.is_record_type(type_): + return "RECORD" + elif cls.is_array_type(type_): + return cls.get_bq_type(type_.subtypes[0]) + else: + raise AirflowException("Not a supported type_: " + type_.cassname) + + @classmethod + def get_bq_mode(cls, type_: Any) -> str: + """Converts type to equivalent BQ mode.""" + if cls.is_array_type(type_) or type_.cassname == "MapType": + return "REPEATED" + elif cls.is_record_type(type_) or cls.is_simple_type(type_): + return "NULLABLE" + else: + raise AirflowException("Not a supported type_: " + type_.cassname) diff --git a/reference/providers/google/cloud/transfers/facebook_ads_to_gcs.py b/reference/providers/google/cloud/transfers/facebook_ads_to_gcs.py new file mode 100644 index 0000000..f4fec21 --- /dev/null +++ b/reference/providers/google/cloud/transfers/facebook_ads_to_gcs.py @@ -0,0 +1,135 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Facebook Ad Reporting to GCS operators.""" +import csv +import tempfile +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.facebook.ads.hooks.ads import FacebookAdsReportingHook +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils.decorators import apply_defaults + + +class FacebookAdsReportToGcsOperator(BaseOperator): + """ + Fetches the results from the Facebook Ads API as desired in the params + Converts and saves the data as a temporary JSON file + Uploads the JSON to Google Cloud Storage + + .. seealso:: + For more information on the Facebook Ads API, take a look at the API docs: + https://developers.facebook.com/docs/marketing-apis/ + + .. seealso:: + For more information on the Facebook Ads Python SDK, take a look at the docs: + https://github.com/facebook/facebook-python-business-sdk + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:FacebookAdsReportToGcsOperator` + + :param bucket_name: The GCS bucket to upload to + :type bucket_name: str + :param object_name: GCS path to save the object. Must be the full file path (ex. `path/to/file.txt`) + :type object_name: str + :param gcp_conn_id: Airflow Google Cloud connection ID + :type gcp_conn_id: str + :param facebook_conn_id: Airflow Facebook Ads connection ID + :type facebook_conn_id: str + :param api_version: The version of Facebook API. Default to v6.0 + :type api_version: str + :param fields: List of fields that is obtained from Facebook. Found in AdsInsights.Field class. + https://developers.facebook.com/docs/marketing-api/insights/parameters/v6.0 + :type fields: List[str] + :param params: Parameters that determine the query for Facebook + https://developers.facebook.com/docs/marketing-api/insights/parameters/v6.0 + :type params: Dict[str, Any] + :param gzip: Option to compress local file or file data for upload + :type gzip: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "facebook_conn_id", + "bucket_name", + "object_name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + bucket_name: str, + object_name: str, + fields: List[str], + params: Dict[str, Any], + gzip: bool = False, + api_version: str = "v6.0", + gcp_conn_id: str = "google_cloud_default", + facebook_conn_id: str = "facebook_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.bucket_name = bucket_name + self.object_name = object_name + self.gcp_conn_id = gcp_conn_id + self.facebook_conn_id = facebook_conn_id + self.api_version = api_version + self.fields = fields + self.params = params + self.gzip = gzip + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict): + service = FacebookAdsReportingHook( + facebook_conn_id=self.facebook_conn_id, api_version=self.api_version + ) + rows = service.bulk_facebook_report(params=self.params, fields=self.fields) + + converted_rows = [dict(row) for row in rows] + self.log.info("Facebook Returned %s data points", len(converted_rows)) + + if converted_rows: + headers = converted_rows[0].keys() + with tempfile.NamedTemporaryFile("w", suffix=".csv") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=headers) + writer.writeheader() + writer.writerows(converted_rows) + csvfile.flush() + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + hook.upload( + bucket_name=self.bucket_name, + object_name=self.object_name, + filename=csvfile.name, + gzip=self.gzip, + ) + self.log.info("%s uploaded to GCS", csvfile.name) diff --git a/reference/providers/google/cloud/transfers/gcs_to_bigquery.py b/reference/providers/google/cloud/transfers/gcs_to_bigquery.py new file mode 100644 index 0000000..b7de65a --- /dev/null +++ b/reference/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -0,0 +1,358 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Storage to BigQuery operator.""" + +import json +from typing import Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils.decorators import apply_defaults + + +# pylint: disable=too-many-instance-attributes +class GCSToBigQueryOperator(BaseOperator): + """ + Loads files from Google Cloud Storage into BigQuery. + + The schema to be used for the BigQuery table may be specified in one of + two ways. You may either directly pass the schema fields in, or you may + point the operator to a Google Cloud Storage object name. The object in + Google Cloud Storage must be a JSON file with the schema fields in it. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GCSToBigQueryOperator` + + :param bucket: The bucket to load from. (templated) + :type bucket: str + :param source_objects: List of Google Cloud Storage URIs to load from. (templated) + If source_format is 'DATASTORE_BACKUP', the list must only contain a single URI. + :type source_objects: list[str] + :param destination_project_dataset_table: The dotted + ``(.|:).
`` BigQuery table to load data into. + If ```` is not included, project will be the project defined in + the connection json. (templated) + :type destination_project_dataset_table: str + :param schema_fields: If set, the schema field list as defined here: + https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.load + Should not be set when source_format is 'DATASTORE_BACKUP'. + Parameter must be defined if 'schema_object' is null and autodetect is False. + :type schema_fields: list + :param schema_object: If set, a GCS object path pointing to a .json file that + contains the schema for the table. (templated) + Parameter must be defined if 'schema_fields' is null and autodetect is False. + :type schema_object: str + :param source_format: File format to export. + :type source_format: str + :param compression: [Optional] The compression type of the data source. + Possible values include GZIP and NONE. + The default value is NONE. + This setting is ignored for Google Cloud Bigtable, + Google Cloud Datastore backups and Avro formats. + :type compression: str + :param create_disposition: The create disposition if the table doesn't exist. + :type create_disposition: str + :param skip_leading_rows: Number of rows to skip when loading from a CSV. + :type skip_leading_rows: int + :param write_disposition: The write disposition if the table already exists. + :type write_disposition: str + :param field_delimiter: The delimiter to use when loading from a CSV. + :type field_delimiter: str + :param max_bad_records: The maximum number of bad records that BigQuery can + ignore when running the job. + :type max_bad_records: int + :param quote_character: The value that is used to quote data sections in a CSV file. + :type quote_character: str + :param ignore_unknown_values: [Optional] Indicates if BigQuery should allow + extra values that are not represented in the table schema. + If true, the extra values are ignored. If false, records with extra columns + are treated as bad records, and if there are too many bad records, an + invalid error is returned in the job result. + :type ignore_unknown_values: bool + :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not (false). + :type allow_quoted_newlines: bool + :param allow_jagged_rows: Accept rows that are missing trailing optional columns. + The missing values are treated as nulls. If false, records with missing trailing + columns are treated as bad records, and if there are too many bad records, an + invalid error is returned in the job result. Only applicable to CSV, ignored + for other formats. + :type allow_jagged_rows: bool + :param encoding: The character encoding of the data. See: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).csvOptions.encoding + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#externalDataConfiguration.csvOptions.encoding + :param max_id_key: If set, the name of a column in the BigQuery table + that's to be loaded. This will be used to select the MAX value from + BigQuery after the load occurs. The results will be returned by the + execute() command, which in turn gets stored in XCom for future + operators to use. This can be helpful with incremental loads--during + future executions, you can pick up from the max ID. + :type max_id_key: str + :param bigquery_conn_id: (Optional) The connection ID used to connect to Google Cloud and + interact with the BigQuery service. + :type bigquery_conn_id: str + :param google_cloud_storage_conn_id: (Optional) The connection ID used to connect to Google Cloud + and interact with the Google Cloud Storage service. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param schema_update_options: Allows the schema of the destination + table to be updated as a side effect of the load job. + :type schema_update_options: list + :param src_fmt_configs: configure optional fields specific to the source format + :type src_fmt_configs: dict + :param external_table: Flag to specify if the destination table should be + a BigQuery external table. Default Value is False. + :type external_table: bool + :param time_partitioning: configure optional time partitioning fields i.e. + partition by field, type and expiration as per API specifications. + Note that 'field' is not available in concurrency with + dataset.table$partition. + :type time_partitioning: dict + :param cluster_fields: Request that the result of this load be stored sorted + by one or more columns. BigQuery supports clustering for both partitioned and + non-partitioned tables. The order of columns given determines the sort order. + Not applicable for external tables. + :type cluster_fields: list[str] + :param autodetect: [Optional] Indicates if we should automatically infer the + options and schema for CSV and JSON sources. (Default: ``True``). + Parameter must be setted to True if 'schema_fields' and 'schema_object' are undefined. + It is suggested to set to True if table are create outside of Airflow. + :type autodetect: bool + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + :param location: [Optional] The geographic location of the job. Required except for US and EU. + See details at https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :type location: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + :param labels: [Optional] Labels for the BiqQuery table. + :type labels: dict + :param description: [Optional] Description for the BigQuery table. + :type description: str + """ + + template_fields = ( + "bucket", + "source_objects", + "schema_object", + "destination_project_dataset_table", + "impersonation_chain", + ) + template_ext = (".sql",) + ui_color = "#f0eee4" + + # pylint: disable=too-many-locals,too-many-arguments + @apply_defaults + def __init__( + self, + *, + bucket, + source_objects, + destination_project_dataset_table, + schema_fields=None, + schema_object=None, + source_format="CSV", + compression="NONE", + create_disposition="CREATE_IF_NEEDED", + skip_leading_rows=0, + write_disposition="WRITE_EMPTY", + field_delimiter=",", + max_bad_records=0, + quote_character=None, + ignore_unknown_values=False, + allow_quoted_newlines=False, + allow_jagged_rows=False, + encoding="UTF-8", + max_id_key=None, + bigquery_conn_id="google_cloud_default", + google_cloud_storage_conn_id="google_cloud_default", + delegate_to=None, + schema_update_options=(), + src_fmt_configs=None, + external_table=False, + time_partitioning=None, + cluster_fields=None, + autodetect=True, + encryption_configuration=None, + location=None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + labels=None, + description=None, + **kwargs, + ): + + super().__init__(**kwargs) + + # GCS config + if src_fmt_configs is None: + src_fmt_configs = {} + if time_partitioning is None: + time_partitioning = {} + self.bucket = bucket + self.source_objects = source_objects + self.schema_object = schema_object + + # BQ config + self.destination_project_dataset_table = destination_project_dataset_table + self.schema_fields = schema_fields + self.source_format = source_format + self.compression = compression + self.create_disposition = create_disposition + self.skip_leading_rows = skip_leading_rows + self.write_disposition = write_disposition + self.field_delimiter = field_delimiter + self.max_bad_records = max_bad_records + self.quote_character = quote_character + self.ignore_unknown_values = ignore_unknown_values + self.allow_quoted_newlines = allow_quoted_newlines + self.allow_jagged_rows = allow_jagged_rows + self.external_table = external_table + self.encoding = encoding + + self.max_id_key = max_id_key + self.bigquery_conn_id = bigquery_conn_id + self.google_cloud_storage_conn_id = google_cloud_storage_conn_id + self.delegate_to = delegate_to + + self.schema_update_options = schema_update_options + self.src_fmt_configs = src_fmt_configs + self.time_partitioning = time_partitioning + self.cluster_fields = cluster_fields + self.autodetect = autodetect + self.encryption_configuration = encryption_configuration + self.location = location + self.impersonation_chain = impersonation_chain + + self.labels = labels + self.description = description + + def execute(self, context): + bq_hook = BigQueryHook( + bigquery_conn_id=self.bigquery_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + if not self.schema_fields: + if self.schema_object and self.source_format != "DATASTORE_BACKUP": + gcs_hook = GCSHook( + gcp_conn_id=self.google_cloud_storage_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + blob = gcs_hook.download( + bucket_name=self.bucket, + object_name=self.schema_object, + ) + schema_fields = json.loads(blob.decode("utf-8")) + elif self.schema_object is None and self.autodetect is False: + raise AirflowException( + "At least one of `schema_fields`, `schema_object`, or `autodetect` must be passed." + ) + else: + schema_fields = None + + else: + schema_fields = self.schema_fields + + source_uris = [ + f"gs://{self.bucket}/{source_object}" + for source_object in self.source_objects + ] + conn = bq_hook.get_conn() + cursor = conn.cursor() + + if self.external_table: + cursor.create_external_table( + external_project_dataset_table=self.destination_project_dataset_table, + schema_fields=schema_fields, + source_uris=source_uris, + source_format=self.source_format, + compression=self.compression, + skip_leading_rows=self.skip_leading_rows, + field_delimiter=self.field_delimiter, + max_bad_records=self.max_bad_records, + quote_character=self.quote_character, + ignore_unknown_values=self.ignore_unknown_values, + allow_quoted_newlines=self.allow_quoted_newlines, + allow_jagged_rows=self.allow_jagged_rows, + encoding=self.encoding, + src_fmt_configs=self.src_fmt_configs, + encryption_configuration=self.encryption_configuration, + labels=self.labels, + description=self.description, + ) + else: + cursor.run_load( + destination_project_dataset_table=self.destination_project_dataset_table, + schema_fields=schema_fields, + source_uris=source_uris, + source_format=self.source_format, + autodetect=self.autodetect, + create_disposition=self.create_disposition, + skip_leading_rows=self.skip_leading_rows, + write_disposition=self.write_disposition, + field_delimiter=self.field_delimiter, + max_bad_records=self.max_bad_records, + quote_character=self.quote_character, + ignore_unknown_values=self.ignore_unknown_values, + allow_quoted_newlines=self.allow_quoted_newlines, + allow_jagged_rows=self.allow_jagged_rows, + encoding=self.encoding, + schema_update_options=self.schema_update_options, + src_fmt_configs=self.src_fmt_configs, + time_partitioning=self.time_partitioning, + cluster_fields=self.cluster_fields, + encryption_configuration=self.encryption_configuration, + labels=self.labels, + description=self.description, + ) + + if cursor.use_legacy_sql: + escaped_table_name = f"[{self.destination_project_dataset_table}]" + else: + escaped_table_name = f"`{self.destination_project_dataset_table}`" + + if self.max_id_key: + cursor.execute(f"SELECT MAX({self.max_id_key}) FROM {escaped_table_name}") + row = cursor.fetchone() + max_id = row[0] if row[0] else 0 + self.log.info( + "Loaded BQ data with max %s.%s=%s", + self.destination_project_dataset_table, + self.max_id_key, + max_id, + ) diff --git a/reference/providers/google/cloud/transfers/gcs_to_gcs.py b/reference/providers/google/cloud/transfers/gcs_to_gcs.py new file mode 100644 index 0000000..3d8bc66 --- /dev/null +++ b/reference/providers/google/cloud/transfers/gcs_to_gcs.py @@ -0,0 +1,475 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Storage operator.""" +import warnings +from typing import Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils.decorators import apply_defaults + +WILDCARD = "*" + + +class GCSToGCSOperator(BaseOperator): + """ + Copies objects from a bucket to another, with renaming if requested. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GCSToGCSOperator` + + :param source_bucket: The source Google Cloud Storage bucket where the + object is. (templated) + :type source_bucket: str + :param source_object: The source name of the object to copy in the Google cloud + storage bucket. (templated) + You can use only one wildcard for objects (filenames) within your + bucket. The wildcard can appear inside the object name or at the + end of the object name. Appending a wildcard to the bucket name is + unsupported. + :type source_object: str + :param source_objects: A list of source name of the objects to copy in the Google cloud + storage bucket. (templated) + :type source_objects: List[str] + :param destination_bucket: The destination Google Cloud Storage bucket + where the object should be. If the destination_bucket is None, it defaults + to source_bucket. (templated) + :type destination_bucket: str + :param destination_object: The destination name of the object in the + destination Google Cloud Storage bucket. (templated) + If a wildcard is supplied in the source_object argument, this is the + prefix that will be prepended to the final destination objects' paths. + Note that the source path's part before the wildcard will be removed; + if it needs to be retained it should be appended to destination_object. + For example, with prefix ``foo/*`` and destination_object ``blah/``, the + file ``foo/baz`` will be copied to ``blah/baz``; to retain the prefix write + the destination_object as e.g. ``blah/foo``, in which case the copied file + will be named ``blah/foo/baz``. + The same thing applies to source objects inside source_objects. + :type destination_object: str + :param move_object: When move object is True, the object is moved instead + of copied to the new location. This is the equivalent of a mv command + as opposed to a cp command. + :type move_object: bool + :param replace: Whether you want to replace existing destination files or not. + :type replace: bool + :param delimiter: This is used to restrict the result to only the 'files' in a given 'folder'. + If source_objects = ['foo/bah/'] and delimiter = '.avro', then only the 'files' in the + folder 'foo/bah/' with '.avro' delimiter will be copied to the destination object. + :type delimiter: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param last_modified_time: When specified, the objects will be copied or moved, + only if they were modified after last_modified_time. + If tzinfo has not been set, UTC will be assumed. + :type last_modified_time: datetime.datetime + :param maximum_modified_time: When specified, the objects will be copied or moved, + only if they were modified before maximum_modified_time. + If tzinfo has not been set, UTC will be assumed. + :type maximum_modified_time: datetime.datetime + :param is_older_than: When specified, the objects will be copied if they are older + than the specified time in seconds. + :type is_older_than: int + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + :Example: + + The following Operator would copy a single file named + ``sales/sales-2017/january.avro`` in the ``data`` bucket to the file named + ``copied_sales/2017/january-backup.avro`` in the ``data_backup`` bucket :: + + copy_single_file = GCSToGCSOperator( + task_id='copy_single_file', + source_bucket='data', + source_objects=['sales/sales-2017/january.avro'], + destination_bucket='data_backup', + destination_object='copied_sales/2017/january-backup.avro', + gcp_conn_id=google_cloud_conn_id + ) + + The following Operator would copy all the Avro files from ``sales/sales-2017`` + folder (i.e. with names starting with that prefix) in ``data`` bucket to the + ``copied_sales/2017`` folder in the ``data_backup`` bucket. :: + + copy_files = GCSToGCSOperator( + task_id='copy_files', + source_bucket='data', + source_objects=['sales/sales-2017'], + destination_bucket='data_backup', + destination_object='copied_sales/2017/', + delimiter='.avro' + gcp_conn_id=google_cloud_conn_id + ) + + Or :: + + copy_files = GCSToGCSOperator( + task_id='copy_files', + source_bucket='data', + source_object='sales/sales-2017/*.avro', + destination_bucket='data_backup', + destination_object='copied_sales/2017/', + gcp_conn_id=google_cloud_conn_id + ) + + The following Operator would move all the Avro files from ``sales/sales-2017`` + folder (i.e. with names starting with that prefix) in ``data`` bucket to the + same folder in the ``data_backup`` bucket, deleting the original files in the + process. :: + + move_files = GCSToGCSOperator( + task_id='move_files', + source_bucket='data', + source_object='sales/sales-2017/*.avro', + destination_bucket='data_backup', + move_object=True, + gcp_conn_id=google_cloud_conn_id + ) + + The following Operator would move all the Avro files from ``sales/sales-2019`` + and ``sales/sales-2020` folder in ``data`` bucket to the same folder in the + ``data_backup`` bucket, deleting the original files in the process. :: + + move_files = GCSToGCSOperator( + task_id='move_files', + source_bucket='data', + source_objects=['sales/sales-2019/*.avro', 'sales/sales-2020'], + destination_bucket='data_backup', + delimiter='.avro', + move_object=True, + gcp_conn_id=google_cloud_conn_id + ) + + """ + + template_fields = ( + "source_bucket", + "source_object", + "source_objects", + "destination_bucket", + "destination_object", + "delimiter", + "impersonation_chain", + ) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + *, # pylint: disable=too-many-arguments + source_bucket, + source_object=None, + source_objects=None, + destination_bucket=None, + destination_object=None, + delimiter=None, + move_object=False, + replace=True, + gcp_conn_id="google_cloud_default", + google_cloud_storage_conn_id=None, + delegate_to=None, + last_modified_time=None, + maximum_modified_time=None, + is_older_than=None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = google_cloud_storage_conn_id + + self.source_bucket = source_bucket + self.source_object = source_object + self.source_objects = source_objects + self.destination_bucket = destination_bucket + self.destination_object = destination_object + self.delimiter = delimiter + self.move_object = move_object + self.replace = replace + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.last_modified_time = last_modified_time + self.maximum_modified_time = maximum_modified_time + self.is_older_than = is_older_than + self.impersonation_chain = impersonation_chain + + def execute(self, context): + + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + if self.source_objects and self.source_object: + error_msg = ( + "You can either set source_object parameter or source_objects " + "parameter but not both. Found source_object={} and" + " source_objects={}".format(self.source_object, self.source_objects) + ) + raise AirflowException(error_msg) + + if not self.source_object and not self.source_objects: + error_msg = "You must set source_object parameter or source_objects parameter. None set" + raise AirflowException(error_msg) + + if self.source_objects and not all( + isinstance(item, str) for item in self.source_objects + ): + raise AirflowException( + "At least, one of the `objects` in the `source_objects` is not a string" + ) + + # If source_object is set, default it to source_objects + if self.source_object: + self.source_objects = [self.source_object] + + if self.destination_bucket is None: + self.log.warning( + "destination_bucket is None. Defaulting it to source_bucket (%s)", + self.source_bucket, + ) + self.destination_bucket = self.source_bucket + + # An empty source_object means to copy all files + if len(self.source_objects) == 0: + self.source_objects = [""] + # Raise exception if empty string `''` is used twice in source_object, this is to avoid double copy + if self.source_objects.count("") > 1: + raise AirflowException( + "You can't have two empty strings inside source_object" + ) + + # Iterate over the source_objects and do the copy + for prefix in self.source_objects: + # Check if prefix contains wildcard + if WILDCARD in prefix: + self._copy_source_with_wildcard(hook=hook, prefix=prefix) + # Now search with prefix using provided delimiter if any + else: + self._copy_source_without_wildcard(hook=hook, prefix=prefix) + + def _copy_source_without_wildcard(self, hook, prefix): + """ + + + For source_objects with no wildcard, this operator would first list + all files in source_objects, using provided delimiter if any. Then copy + files from source_objects to destination_object and rename each source + file. + + Example 1: + + + The following Operator would copy all the files from ``a/``folder + (i.e a/a.csv, a/b.csv, a/c.csv)in ``data`` bucket to the ``b/`` folder in + the ``data_backup`` bucket (b/a.csv, b/b.csv, b/c.csv) :: + + copy_files = GCSToGCSOperator( + task_id='copy_files_without_wildcard', + source_bucket='data', + source_objects=['a/'], + destination_bucket='data_backup', + destination_object='b/', + gcp_conn_id=google_cloud_conn_id + ) + + Example 2: + + + The following Operator would copy all avro files from ``a/``folder + (i.e a/a.avro, a/b.avro, a/c.avro)in ``data`` bucket to the ``b/`` folder in + the ``data_backup`` bucket (b/a.avro, b/b.avro, b/c.avro) :: + + copy_files = GCSToGCSOperator( + task_id='copy_files_without_wildcard', + source_bucket='data', + source_objects=['a/'], + destination_bucket='data_backup', + destination_object='b/', + delimiter='.avro', + gcp_conn_id=google_cloud_conn_id + ) + """ + objects = hook.list(self.source_bucket, prefix=prefix, delimiter=self.delimiter) + + # If objects is empty and we have prefix, let's check if prefix is a blob + # and copy directly + if len(objects) == 0 and prefix: + if hook.exists(self.source_bucket, prefix): + self._copy_single_object( + hook=hook, + source_object=prefix, + destination_object=self.destination_object, + ) + for source_obj in objects: + if self.destination_object is None: + destination_object = source_obj + else: + destination_object = source_obj.replace( + prefix, self.destination_object, 1 + ) + self._copy_single_object( + hook=hook, + source_object=source_obj, + destination_object=destination_object, + ) + + def _copy_source_with_wildcard(self, hook, prefix): + total_wildcards = prefix.count(WILDCARD) + if total_wildcards > 1: + error_msg = ( + "Only one wildcard '*' is allowed in source_object parameter. " + "Found {} in {}.".format(total_wildcards, prefix) + ) + + raise AirflowException(error_msg) + self.log.info("Delimiter ignored because wildcard is in prefix") + prefix_, delimiter = prefix.split(WILDCARD, 1) + objects = hook.list(self.source_bucket, prefix=prefix_, delimiter=delimiter) + if not self.replace: + # If we are not replacing, list all files in the Destination GCS bucket + # and only keep those files which are present in + # Source GCS bucket and not in Destination GCS bucket + + existing_objects = hook.list( + self.destination_bucket, prefix=prefix_, delimiter=delimiter + ) + + objects = set(objects) - set(existing_objects) + if len(objects) > 0: + self.log.info( + "%s files are going to be synced: %s.", len(objects), objects + ) + else: + self.log.info("There are no new files to sync. Have a nice day!") + for source_object in objects: + if self.destination_object is None: + destination_object = source_object + else: + destination_object = source_object.replace( + prefix_, self.destination_object, 1 + ) + + self._copy_single_object( + hook=hook, + source_object=source_object, + destination_object=destination_object, + ) + + def _copy_single_object(self, hook, source_object, destination_object): + if self.is_older_than: + # Here we check if the given object is older than the given time + # If given, last_modified_time and maximum_modified_time is ignored + if hook.is_older_than( + self.source_bucket, source_object, self.is_older_than + ): + self.log.info("Object is older than %s seconds ago", self.is_older_than) + else: + self.log.debug( + "Object is not older than %s seconds ago", self.is_older_than + ) + return + elif self.last_modified_time and self.maximum_modified_time: + # check to see if object was modified between last_modified_time and + # maximum_modified_time + if hook.is_updated_between( + self.source_bucket, + source_object, + self.last_modified_time, + self.maximum_modified_time, + ): + self.log.info( + "Object has been modified between %s and %s", + self.last_modified_time, + self.maximum_modified_time, + ) + else: + self.log.debug( + "Object was not modified between %s and %s", + self.last_modified_time, + self.maximum_modified_time, + ) + return + + elif self.last_modified_time is not None: + # Check to see if object was modified after last_modified_time + if hook.is_updated_after( + self.source_bucket, source_object, self.last_modified_time + ): + self.log.info( + "Object has been modified after %s ", self.last_modified_time + ) + else: + self.log.debug( + "Object was not modified after %s ", self.last_modified_time + ) + return + elif self.maximum_modified_time is not None: + # Check to see if object was modified before maximum_modified_time + if hook.is_updated_before( + self.source_bucket, source_object, self.maximum_modified_time + ): + self.log.info( + "Object has been modified before %s ", self.maximum_modified_time + ) + else: + self.log.debug( + "Object was not modified before %s ", self.maximum_modified_time + ) + return + + self.log.info( + "Executing copy of gs://%s/%s to gs://%s/%s", + self.source_bucket, + source_object, + self.destination_bucket, + destination_object, + ) + + hook.rewrite( + self.source_bucket, + source_object, + self.destination_bucket, + destination_object, + ) + + if self.move_object: + hook.delete(self.source_bucket, source_object) diff --git a/reference/providers/google/cloud/transfers/gcs_to_local.py b/reference/providers/google/cloud/transfers/gcs_to_local.py new file mode 100644 index 0000000..becc4a4 --- /dev/null +++ b/reference/providers/google/cloud/transfers/gcs_to_local.py @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +import warnings +from typing import Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.models.xcom import MAX_XCOM_SIZE +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.sensors.base import apply_defaults + + +class GCSToLocalFilesystemOperator(BaseOperator): + """ + Downloads a file from Google Cloud Storage. + + If a filename is supplied, it writes the file to the specified location, alternatively one can + set the ``store_to_xcom_key`` parameter to True push the file content into xcom. When the file size + exceeds the maximum size for xcom it is recommended to write to a file. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GCSToLocalFilesystemOperator` + + :param bucket: The Google Cloud Storage bucket where the object is. + Must not contain 'gs://' prefix. (templated) + :type bucket: str + :param object_name: The name of the object to download in the Google cloud + storage bucket. (templated) + :type object_name: str + :param filename: The file path, including filename, on the local file system (where the + operator is being executed) that the file should be downloaded to. (templated) + If no filename passed, the downloaded data will not be stored on the local file + system. + :type filename: str + :param store_to_xcom_key: If this param is set, the operator will push + the contents of the downloaded file to XCom with the key set in this + parameter. If not set, the downloaded data will not be pushed to XCom. (templated) + :type store_to_xcom_key: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "bucket", + "object_name", + "filename", + "store_to_xcom_key", + "impersonation_chain", + ) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + *, + bucket: str, + object_name: Optional[str] = None, + filename: Optional[str] = None, + store_to_xcom_key: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + # To preserve backward compatibility + # TODO: Remove one day + if object_name is None: + if "object" in kwargs: + object_name = kwargs["object"] + DeprecationWarning("Use 'object_name' instead of 'object'.") + else: + TypeError( + "__init__() missing 1 required positional argument: 'object_name'" + ) + + if filename is not None and store_to_xcom_key is not None: + raise ValueError("Either filename or store_to_xcom_key can be set") + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = google_cloud_storage_conn_id + + super().__init__(**kwargs) + self.bucket = bucket + self.object = object_name + self.filename = filename # noqa + self.store_to_xcom_key = store_to_xcom_key # noqa + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context): + self.log.info( + "Executing download: %s, %s, %s", self.bucket, self.object, self.filename + ) + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + if self.store_to_xcom_key: + file_bytes = hook.download(bucket_name=self.bucket, object_name=self.object) + if sys.getsizeof(file_bytes) < MAX_XCOM_SIZE: + context["ti"].xcom_push( + key=self.store_to_xcom_key, value=str(file_bytes) + ) + else: + raise AirflowException( + "The size of the downloaded file is too large to push to XCom!" + ) + else: + hook.download( + bucket_name=self.bucket, object_name=self.object, filename=self.filename + ) diff --git a/reference/providers/google/cloud/transfers/gcs_to_sftp.py b/reference/providers/google/cloud/transfers/gcs_to_sftp.py new file mode 100644 index 0000000..39aaa43 --- /dev/null +++ b/reference/providers/google/cloud/transfers/gcs_to_sftp.py @@ -0,0 +1,226 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud Storage to SFTP operator.""" +import os +from tempfile import NamedTemporaryFile +from typing import Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.sftp.hooks.sftp import SFTPHook +from airflow.utils.decorators import apply_defaults + +WILDCARD = "*" + + +class GCSToSFTPOperator(BaseOperator): + """ + Transfer files from a Google Cloud Storage bucket to SFTP server. + + **Example**: :: + + with models.DAG( + "example_gcs_to_sftp", + start_date=datetime(2020, 6, 19), + schedule_interval=None, + ) as dag: + # downloads file to /tmp/sftp/folder/subfolder/file.txt + copy_file_from_gcs_to_sftp = GCSToSFTPOperator( + task_id="file-copy-gsc-to-sftp", + source_bucket="test-gcs-sftp-bucket-name", + source_object="folder/subfolder/file.txt", + destination_path="/tmp/sftp", + ) + + # moves file to /tmp/data.txt + move_file_from_gcs_to_sftp = GCSToSFTPOperator( + task_id="file-move-gsc-to-sftp", + source_bucket="test-gcs-sftp-bucket-name", + source_object="folder/subfolder/data.txt", + destination_path="/tmp", + move_object=True, + keep_directory_structure=False, + ) + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GCSToSFTPOperator` + + :param source_bucket: The source Google Cloud Storage bucket where the + object is. (templated) + :type source_bucket: str + :param source_object: The source name of the object to copy in the Google cloud + storage bucket. (templated) + You can use only one wildcard for objects (filenames) within your + bucket. The wildcard can appear inside the object name or at the + end of the object name. Appending a wildcard to the bucket name is + unsupported. + :type source_object: str + :param destination_path: The sftp remote path. This is the specified directory path for + uploading to the SFTP server. + :type destination_path: str + :param keep_directory_structure: (Optional) When set to False the path of the file + on the bucket is recreated within path passed in destination_path. + :type keep_directory_structure: bool + :param move_object: When move object is True, the object is moved instead + of copied to the new location. This is the equivalent of a mv command + as opposed to a cp command. + :type move_object: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param sftp_conn_id: The sftp connection id. The name or identifier for + establishing a connection to the SFTP server. + :type sftp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "source_bucket", + "source_object", + "destination_path", + "impersonation_chain", + ) + ui_color = "#f0eee4" + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + source_bucket: str, + source_object: str, + destination_path: str, + keep_directory_structure: bool = True, + move_object: bool = False, + gcp_conn_id: str = "google_cloud_default", + sftp_conn_id: str = "ssh_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.source_bucket = source_bucket + self.source_object = source_object + self.destination_path = destination_path + self.keep_directory_structure = keep_directory_structure + self.move_object = move_object + self.gcp_conn_id = gcp_conn_id + self.sftp_conn_id = sftp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.sftp_dirs = None + + def execute(self, context): + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + sftp_hook = SFTPHook(self.sftp_conn_id) + + if WILDCARD in self.source_object: + total_wildcards = self.source_object.count(WILDCARD) + if total_wildcards > 1: + raise AirflowException( + "Only one wildcard '*' is allowed in source_object parameter. " + "Found {} in {}.".format(total_wildcards, self.source_object) + ) + + prefix, delimiter = self.source_object.split(WILDCARD, 1) + prefix_dirname = os.path.dirname(prefix) + + objects = gcs_hook.list( + self.source_bucket, prefix=prefix, delimiter=delimiter + ) + + for source_object in objects: + destination_path = self._resolve_destination_path( + source_object, prefix=prefix_dirname + ) + self._copy_single_object( + gcs_hook, sftp_hook, source_object, destination_path + ) + + self.log.info( + "Done. Uploaded '%d' files to %s", len(objects), self.destination_path + ) + else: + destination_path = self._resolve_destination_path(self.source_object) + self._copy_single_object( + gcs_hook, sftp_hook, self.source_object, destination_path + ) + self.log.info( + "Done. Uploaded '%s' file to %s", self.source_object, destination_path + ) + + def _resolve_destination_path( + self, source_object: str, prefix: Optional[str] = None + ) -> str: + if not self.keep_directory_structure: + if prefix: + source_object = os.path.relpath(source_object, start=prefix) + else: + source_object = os.path.basename(source_object) + return os.path.join(self.destination_path, source_object) + + def _copy_single_object( + self, + gcs_hook: GCSHook, + sftp_hook: SFTPHook, + source_object: str, + destination_path: str, + ) -> None: + """Helper function to copy single object.""" + self.log.info( + "Executing copy of gs://%s/%s to %s", + self.source_bucket, + source_object, + destination_path, + ) + + dir_path = os.path.dirname(destination_path) + sftp_hook.create_directory(dir_path) + + with NamedTemporaryFile("w") as tmp: + gcs_hook.download( + bucket_name=self.source_bucket, + object_name=source_object, + filename=tmp.name, + ) + sftp_hook.store_file(destination_path, tmp.name) + + if self.move_object: + self.log.info( + "Executing delete of gs://%s/%s", self.source_bucket, source_object + ) + gcs_hook.delete(self.source_bucket, source_object) diff --git a/reference/providers/google/cloud/transfers/gdrive_to_gcs.py b/reference/providers/google/cloud/transfers/gdrive_to_gcs.py new file mode 100644 index 0000000..48d3b51 --- /dev/null +++ b/reference/providers/google/cloud/transfers/gdrive_to_gcs.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import warnings +from typing import Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.suite.hooks.drive import GoogleDriveHook +from airflow.utils.decorators import apply_defaults + + +class GoogleDriveToGCSOperator(BaseOperator): + """ + Writes a Google Drive file into Google Cloud Storage. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleDriveToGCSOperator` + + :param bucket_name: The destination Google cloud storage bucket where the + file should be written to + :type bucket_name: str + :param object_name: The Google Cloud Storage object name for the object created by the operator. + For example: ``path/to/my/file/file.txt``. + :type object_name: str + :param destination_bucket: Same as bucket_name, but for backward compatibly + :type destination_bucket: str + :param destination_object: Same as object_name, but for backward compatibly + :type destination_object: str + :param folder_id: The folder id of the folder in which the Google Drive file resides + :type folder_id: str + :param file_name: The name of the file residing in Google Drive + :type file_name: str + :param drive_id: Optional. The id of the shared Google Drive in which the file resides. + :type drive_id: str + :param gcp_conn_id: The GCP connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "bucket_name", + "object_name", + "destination_bucket", + "destination_object", + "folder_id", + "file_name", + "drive_id", + "impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + bucket_name: Optional[str] = None, + object_name: Optional[str] = None, + destination_bucket: Optional[str] = None, # deprecated + destination_object: Optional[str] = None, # deprecated + file_name: str, + folder_id: str, + drive_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.bucket_name = destination_bucket or bucket_name + if destination_bucket: + warnings.warn( + "`destination_bucket` is deprecated please use `bucket_name`", + DeprecationWarning, + ) + self.object_name = destination_object or object_name + if destination_object: + warnings.warn( + "`destination_object` is deprecated please use `object_name`", + DeprecationWarning, + ) + self.folder_id = folder_id + self.drive_id = drive_id + self.file_name = file_name + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context): + gdrive_hook = GoogleDriveHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + file_metadata = gdrive_hook.get_file_id( + folder_id=self.folder_id, file_name=self.file_name, drive_id=self.drive_id + ) + with gcs_hook.provide_file_and_upload( + bucket_name=self.bucket_name, object_name=self.object_name + ) as file: + gdrive_hook.download_file(file_id=file_metadata["id"], file_handle=file) diff --git a/reference/providers/google/cloud/transfers/gdrive_to_local.py b/reference/providers/google/cloud/transfers/gdrive_to_local.py new file mode 100644 index 0000000..658cd1f --- /dev/null +++ b/reference/providers/google/cloud/transfers/gdrive_to_local.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.suite.hooks.drive import GoogleDriveHook +from airflow.utils.decorators import apply_defaults + + +class GoogleDriveToLocalOperator(BaseOperator): + """ + Writes a Google Drive file into local Storage. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleDriveToLocalOperator` + + :param output_file: Path to downloaded file + :type output_file: str + :param folder_id: The folder id of the folder in which the Google Drive file resides + :type folder_id: str + :param file_name: The name of the file residing in Google Drive + :type file_name: str + :param drive_id: Optional. The id of the shared Google Drive in which the file resides. + :type drive_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "output_file", + "folder_id", + "file_name", + "drive_id", + "impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + output_file: str, + file_name: str, + folder_id: str, + drive_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.output_file = output_file + self.folder_id = folder_id + self.drive_id = drive_id + self.file_name = file_name + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context): + self.log.info( + "Executing download: %s into %s", self.file_name, self.output_file + ) + gdrive_hook = GoogleDriveHook( + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + file_metadata = gdrive_hook.get_file_id( + folder_id=self.folder_id, file_name=self.file_name, drive_id=self.drive_id + ) + + with open(self.output_file, "wb") as file: + gdrive_hook.download_file(file_id=file_metadata["id"], file_handle=file) diff --git a/reference/providers/google/cloud/transfers/local_to_gcs.py b/reference/providers/google/cloud/transfers/local_to_gcs.py new file mode 100644 index 0000000..e33f2fd --- /dev/null +++ b/reference/providers/google/cloud/transfers/local_to_gcs.py @@ -0,0 +1,141 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains operator for uploading local file(s) to GCS.""" +import os +import warnings +from glob import glob +from typing import Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils.decorators import apply_defaults + + +class LocalFilesystemToGCSOperator(BaseOperator): + """ + Uploads a file or list of files to Google Cloud Storage. + Optionally can compress the file for upload. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:LocalFilesystemToGCSOperator` + + :param src: Path to the local file, or list of local files. Path can be either absolute + (e.g. /path/to/file.ext) or relative (e.g. ../../foo/*/*.csv). (templated) + :type src: str or list + :param dst: Destination path within the specified bucket on GCS (e.g. /path/to/file.ext). + If multiple files are being uploaded, specify object prefix with trailing backslash + (e.g. /path/to/directory/) (templated) + :type dst: str + :param bucket: The bucket to upload to. (templated) + :type bucket: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param mime_type: The mime-type string + :type mime_type: str + :param delegate_to: The account to impersonate, if any + :type delegate_to: str + :param gzip: Allows for file to be compressed and uploaded as gzip + :type gzip: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "src", + "dst", + "bucket", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + src, + dst, + bucket, + gcp_conn_id="google_cloud_default", + google_cloud_storage_conn_id=None, + mime_type="application/octet-stream", + delegate_to=None, + gzip=False, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = google_cloud_storage_conn_id + + self.src = src + self.dst = dst + self.bucket = bucket + self.gcp_conn_id = gcp_conn_id + self.mime_type = mime_type + self.delegate_to = delegate_to + self.gzip = gzip + self.impersonation_chain = impersonation_chain + + def execute(self, context): + """Uploads a file or list of files to Google Cloud Storage""" + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + filepaths = self.src if isinstance(self.src, list) else glob(self.src) + if os.path.basename(self.dst): # path to a file + if len(filepaths) > 1: # multiple file upload + raise ValueError( + "'dst' parameter references filepath. Please specify " + "directory (with trailing backslash) to upload multiple " + "files. e.g. /path/to/directory/" + ) + object_paths = [self.dst] + else: # directory is provided + object_paths = [ + os.path.join(self.dst, os.path.basename(filepath)) + for filepath in filepaths + ] + + for filepath, object_path in zip(filepaths, object_paths): + hook.upload( + bucket_name=self.bucket, + object_name=object_path, + mime_type=self.mime_type, + filename=filepath, + gzip=self.gzip, + ) diff --git a/reference/providers/google/cloud/transfers/mssql_to_gcs.py b/reference/providers/google/cloud/transfers/mssql_to_gcs.py new file mode 100644 index 0000000..90ff7aa --- /dev/null +++ b/reference/providers/google/cloud/transfers/mssql_to_gcs.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""MsSQL to GCS operator.""" + +import decimal +from typing import Dict + +from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator +from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook +from airflow.utils.decorators import apply_defaults + + +class MSSQLToGCSOperator(BaseSQLToGCSOperator): + """Copy data from Microsoft SQL Server to Google Cloud Storage + in JSON or CSV format. + + :param mssql_conn_id: Reference to a specific MSSQL hook. + :type mssql_conn_id: str + + **Example**: + The following operator will export data from the Customers table + within the given MSSQL Database and then upload it to the + 'mssql-export' GCS bucket (along with a schema file). :: + + export_customers = MsSqlToGoogleCloudStorageOperator( + task_id='export_customers', + sql='SELECT * FROM dbo.Customers;', + bucket='mssql-export', + filename='data/customers/export.json', + schema_filename='schemas/export.json', + mssql_conn_id='mssql_default', + google_cloud_storage_conn_id='google_cloud_default', + dag=dag + ) + """ + + ui_color = "#e0a98c" + + type_map = {3: "INTEGER", 4: "TIMESTAMP", 5: "NUMERIC"} + + @apply_defaults + def __init__(self, *, mssql_conn_id="mssql_default", **kwargs): + super().__init__(**kwargs) + self.mssql_conn_id = mssql_conn_id + + def query(self): + """ + Queries MSSQL and returns a cursor of results. + + :return: mssql cursor + """ + mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id) + conn = mssql.get_conn() + cursor = conn.cursor() + cursor.execute(self.sql) + return cursor + + def field_to_bigquery(self, field) -> Dict[str, str]: + return { + "name": field[0].replace(" ", "_"), + "type": self.type_map.get(field[1], "STRING"), + "mode": "NULLABLE", + } + + @classmethod + def convert_type(cls, value, schema_type): + """ + Takes a value from MSSQL, and converts it to a value that's safe for + JSON/Google Cloud Storage/BigQuery. + """ + if isinstance(value, decimal.Decimal): + return float(value) + return value diff --git a/reference/providers/google/cloud/transfers/mysql_to_gcs.py b/reference/providers/google/cloud/transfers/mysql_to_gcs.py new file mode 100644 index 0000000..c98e7cd --- /dev/null +++ b/reference/providers/google/cloud/transfers/mysql_to_gcs.py @@ -0,0 +1,136 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""MySQL to GCS operator.""" + +import base64 +import calendar +from datetime import date, datetime, timedelta +from decimal import Decimal +from typing import Dict + +from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.utils.decorators import apply_defaults +from MySQLdb.constants import FIELD_TYPE + + +class MySQLToGCSOperator(BaseSQLToGCSOperator): + """Copy data from MySQL to Google Cloud Storage in JSON or CSV format. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MySQLToGCSOperator` + + :param mysql_conn_id: Reference to a specific MySQL hook. + :type mysql_conn_id: str + :param ensure_utc: Ensure TIMESTAMP columns exported as UTC. If set to + `False`, TIMESTAMP columns will be exported using the MySQL server's + default timezone. + :type ensure_utc: bool + """ + + ui_color = "#a0e08c" + + type_map = { + FIELD_TYPE.BIT: "INTEGER", + FIELD_TYPE.DATETIME: "TIMESTAMP", + FIELD_TYPE.DATE: "TIMESTAMP", + FIELD_TYPE.DECIMAL: "FLOAT", + FIELD_TYPE.NEWDECIMAL: "FLOAT", + FIELD_TYPE.DOUBLE: "FLOAT", + FIELD_TYPE.FLOAT: "FLOAT", + FIELD_TYPE.INT24: "INTEGER", + FIELD_TYPE.LONG: "INTEGER", + FIELD_TYPE.LONGLONG: "INTEGER", + FIELD_TYPE.SHORT: "INTEGER", + FIELD_TYPE.TIME: "TIME", + FIELD_TYPE.TIMESTAMP: "TIMESTAMP", + FIELD_TYPE.TINY: "INTEGER", + FIELD_TYPE.YEAR: "INTEGER", + } + + @apply_defaults + def __init__(self, *, mysql_conn_id="mysql_default", ensure_utc=False, **kwargs): + super().__init__(**kwargs) + self.mysql_conn_id = mysql_conn_id + self.ensure_utc = ensure_utc + + def query(self): + """Queries mysql and returns a cursor to the results.""" + mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) + conn = mysql.get_conn() + cursor = conn.cursor() + if self.ensure_utc: + # Ensure TIMESTAMP results are in UTC + tz_query = "SET time_zone = '+00:00'" + self.log.info("Executing: %s", tz_query) + cursor.execute(tz_query) + self.log.info("Executing: %s", self.sql) + cursor.execute(self.sql) + return cursor + + def field_to_bigquery(self, field) -> Dict[str, str]: + field_type = self.type_map.get(field[1], "STRING") + # Always allow TIMESTAMP to be nullable. MySQLdb returns None types + # for required fields because some MySQL timestamps can't be + # represented by Python's datetime (e.g. 0000-00-00 00:00:00). + field_mode = "NULLABLE" if field[6] or field_type == "TIMESTAMP" else "REQUIRED" + return { + "name": field[0], + "type": field_type, + "mode": field_mode, + } + + def convert_type(self, value, schema_type: str): + """ + Takes a value from MySQLdb, and converts it to a value that's safe for + JSON/Google Cloud Storage/BigQuery. + + * Datetimes are converted to UTC seconds. + * Decimals are converted to floats. + * Dates are converted to ISO formatted string if given schema_type is + DATE, or UTC seconds otherwise. + * Binary type fields are converted to integer if given schema_type is + INTEGER, or encoded with base64 otherwise. Imported BYTES data must + be base64-encoded according to BigQuery documentation: + https://cloud.google.com/bigquery/data-types + + :param value: MySQLdb column value + :type value: Any + :param schema_type: BigQuery data type + :type schema_type: str + """ + if value is None: + return value + if isinstance(value, datetime): + value = calendar.timegm(value.timetuple()) + elif isinstance(value, timedelta): + value = value.total_seconds() + elif isinstance(value, Decimal): + value = float(value) + elif isinstance(value, date): + if schema_type == "DATE": + value = value.isoformat() + else: + value = calendar.timegm(value.timetuple()) + elif isinstance(value, bytes): + if schema_type == "INTEGER": + value = int.from_bytes(value, "big") + else: + value = base64.standard_b64encode(value).decode("ascii") + return value diff --git a/reference/providers/google/cloud/transfers/oracle_to_gcs.py b/reference/providers/google/cloud/transfers/oracle_to_gcs.py new file mode 100644 index 0000000..80f4395 --- /dev/null +++ b/reference/providers/google/cloud/transfers/oracle_to_gcs.py @@ -0,0 +1,129 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=c-extension-no-member +import base64 +import calendar +from datetime import date, datetime, timedelta +from decimal import Decimal +from typing import Dict + +import cx_Oracle +from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator +from airflow.providers.oracle.hooks.oracle import OracleHook +from airflow.utils.decorators import apply_defaults + + +class OracleToGCSOperator(BaseSQLToGCSOperator): + """Copy data from Oracle to Google Cloud Storage in JSON or CSV format. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:OracleToGCSOperator` + + :param oracle_conn_id: Reference to a specific Oracle hook. + :type oracle_conn_id: str + :param ensure_utc: Ensure TIMESTAMP columns exported as UTC. If set to + `False`, TIMESTAMP columns will be exported using the Oracle server's + default timezone. + :type ensure_utc: bool + """ + + ui_color = "#a0e08c" + + type_map = { + cx_Oracle.DB_TYPE_BINARY_DOUBLE: "DECIMAL", + cx_Oracle.DB_TYPE_BINARY_FLOAT: "DECIMAL", + cx_Oracle.DB_TYPE_BINARY_INTEGER: "INTEGER", + cx_Oracle.DB_TYPE_BOOLEAN: "BOOLEAN", + cx_Oracle.DB_TYPE_DATE: "TIMESTAMP", + cx_Oracle.DB_TYPE_NUMBER: "NUMERIC", + cx_Oracle.DB_TYPE_TIMESTAMP: "TIMESTAMP", + cx_Oracle.DB_TYPE_TIMESTAMP_LTZ: "TIMESTAMP", + cx_Oracle.DB_TYPE_TIMESTAMP_TZ: "TIMESTAMP", + } + + @apply_defaults + def __init__(self, *, oracle_conn_id="oracle_default", ensure_utc=False, **kwargs): + super().__init__(**kwargs) + self.ensure_utc = ensure_utc + self.oracle_conn_id = oracle_conn_id + + def query(self): + """Queries Oracle and returns a cursor to the results.""" + oracle = OracleHook(oracle_conn_id=self.oracle_conn_id) + conn = oracle.get_conn() + cursor = conn.cursor() + if self.ensure_utc: + # Ensure TIMESTAMP results are in UTC + tz_query = "SET time_zone = '+00:00'" + self.log.info("Executing: %s", tz_query) + cursor.execute(tz_query) + self.log.info("Executing: %s", self.sql) + cursor.execute(self.sql) + return cursor + + def field_to_bigquery(self, field) -> Dict[str, str]: + field_type = self.type_map.get(field[1], "STRING") + + field_mode = ( + "NULLABLE" if not field[6] or field_type == "TIMESTAMP" else "REQUIRED" + ) + return { + "name": field[0], + "type": field_type, + "mode": field_mode, + } + + def convert_type(self, value, schema_type): + """ + Takes a value from Oracle db, and converts it to a value that's safe for + JSON/Google Cloud Storage/BigQuery. + + * Datetimes are converted to UTC seconds. + * Decimals are converted to floats. + * Dates are converted to ISO formatted string if given schema_type is + DATE, or UTC seconds otherwise. + * Binary type fields are converted to integer if given schema_type is + INTEGER, or encoded with base64 otherwise. Imported BYTES data must + be base64-encoded according to BigQuery documentation: + https://cloud.google.com/bigquery/data-types + + :param value: Oracle db column value + :type value: Any + :param schema_type: BigQuery data type + :type schema_type: str + """ + if value is None: + return value + if isinstance(value, datetime): + value = calendar.timegm(value.timetuple()) + elif isinstance(value, timedelta): + value = value.total_seconds() + elif isinstance(value, Decimal): + value = float(value) + elif isinstance(value, date): + if schema_type == "DATE": + value = value.isoformat() + else: + value = calendar.timegm(value.timetuple()) + elif isinstance(value, bytes): + if schema_type == "INTEGER": + value = int.from_bytes(value, "big") + else: + value = base64.standard_b64encode(value).decode("ascii") + return value diff --git a/reference/providers/google/cloud/transfers/postgres_to_gcs.py b/reference/providers/google/cloud/transfers/postgres_to_gcs.py new file mode 100644 index 0000000..4dffd2a --- /dev/null +++ b/reference/providers/google/cloud/transfers/postgres_to_gcs.py @@ -0,0 +1,159 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""PostgreSQL to GCS operator.""" + +import datetime +import json +import time +import uuid +from decimal import Decimal +from typing import Dict + +import pendulum +from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator +from airflow.providers.postgres.hooks.postgres import PostgresHook +from airflow.utils.decorators import apply_defaults + + +class _PostgresServerSideCursorDecorator: + """ + Inspired by `_PrestoToGCSPrestoCursorAdapter` to keep this consistent. + + Decorator for allowing description to be available for postgres cursor in case server side + cursor is used. It doesn't provide other methods except those needed in BaseSQLToGCSOperator, + which is more of a safety feature. + """ + + def __init__(self, cursor): + self.cursor = cursor + self.rows = [] + self.initialized = False + + def __iter__(self): + return self + + def __next__(self): + if self.rows: + return self.rows.pop() + else: + self.initialized = True + return next(self.cursor) + + @property + def description(self): + """Fetch first row to initialize cursor description when using server side cursor.""" + if not self.initialized: + element = self.cursor.fetchone() + self.rows.append(element) + self.initialized = True + return self.cursor.description + + +class PostgresToGCSOperator(BaseSQLToGCSOperator): + """ + Copy data from Postgres to Google Cloud Storage in JSON or CSV format. + + :param postgres_conn_id: Reference to a specific Postgres hook. + :type postgres_conn_id: str + :param use_server_side_cursor: If server-side cursor should be used for querying postgres. + For detailed info, check https://www.psycopg.org/docs/usage.html#server-side-cursors + :type use_server_side_cursor: bool + :param cursor_itersize: How many records are fetched at a time in case of server-side cursor. + :type cursor_itersize: int + """ + + ui_color = "#a0e08c" + + type_map = { + 1114: "TIMESTAMP", + 1184: "TIMESTAMP", + 1082: "TIMESTAMP", + 1083: "TIMESTAMP", + 1005: "INTEGER", + 1007: "INTEGER", + 1016: "INTEGER", + 20: "INTEGER", + 21: "INTEGER", + 23: "INTEGER", + 16: "BOOLEAN", + 700: "FLOAT", + 701: "FLOAT", + 1700: "FLOAT", + } + + @apply_defaults + def __init__( + self, + *, + postgres_conn_id="postgres_default", + use_server_side_cursor=False, + cursor_itersize=2000, + **kwargs, + ): + super().__init__(**kwargs) + self.postgres_conn_id = postgres_conn_id + self.use_server_side_cursor = use_server_side_cursor + self.cursor_itersize = cursor_itersize + + def _unique_name(self): + return ( + f"{self.dag_id}__{self.task_id}__{uuid.uuid4()}" + if self.use_server_side_cursor + else None + ) + + def query(self): + """Queries Postgres and returns a cursor to the results.""" + hook = PostgresHook(postgres_conn_id=self.postgres_conn_id) + conn = hook.get_conn() + cursor = conn.cursor(name=self._unique_name()) + cursor.execute(self.sql, self.parameters) + if self.use_server_side_cursor: + cursor.itersize = self.cursor_itersize + return _PostgresServerSideCursorDecorator(cursor) + return cursor + + def field_to_bigquery(self, field) -> Dict[str, str]: + return { + "name": field[0], + "type": self.type_map.get(field[1], "STRING"), + "mode": "REPEATED" if field[1] in (1009, 1005, 1007, 1016) else "NULLABLE", + } + + def convert_type(self, value, schema_type): + """ + Takes a value from Postgres, and converts it to a value that's safe for + JSON/Google Cloud Storage/BigQuery. Dates are converted to UTC seconds. + Decimals are converted to floats. Times are converted to seconds. + """ + if isinstance(value, (datetime.datetime, datetime.date)): + return pendulum.parse(value.isoformat()).float_timestamp + if isinstance(value, datetime.time): + formatted_time = time.strptime(str(value), "%H:%M:%S") + return int( + datetime.timedelta( + hours=formatted_time.tm_hour, + minutes=formatted_time.tm_min, + seconds=formatted_time.tm_sec, + ).total_seconds() + ) + if isinstance(value, dict): + return json.dumps(value) + if isinstance(value, Decimal): + return float(value) + return value diff --git a/reference/providers/google/cloud/transfers/presto_to_gcs.py b/reference/providers/google/cloud/transfers/presto_to_gcs.py new file mode 100644 index 0000000..dc662ff --- /dev/null +++ b/reference/providers/google/cloud/transfers/presto_to_gcs.py @@ -0,0 +1,209 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, List, Tuple + +from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator +from airflow.providers.presto.hooks.presto import PrestoHook +from airflow.utils.decorators import apply_defaults +from prestodb.client import PrestoResult +from prestodb.dbapi import Cursor as PrestoCursor + + +class _PrestoToGCSPrestoCursorAdapter: + """ + An adapter that adds additional feature to the Presto cursor. + + The implementation of cursor in the prestodb library is not sufficient. + The following changes have been made: + + * The poke mechanism for row. You can look at the next row without consuming it. + * The description attribute is available before reading the first row. Thanks to the poke mechanism. + * the iterator interface has been implemented. + + A detailed description of the class methods is available in + `PEP-249 `__. + """ + + def __init__(self, cursor: PrestoCursor): + self.cursor: PrestoCursor = cursor + self.rows: List[Any] = [] + self.initialized: bool = False + + @property + def description(self) -> List[Tuple]: + """ + This read-only attribute is a sequence of 7-item sequences. + + Each of these sequences contains information describing one result column: + + * ``name`` + * ``type_code`` + * ``display_size`` + * ``internal_size`` + * ``precision`` + * ``scale`` + * ``null_ok`` + + The first two items (``name`` and ``type_code``) are mandatory, the other + five are optional and are set to None if no meaningful values can be provided. + """ + if not self.initialized: + # Peek for first row to load description. + self.peekone() + return self.cursor.description + + @property + def rowcount(self) -> int: + """The read-only attribute specifies the number of rows""" + return self.cursor.rowcount + + def close(self) -> None: + """Close the cursor now""" + self.cursor.close() + + def execute(self, *args, **kwargs) -> PrestoResult: + """Prepare and execute a database operation (query or command).""" + self.initialized = False + self.rows = [] + return self.cursor.execute(*args, **kwargs) + + def executemany(self, *args, **kwargs): + """ + Prepare a database operation (query or command) and then execute it against all parameter + sequences or mappings found in the sequence seq_of_parameters. + """ + self.initialized = False + self.rows = [] + return self.cursor.executemany(*args, **kwargs) + + def peekone(self) -> Any: + """Return the next row without consuming it.""" + self.initialized = True + element = self.cursor.fetchone() + self.rows.insert(0, element) + return element + + def fetchone(self) -> Any: + """ + Fetch the next row of a query result set, returning a single sequence, or + ``None`` when no more data is available. + """ + if self.rows: + return self.rows.pop(0) + return self.cursor.fetchone() + + def fetchmany(self, size=None) -> list: + """ + Fetch the next set of rows of a query result, returning a sequence of sequences + (e.g. a list of tuples). An empty sequence is returned when no more rows are available. + """ + if size is None: + size = self.cursor.arraysize + + result = [] + for _ in range(size): + row = self.fetchone() + if row is None: + break + result.append(row) + + return result + + def __next__(self) -> Any: + """ + Return the next row from the currently executing SQL statement using the same semantics as + ``.fetchone()``. A ``StopIteration`` exception is raised when the result set is exhausted. + :return: + """ + result = self.fetchone() + if result is None: + raise StopIteration() + return result + + def __iter__(self) -> "_PrestoToGCSPrestoCursorAdapter": + """Return self to make cursors compatible to the iteration protocol""" + return self + + +class PrestoToGCSOperator(BaseSQLToGCSOperator): + """Copy data from PrestoDB to Google Cloud Storage in JSON or CSV format. + + :param presto_conn_id: Reference to a specific Presto hook. + :type presto_conn_id: str + """ + + ui_color = "#a0e08c" + + type_map = { + "BOOLEAN": "BOOL", + "TINYINT": "INT64", + "SMALLINT": "INT64", + "INTEGER": "INT64", + "BIGINT": "INT64", + "REAL": "FLOAT64", + "DOUBLE": "FLOAT64", + "DECIMAL": "NUMERIC", + "VARCHAR": "STRING", + "CHAR": "STRING", + "VARBINARY": "BYTES", + "JSON": "STRING", + "DATE": "DATE", + "TIME": "TIME", + # BigQuery don't time with timezone native. + "TIME WITH TIME ZONE": "STRING", + "TIMESTAMP": "TIMESTAMP", + # BigQuery supports a narrow range of time zones during import. + # You should use TIMESTAMP function, if you want have TIMESTAMP type + "TIMESTAMP WITH TIME ZONE": "STRING", + "IPADDRESS": "STRING", + "UUID": "STRING", + } + + @apply_defaults + def __init__(self, *, presto_conn_id: str = "presto_default", **kwargs): + super().__init__(**kwargs) + self.presto_conn_id = presto_conn_id + + def query(self): + """Queries presto and returns a cursor to the results.""" + presto = PrestoHook(presto_conn_id=self.presto_conn_id) + conn = presto.get_conn() + cursor = conn.cursor() + self.log.info("Executing: %s", self.sql) + cursor.execute(self.sql) + return _PrestoToGCSPrestoCursorAdapter(cursor) + + def field_to_bigquery(self, field) -> Dict[str, str]: + """Convert presto field type to BigQuery field type.""" + clear_field_type = field[1].upper() + # remove type argument e.g. DECIMAL(2, 10) => DECIMAL + clear_field_type, _, _ = clear_field_type.partition("(") + new_field_type = self.type_map.get(clear_field_type, "STRING") + + return {"name": field[0], "type": new_field_type} + + def convert_type(self, value, schema_type): + """ + Do nothing. Presto uses JSON on the transport layer, so types are simple. + + :param value: Presto column value + :type value: Any + :param schema_type: BigQuery data type + :type schema_type: str + """ + return value diff --git a/reference/providers/google/cloud/transfers/s3_to_gcs.py b/reference/providers/google/cloud/transfers/s3_to_gcs.py new file mode 100644 index 0000000..78c4bf5 --- /dev/null +++ b/reference/providers/google/cloud/transfers/s3_to_gcs.py @@ -0,0 +1,248 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import warnings +from tempfile import NamedTemporaryFile +from typing import Iterable, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.operators.s3_list import S3ListOperator +from airflow.providers.google.cloud.hooks.gcs import ( + GCSHook, + _parse_gcs_url, + gcs_object_is_directory, +) +from airflow.utils.decorators import apply_defaults + + +class S3ToGCSOperator(S3ListOperator): + """ + Synchronizes an S3 key, possibly a prefix, with a Google Cloud Storage + destination path. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:S3ToGCSOperator` + + :param bucket: The S3 bucket where to find the objects. (templated) + :type bucket: str + :param prefix: Prefix string which filters objects whose name begin with + such prefix. (templated) + :type prefix: str + :param delimiter: the delimiter marks key hierarchy. (templated) + :type delimiter: str + :param aws_conn_id: The source S3 connection + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param dest_gcs_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type dest_gcs_conn_id: str + :param dest_gcs: The destination Google Cloud Storage bucket and prefix + where you want to store the files. (templated) + :type dest_gcs: str + :param delegate_to: Google account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param replace: Whether you want to replace existing destination files + or not. + :type replace: bool + :param gzip: Option to compress file for upload + :type gzip: bool + :param google_impersonation_chain: Optional Google service account to impersonate using + short-term credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type google_impersonation_chain: Union[str, Sequence[str]] + + + **Example**: + + .. code-block:: python + + s3_to_gcs_op = S3ToGCSOperator( + task_id='s3_to_gcs_example', + bucket='my-s3-bucket', + prefix='data/customers-201804', + dest_gcs_conn_id='google_cloud_default', + dest_gcs='gs://my.gcs.bucket/some/customers/', + replace=False, + gzip=True, + dag=my-dag) + + Note that ``bucket``, ``prefix``, ``delimiter`` and ``dest_gcs`` are + templated, so you can use variables in them if you wish. + """ + + template_fields: Iterable[str] = ( + "bucket", + "prefix", + "delimiter", + "dest_gcs", + "google_impersonation_chain", + ) + ui_color = "#e09411" + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + bucket, + prefix="", + delimiter="", + aws_conn_id="aws_default", + verify=None, + gcp_conn_id="google_cloud_default", + dest_gcs_conn_id=None, + dest_gcs=None, + delegate_to=None, + replace=False, + gzip=False, + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + + super().__init__( + bucket=bucket, + prefix=prefix, + delimiter=delimiter, + aws_conn_id=aws_conn_id, + **kwargs, + ) + + if dest_gcs_conn_id: + warnings.warn( + "The dest_gcs_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = dest_gcs_conn_id + + self.gcp_conn_id = gcp_conn_id + self.dest_gcs = dest_gcs + self.delegate_to = delegate_to + self.replace = replace + self.verify = verify + self.gzip = gzip + self.google_impersonation_chain = google_impersonation_chain + + if dest_gcs and not gcs_object_is_directory(self.dest_gcs): + self.log.info( + "Destination Google Cloud Storage path is not a valid " + '"directory", define a path that ends with a slash "/" or ' + "leave it empty for the root of the bucket." + ) + raise AirflowException( + 'The destination Google Cloud Storage path must end with a slash "/" or be empty.' + ) + + def execute(self, context): + # use the super method to list all the files in an S3 bucket/key + files = super().execute(context) + + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.google_impersonation_chain, + ) + + # pylint: disable=too-many-nested-blocks + if not self.replace: + # if we are not replacing -> list all files in the GCS bucket + # and only keep those files which are present in + # S3 and not in Google Cloud Storage + bucket_name, object_prefix = _parse_gcs_url(self.dest_gcs) + existing_files_prefixed = gcs_hook.list(bucket_name, prefix=object_prefix) + + existing_files = [] + + if existing_files_prefixed: + # Remove the object prefix itself, an empty directory was found + if object_prefix in existing_files_prefixed: + existing_files_prefixed.remove(object_prefix) + + # Remove the object prefix from all object string paths + for f in existing_files_prefixed: + if f.startswith(object_prefix): + existing_files.append(f[len(object_prefix) :]) + else: + existing_files.append(f) + + files = list(set(files) - set(existing_files)) + if len(files) > 0: + self.log.info("%s files are going to be synced: %s.", len(files), files) + else: + self.log.info("There are no new files to sync. Have a nice day!") + + if files: + hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + + for file in files: + # GCS hook builds its own in-memory file so we have to create + # and pass the path + file_object = hook.get_key(file, self.bucket) + with NamedTemporaryFile(mode="wb", delete=True) as f: + file_object.download_fileobj(f) + f.flush() + + dest_gcs_bucket, dest_gcs_object_prefix = _parse_gcs_url( + self.dest_gcs + ) + # There will always be a '/' before file because it is + # enforced at instantiation time + dest_gcs_object = dest_gcs_object_prefix + file + + # Sync is sequential and the hook already logs too much + # so skip this for now + # self.log.info( + # 'Saving file {0} from S3 bucket {1} in GCS bucket {2}' + # ' as object {3}'.format(file, self.bucket, + # dest_gcs_bucket, + # dest_gcs_object)) + + gcs_hook.upload( + dest_gcs_bucket, dest_gcs_object, f.name, gzip=self.gzip + ) + + self.log.info( + "All done, uploaded %d files to Google Cloud Storage", len(files) + ) + else: + self.log.info( + "In sync, no files needed to be uploaded to Google Cloud Storage" + ) + + return files diff --git a/reference/providers/google/cloud/transfers/salesforce_to_gcs.py b/reference/providers/google/cloud/transfers/salesforce_to_gcs.py new file mode 100644 index 0000000..cc4d6eb --- /dev/null +++ b/reference/providers/google/cloud/transfers/salesforce_to_gcs.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import tempfile +from typing import Dict, Optional + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.salesforce.hooks.salesforce import SalesforceHook + + +class SalesforceToGcsOperator(BaseOperator): + """ + Submits Salesforce query and uploads results to Google Cloud Storage + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SalesforceToGcsOperator` + + :param query: The query to make to Salesforce. + :type query: str + :param bucket_name: The bucket to upload to. + :type bucket_name: str + :param object_name: The object name to set when uploading the file. + :type object_name: str + :param salesforce_conn_id: the name of the connection that has the parameters + we need to connect to Salesforce. + :type salesforce_conn_id: str + :param include_deleted: True if the query should include deleted records. + :type include_deleted: bool + :param query_params: Additional optional arguments + :type query_params: dict + :param export_format: Desired format of files to be exported. + :type export_format: str + :param coerce_to_timestamp: True if you want all datetime fields to be converted into Unix timestamps. + False if you want them to be left in the same format as they were in Salesforce. + Leaving the value as False will result in datetimes being strings. Default: False + :type coerce_to_timestamp: bool + :param record_time_added: True if you want to add a Unix timestamp field + to the resulting data that marks when the data was fetched from Salesforce. Default: False + :type record_time_added: bool + :param gzip: Option to compress local file or file data for upload + :type gzip: bool + :param gcp_conn_id: the name of the connection that has the parameters we need to connect to GCS. + :type gcp_conn_id: str + """ + + template_fields = ( + "query", + "bucket_name", + "object_name", + ) + template_ext = (".sql",) + + def __init__( + self, + *, + query: str, + bucket_name: str, + object_name: str, + salesforce_conn_id: str, + include_deleted: bool = False, + query_params: Optional[dict] = None, + export_format: str = "csv", + coerce_to_timestamp: bool = False, + record_time_added: bool = False, + gzip: bool = False, + gcp_conn_id: str = "google_cloud_default", + **kwargs, + ): + super().__init__(**kwargs) + self.query = query + self.bucket_name = bucket_name + self.object_name = object_name + self.salesforce_conn_id = salesforce_conn_id + self.export_format = export_format + self.coerce_to_timestamp = coerce_to_timestamp + self.record_time_added = record_time_added + self.gzip = gzip + self.gcp_conn_id = gcp_conn_id + self.include_deleted = include_deleted + self.query_params = query_params + + def execute(self, context: Dict): + salesforce = SalesforceHook(conn_id=self.salesforce_conn_id) + response = salesforce.make_query( + query=self.query, + include_deleted=self.include_deleted, + query_params=self.query_params, + ) + + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "salesforce_temp_file") + salesforce.write_object_to_file( + query_results=response["records"], + filename=path, + fmt=self.export_format, + coerce_to_timestamp=self.coerce_to_timestamp, + record_time_added=self.record_time_added, + ) + + hook = GCSHook(gcp_conn_id=self.gcp_conn_id) + hook.upload( + bucket_name=self.bucket_name, + object_name=self.object_name, + filename=path, + gzip=self.gzip, + ) + + gcs_uri = f"gs://{self.bucket_name}/{self.object_name}" + self.log.info("%s uploaded to GCS", gcs_uri) + return gcs_uri diff --git a/reference/providers/google/cloud/transfers/sftp_to_gcs.py b/reference/providers/google/cloud/transfers/sftp_to_gcs.py new file mode 100644 index 0000000..48f6421 --- /dev/null +++ b/reference/providers/google/cloud/transfers/sftp_to_gcs.py @@ -0,0 +1,194 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains SFTP to Google Cloud Storage operator.""" +import os +from tempfile import NamedTemporaryFile +from typing import Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.sftp.hooks.sftp import SFTPHook +from airflow.utils.decorators import apply_defaults + +WILDCARD = "*" + + +class SFTPToGCSOperator(BaseOperator): + """ + Transfer files to Google Cloud Storage from SFTP server. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SFTPToGCSOperator` + + :param source_path: The sftp remote path. This is the specified file path + for downloading the single file or multiple files from the SFTP server. + You can use only one wildcard within your path. The wildcard can appear + inside the path or at the end of the path. + :type source_path: str + :param destination_bucket: The bucket to upload to. + :type destination_bucket: str + :param destination_path: The destination name of the object in the + destination Google Cloud Storage bucket. + If destination_path is not provided file/files will be placed in the + main bucket path. + If a wildcard is supplied in the destination_path argument, this is the + prefix that will be prepended to the final destination objects' paths. + :type destination_path: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param sftp_conn_id: The sftp connection id. The name or identifier for + establishing a connection to the SFTP server. + :type sftp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param mime_type: The mime-type string + :type mime_type: str + :param gzip: Allows for file to be compressed and uploaded as gzip + :type gzip: bool + :param move_object: When move object is True, the object is moved instead + of copied to the new location. This is the equivalent of a mv command + as opposed to a cp command. + :type move_object: bool + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "source_path", + "destination_path", + "destination_bucket", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + source_path: str, + destination_bucket: str, + destination_path: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + sftp_conn_id: str = "ssh_default", + delegate_to: Optional[str] = None, + mime_type: str = "application/octet-stream", + gzip: bool = False, + move_object: bool = False, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.source_path = source_path + self.destination_path = self._set_destination_path(destination_path) + self.destination_bucket = self._set_bucket_name(destination_bucket) + self.gcp_conn_id = gcp_conn_id + self.mime_type = mime_type + self.delegate_to = delegate_to + self.gzip = gzip + self.sftp_conn_id = sftp_conn_id + self.move_object = move_object + self.impersonation_chain = impersonation_chain + + def execute(self, context): + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + sftp_hook = SFTPHook(self.sftp_conn_id) + + if WILDCARD in self.source_path: + total_wildcards = self.source_path.count(WILDCARD) + if total_wildcards > 1: + raise AirflowException( + "Only one wildcard '*' is allowed in source_path parameter. " + "Found {} in {}.".format(total_wildcards, self.source_path) + ) + + prefix, delimiter = self.source_path.split(WILDCARD, 1) + base_path = os.path.dirname(prefix) + + files, _, _ = sftp_hook.get_tree_map( + base_path, prefix=prefix, delimiter=delimiter + ) + + for file in files: + destination_path = file.replace(base_path, self.destination_path, 1) + self._copy_single_object(gcs_hook, sftp_hook, file, destination_path) + + else: + destination_object = ( + self.destination_path + if self.destination_path + else self.source_path.rsplit("/", 1)[1] + ) + self._copy_single_object( + gcs_hook, sftp_hook, self.source_path, destination_object + ) + + def _copy_single_object( + self, + gcs_hook: GCSHook, + sftp_hook: SFTPHook, + source_path: str, + destination_object: str, + ) -> None: + """Helper function to copy single object.""" + self.log.info( + "Executing copy of %s to gs://%s/%s", + source_path, + self.destination_bucket, + destination_object, + ) + + with NamedTemporaryFile("w") as tmp: + sftp_hook.retrieve_file(source_path, tmp.name) + + gcs_hook.upload( + bucket_name=self.destination_bucket, + object_name=destination_object, + filename=tmp.name, + mime_type=self.mime_type, + ) + + if self.move_object: + self.log.info("Executing delete of %s", source_path) + sftp_hook.delete_file(source_path) + + @staticmethod + def _set_destination_path(path: Union[str, None]) -> str: + if path is not None: + return path.lstrip("/") if path.startswith("/") else path + return "" + + @staticmethod + def _set_bucket_name(name: str) -> str: + bucket = name if not name.startswith("gs://") else name[5:] + return bucket.strip("/") diff --git a/reference/providers/google/cloud/transfers/sheets_to_gcs.py b/reference/providers/google/cloud/transfers/sheets_to_gcs.py new file mode 100644 index 0000000..c109a78 --- /dev/null +++ b/reference/providers/google/cloud/transfers/sheets_to_gcs.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import csv +from tempfile import NamedTemporaryFile +from typing import Any, List, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.suite.hooks.sheets import GSheetsHook +from airflow.utils.decorators import apply_defaults + + +class GoogleSheetsToGCSOperator(BaseOperator): + """ + Writes Google Sheet data into Google Cloud Storage. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleSheetsToGCSOperator` + + :param spreadsheet_id: The Google Sheet ID to interact with. + :type spreadsheet_id: str + :param sheet_filter: Default to None, if provided, Should be an array of the sheet + titles to pull from. + :type sheet_filter: List[str] + :param destination_bucket: The destination Google cloud storage bucket where the + report should be written to. (templated) + :type destination_bucket: str + :param destination_path: The Google cloud storage URI array for the object created by the operator. + For example: ``path/to/my/files``. + :type destination_path: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "spreadsheet_id", + "destination_bucket", + "destination_path", + "sheet_filter", + "impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + spreadsheet_id: str, + destination_bucket: str, + sheet_filter: Optional[List[str]] = None, + destination_path: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.gcp_conn_id = gcp_conn_id + self.spreadsheet_id = spreadsheet_id + self.sheet_filter = sheet_filter + self.destination_bucket = destination_bucket + self.destination_path = destination_path + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def _upload_data( + self, + gcs_hook: GCSHook, + hook: GSheetsHook, + sheet_range: str, + sheet_values: List[Any], + ) -> str: + # Construct destination file path + sheet = hook.get_spreadsheet(self.spreadsheet_id) + file_name = f"{sheet['properties']['title']}_{sheet_range}.csv".replace( + " ", "_" + ) + dest_file_name = ( + f"{self.destination_path.strip('/')}/{file_name}" + if self.destination_path + else file_name + ) + + with NamedTemporaryFile("w+") as temp_file: + # Write data + writer = csv.writer(temp_file) + writer.writerows(sheet_values) + temp_file.flush() + + # Upload to GCS + gcs_hook.upload( + bucket_name=self.destination_bucket, + object_name=dest_file_name, + filename=temp_file.name, + ) + return dest_file_name + + def execute(self, context): + sheet_hook = GSheetsHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + # Pull data and upload + destination_array: List[str] = [] + sheet_titles = sheet_hook.get_sheet_titles( + spreadsheet_id=self.spreadsheet_id, sheet_filter=self.sheet_filter + ) + for sheet_range in sheet_titles: + data = sheet_hook.get_values( + spreadsheet_id=self.spreadsheet_id, range_=sheet_range + ) + gcs_path_to_file = self._upload_data( + gcs_hook, sheet_hook, sheet_range, data + ) + destination_array.append(gcs_path_to_file) + + self.xcom_push(context, "destination_objects", destination_array) + return destination_array diff --git a/reference/providers/google/cloud/transfers/sql_to_gcs.py b/reference/providers/google/cloud/transfers/sql_to_gcs.py new file mode 100644 index 0000000..2869120 --- /dev/null +++ b/reference/providers/google/cloud/transfers/sql_to_gcs.py @@ -0,0 +1,386 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Base operator for SQL to GCS operators.""" +import abc +import json +import warnings +from tempfile import NamedTemporaryFile +from typing import Optional, Sequence, Union + +import pyarrow as pa +import pyarrow.parquet as pq +import unicodecsv as csv +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils.decorators import apply_defaults + + +class BaseSQLToGCSOperator(BaseOperator): + """ + :param sql: The SQL to execute. + :type sql: str + :param bucket: The bucket to upload to. + :type bucket: str + :param filename: The filename to use as the object name when uploading + to Google Cloud Storage. A ``{}`` should be specified in the filename + to allow the operator to inject file numbers in cases where the + file is split due to size. + :type filename: str + :param schema_filename: If set, the filename to use as the object name + when uploading a .json file containing the BigQuery schema fields + for the table that was dumped from the database. + :type schema_filename: str + :param approx_max_file_size_bytes: This operator supports the ability + to split large table dumps into multiple files (see notes in the + filename param docs above). This param allows developers to specify the + file size of the splits. Check https://cloud.google.com/storage/quotas + to see the maximum allowed file size for a single object. + :type approx_max_file_size_bytes: long + :param export_format: Desired format of files to be exported. + :type export_format: str + :param field_delimiter: The delimiter to be used for CSV files. + :type field_delimiter: str + :param null_marker: The null marker to be used for CSV files. + :type null_marker: str + :param gzip: Option to compress file for upload (does not apply to schemas). + :type gzip: bool + :param schema: The schema to use, if any. Should be a list of dict or + a str. Pass a string if using Jinja template, otherwise, pass a list of + dict. Examples could be seen: https://cloud.google.com/bigquery/docs + /schemas#specifying_a_json_schema_file + :type schema: str or list + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param parameters: a parameters dict that is substituted at query runtime. + :type parameters: dict + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "sql", + "bucket", + "filename", + "schema_filename", + "schema", + "parameters", + "impersonation_chain", + ) + template_ext = (".sql",) + ui_color = "#a0e08c" + + @apply_defaults + def __init__( + self, + *, # pylint: disable=too-many-arguments + sql: str, + bucket: str, + filename: str, + schema_filename: Optional[str] = None, + approx_max_file_size_bytes: int = 1900000000, + export_format: str = "json", + field_delimiter: str = ",", + null_marker: Optional[str] = None, + gzip: bool = False, + schema: Optional[Union[str, list]] = None, + parameters: Optional[dict] = None, + gcp_conn_id: str = "google_cloud_default", + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) + gcp_conn_id = google_cloud_storage_conn_id + + self.sql = sql + self.bucket = bucket + self.filename = filename + self.schema_filename = schema_filename + self.approx_max_file_size_bytes = approx_max_file_size_bytes + self.export_format = export_format.lower() + self.field_delimiter = field_delimiter + self.null_marker = null_marker + self.gzip = gzip + self.schema = schema + self.parameters = parameters + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context): + self.log.info("Executing query") + cursor = self.query() + + self.log.info("Writing local data files") + files_to_upload = self._write_local_data_files(cursor) + # If a schema is set, create a BQ schema JSON file. + if self.schema_filename: + self.log.info("Writing local schema file") + files_to_upload.append(self._write_local_schema_file(cursor)) + + # Flush all files before uploading + for tmp_file in files_to_upload: + tmp_file["file_handle"].flush() + + self.log.info("Uploading %d files to GCS.", len(files_to_upload)) + self._upload_to_gcs(files_to_upload) + + self.log.info("Removing local files") + # Close all temp file handles. + for tmp_file in files_to_upload: + tmp_file["file_handle"].close() + + def convert_types(self, schema, col_type_dict, row) -> list: + """Convert values from DBAPI to output-friendly formats.""" + return [ + self.convert_type(value, col_type_dict.get(name)) + for name, value in zip(schema, row) + ] + + def _write_local_data_files(self, cursor): + """ + Takes a cursor, and writes results to a local file. + + :return: A dictionary where keys are filenames to be used as object + names in GCS, and values are file handles to local files that + contain the data for the GCS objects. + """ + schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description)) + col_type_dict = self._get_col_type_dict() + file_no = 0 + tmp_file_handle = NamedTemporaryFile(delete=True) + if self.export_format == "csv": + file_mime_type = "text/csv" + elif self.export_format == "parquet": + file_mime_type = "application/octet-stream" + else: + file_mime_type = "application/json" + files_to_upload = [ + { + "file_name": self.filename.format(file_no), + "file_handle": tmp_file_handle, + "file_mime_type": file_mime_type, + } + ] + self.log.info("Current file count: %d", len(files_to_upload)) + + if self.export_format == "csv": + csv_writer = self._configure_csv_file(tmp_file_handle, schema) + if self.export_format == "parquet": + parquet_schema = self._convert_parquet_schema(cursor) + parquet_writer = self._configure_parquet_file( + tmp_file_handle, parquet_schema + ) + + for row in cursor: + # Convert datetime objects to utc seconds, and decimals to floats. + # Convert binary type object to string encoded with base64. + row = self.convert_types(schema, col_type_dict, row) + + if self.export_format == "csv": + if self.null_marker is not None: + row = [ + value if value is not None else self.null_marker + for value in row + ] + csv_writer.writerow(row) + elif self.export_format == "parquet": + if self.null_marker is not None: + row = [ + value if value is not None else self.null_marker + for value in row + ] + row_pydic = {col: [value] for col, value in zip(schema, row)} + tbl = pa.Table.from_pydict(row_pydic, parquet_schema) + parquet_writer.write_table(tbl) + else: + row_dict = dict(zip(schema, row)) + + tmp_file_handle.write( + json.dumps(row_dict, sort_keys=True, ensure_ascii=False).encode( + "utf-8" + ) + ) + + # Append newline to make dumps BigQuery compatible. + tmp_file_handle.write(b"\n") + + # Stop if the file exceeds the file size limit. + if tmp_file_handle.tell() >= self.approx_max_file_size_bytes: + file_no += 1 + tmp_file_handle = NamedTemporaryFile(delete=True) + files_to_upload.append( + { + "file_name": self.filename.format(file_no), + "file_handle": tmp_file_handle, + "file_mime_type": file_mime_type, + } + ) + self.log.info("Current file count: %d", len(files_to_upload)) + if self.export_format == "csv": + csv_writer = self._configure_csv_file(tmp_file_handle, schema) + if self.export_format == "parquet": + parquet_writer = self._configure_parquet_file( + tmp_file_handle, parquet_schema + ) + return files_to_upload + + def _configure_csv_file(self, file_handle, schema): + """Configure a csv writer with the file_handle and write schema + as headers for the new file. + """ + csv_writer = csv.writer( + file_handle, encoding="utf-8", delimiter=self.field_delimiter + ) + csv_writer.writerow(schema) + return csv_writer + + def _configure_parquet_file(self, file_handle, parquet_schema): + parquet_writer = pq.ParquetWriter(file_handle.name, parquet_schema) + return parquet_writer + + def _convert_parquet_schema(self, cursor): + type_map = { + "INTEGER": pa.int64(), + "FLOAT": pa.float64(), + "NUMERIC": pa.float64(), + "BIGNUMERIC": pa.float64(), + "BOOL": pa.bool_(), + "STRING": pa.string(), + "BYTES": pa.binary(), + "DATE": pa.date32(), + "DATETIME": pa.date64(), + "TIMESTAMP": pa.timestamp("s"), + } + + columns = [field[0] for field in cursor.description] + bq_types = [self.field_to_bigquery(field) for field in cursor.description] + pq_types = [type_map.get(bq_type, pa.string()) for bq_type in bq_types] + parquet_schema = pa.schema(zip(columns, pq_types)) + return parquet_schema + + @abc.abstractmethod + def query(self): + """Execute DBAPI query.""" + + @abc.abstractmethod + def field_to_bigquery(self, field): + """Convert a DBAPI field to BigQuery schema format.""" + + @abc.abstractmethod + def convert_type(self, value, schema_type): + """Convert a value from DBAPI to output-friendly formats.""" + + def _get_col_type_dict(self): + """Return a dict of column name and column type based on self.schema if not None.""" + schema = [] + if isinstance(self.schema, str): + schema = json.loads(self.schema) + elif isinstance(self.schema, list): + schema = self.schema + elif self.schema is not None: + self.log.warning( + "Using default schema due to unexpected type. Should be a string or list." + ) + + col_type_dict = {} + try: + col_type_dict = {col["name"]: col["type"] for col in schema} + except KeyError: + self.log.warning( + "Using default schema due to missing name or type. Please " + "refer to: https://cloud.google.com/bigquery/docs/schemas" + "#specifying_a_json_schema_file" + ) + return col_type_dict + + def _write_local_schema_file(self, cursor): + """ + Takes a cursor, and writes the BigQuery schema for the results to a + local file system. Schema for database will be read from cursor if + not specified. + + :return: A dictionary where key is a filename to be used as an object + name in GCS, and values are file handles to local files that + contains the BigQuery schema fields in .json format. + """ + if self.schema: + self.log.info("Using user schema") + schema = self.schema + else: + self.log.info("Starts generating schema") + schema = [self.field_to_bigquery(field) for field in cursor.description] + + if isinstance(schema, list): + schema = json.dumps(schema, sort_keys=True) + + self.log.info("Using schema for %s", self.schema_filename) + self.log.debug("Current schema: %s", schema) + + tmp_schema_file_handle = NamedTemporaryFile(delete=True) + tmp_schema_file_handle.write(schema.encode("utf-8")) + schema_file_to_upload = { + "file_name": self.schema_filename, + "file_handle": tmp_schema_file_handle, + "file_mime_type": "application/json", + } + return schema_file_to_upload + + def _upload_to_gcs(self, files_to_upload): + """ + Upload all of the file splits (and optionally the schema .json file) to + Google Cloud Storage. + """ + hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + for tmp_file in files_to_upload: + hook.upload( + self.bucket, + tmp_file.get("file_name"), + tmp_file.get("file_handle").name, + mime_type=tmp_file.get("file_mime_type"), + gzip=self.gzip + if tmp_file.get("file_name") != self.schema_filename + else False, + ) diff --git a/reference/providers/google/cloud/utils/__init__.py b/reference/providers/google/cloud/utils/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/google/cloud/utils/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/cloud/utils/credentials_provider.py b/reference/providers/google/cloud/utils/credentials_provider.py new file mode 100644 index 0000000..1be00b5 --- /dev/null +++ b/reference/providers/google/cloud/utils/credentials_provider.py @@ -0,0 +1,388 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains a mechanism for providing temporary +Google Cloud authentication. +""" +import json +import logging +import tempfile +from contextlib import ExitStack, contextmanager +from typing import Collection, Dict, Generator, Optional, Sequence, Tuple, Union +from urllib.parse import urlencode + +import google.auth +import google.auth.credentials +import google.oauth2.service_account +from airflow.exceptions import AirflowException +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.process_utils import patch_environ +from google.auth import impersonated_credentials +from google.auth.environment_vars import CREDENTIALS, LEGACY_PROJECT, PROJECT + +log = logging.getLogger(__name__) + +AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT = "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT" +_DEFAULT_SCOPES: Sequence[str] = ("https://www.googleapis.com/auth/cloud-platform",) + + +def build_gcp_conn( + key_file_path: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + project_id: Optional[str] = None, +) -> str: + """ + Builds a uri that can be used as :envvar:`AIRFLOW_CONN_{CONN_ID}` with provided service key, + scopes and project id. + + :param key_file_path: Path to service key. + :type key_file_path: Optional[str] + :param scopes: Required OAuth scopes. + :type scopes: Optional[List[str]] + :param project_id: The Google Cloud project id to be used for the connection. + :type project_id: Optional[str] + :return: String representing Airflow connection. + """ + conn = "google-cloud-platform://?{}" + extras = "extra__google_cloud_platform" + + query_params = {} + if key_file_path: + query_params[f"{extras}__key_path"] = key_file_path + if scopes: + scopes_string = ",".join(scopes) + query_params[f"{extras}__scope"] = scopes_string + if project_id: + query_params[f"{extras}__projects"] = project_id + + query = urlencode(query_params) + return conn.format(query) + + +@contextmanager +def provide_gcp_credentials( + key_file_path: Optional[str] = None, key_file_dict: Optional[Dict] = None +): + """ + Context manager that provides a Google Cloud credentials for application supporting `Application + Default Credentials (ADC) strategy `__. + + It can be used to provide credentials for external programs (e.g. gcloud) that expect authorization + file in ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable. + + :param key_file_path: Path to file with Google Cloud Service Account .json file. + :type key_file_path: str + :param key_file_dict: Dictionary with credentials. + :type key_file_dict: Dict + """ + if not key_file_path and not key_file_dict: + raise ValueError("Please provide `key_file_path` or `key_file_dict`.") + + if key_file_path and key_file_path.endswith(".p12"): + raise AirflowException( + "Legacy P12 key file are not supported, use a JSON key file." + ) + + with tempfile.NamedTemporaryFile(mode="w+t") as conf_file: + if not key_file_path and key_file_dict: + conf_file.write(json.dumps(key_file_dict)) + conf_file.flush() + key_file_path = conf_file.name + if key_file_path: + with patch_environ({CREDENTIALS: key_file_path}): + yield + else: + # We will use the default service account credentials. + yield + + +@contextmanager +def provide_gcp_connection( + key_file_path: Optional[str] = None, + scopes: Optional[Sequence] = None, + project_id: Optional[str] = None, +) -> Generator: + """ + Context manager that provides a temporary value of :envvar:`AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT` + connection. It build a new connection that includes path to provided service json, + required scopes and project id. + + :param key_file_path: Path to file with Google Cloud Service Account .json file. + :type key_file_path: str + :param scopes: OAuth scopes for the connection + :type scopes: Sequence + :param project_id: The id of Google Cloud project for the connection. + :type project_id: str + """ + if key_file_path and key_file_path.endswith(".p12"): + raise AirflowException( + "Legacy P12 key file are not supported, use a JSON key file." + ) + + conn = build_gcp_conn( + scopes=scopes, key_file_path=key_file_path, project_id=project_id + ) + + with patch_environ({AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT: conn}): + yield + + +@contextmanager +def provide_gcp_conn_and_credentials( + key_file_path: Optional[str] = None, + scopes: Optional[Sequence] = None, + project_id: Optional[str] = None, +) -> Generator: + """ + Context manager that provides both: + + - Google Cloud credentials for application supporting `Application Default Credentials (ADC) + strategy `__. + - temporary value of :envvar:`AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT` connection + + :param key_file_path: Path to file with Google Cloud Service Account .json file. + :type key_file_path: str + :param scopes: OAuth scopes for the connection + :type scopes: Sequence + :param project_id: The id of Google Cloud project for the connection. + :type project_id: str + """ + with ExitStack() as stack: + if key_file_path: + stack.enter_context( # type; ignore # pylint: disable=no-member + provide_gcp_credentials(key_file_path) + ) + if project_id: + stack.enter_context( # type; ignore # pylint: disable=no-member + patch_environ({PROJECT: project_id, LEGACY_PROJECT: project_id}) + ) + + stack.enter_context( # type; ignore # pylint: disable=no-member + provide_gcp_connection(key_file_path, scopes, project_id) + ) + yield + + +class _CredentialProvider(LoggingMixin): + """ + Prepare the Credentials object for Google API and the associated project_id + + Only either `key_path` or `keyfile_dict` should be provided, or an exception will + occur. If neither of them are provided, return default credentials for the current environment + + :param key_path: Path to Google Cloud Service Account key file (JSON). + :type key_path: str + :param keyfile_dict: A dict representing Cloud Service Account as in the Credential JSON file + :type keyfile_dict: Dict[str, str] + :param scopes: OAuth scopes for the connection + :type scopes: Collection[str] + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param disable_logging: If true, disable all log messages, which allows you to use this + class to configure Logger. + :param target_principal: The service account to directly impersonate using short-term + credentials, if any. For this to work, the target_principal account must grant + the originating account the Service Account Token Creator IAM role. + :type target_principal: str + :param delegates: optional chained list of accounts required to get the access_token of + target_principal. If set, the sequence of identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account and target_principal + granting the role to the last account from the list. + :type delegates: Sequence[str] + """ + + def __init__( + self, + key_path: Optional[str] = None, + keyfile_dict: Optional[Dict[str, str]] = None, + # See: https://github.com/PyCQA/pylint/issues/2377 + scopes: Optional[ + Collection[str] + ] = None, # pylint: disable=unsubscriptable-object + delegate_to: Optional[str] = None, + disable_logging: bool = False, + target_principal: Optional[str] = None, + delegates: Optional[Sequence[str]] = None, + ) -> None: + super().__init__() + if key_path and keyfile_dict: + raise AirflowException( + "The `keyfile_dict` and `key_path` fields are mutually exclusive. " + "Please provide only one value." + ) + self.key_path = key_path + self.keyfile_dict = keyfile_dict + self.scopes = scopes + self.delegate_to = delegate_to + self.disable_logging = disable_logging + self.target_principal = target_principal + self.delegates = delegates + + def get_credentials_and_project( + self, + ) -> Tuple[google.auth.credentials.Credentials, str]: + """ + Get current credentials and project ID. + + :return: Google Auth Credentials + :type: Tuple[google.auth.credentials.Credentials, str] + """ + if self.key_path: + credentials, project_id = self._get_credentials_using_key_path() + elif self.keyfile_dict: + credentials, project_id = self._get_credentials_using_keyfile_dict() + else: + credentials, project_id = self._get_credentials_using_adc() + + if self.delegate_to: + if hasattr(credentials, "with_subject"): + credentials = credentials.with_subject(self.delegate_to) + else: + raise AirflowException( + "The `delegate_to` parameter cannot be used here as the current " + "authentication method does not support account impersonate. " + "Please use service-account for authorization." + ) + + if self.target_principal: + credentials = impersonated_credentials.Credentials( + source_credentials=credentials, + target_principal=self.target_principal, + delegates=self.delegates, + target_scopes=self.scopes, + ) + + project_id = _get_project_id_from_service_account_email( + self.target_principal + ) + + return credentials, project_id + + def _get_credentials_using_keyfile_dict(self): + self._log_debug("Getting connection using JSON Dict") + # Depending on how the JSON was formatted, it may contain + # escaped newlines. Convert those to actual newlines. + self.keyfile_dict["private_key"] = self.keyfile_dict["private_key"].replace( + "\\n", "\n" + ) + credentials = ( + google.oauth2.service_account.Credentials.from_service_account_info( + self.keyfile_dict, scopes=self.scopes + ) + ) + project_id = credentials.project_id + return credentials, project_id + + def _get_credentials_using_key_path(self): + if self.key_path.endswith(".p12"): + raise AirflowException( + "Legacy P12 key file are not supported, use a JSON key file." + ) + + if not self.key_path.endswith(".json"): + raise AirflowException("Unrecognised extension for key file.") + + self._log_debug("Getting connection using JSON key file %s", self.key_path) + credentials = ( + google.oauth2.service_account.Credentials.from_service_account_file( + self.key_path, scopes=self.scopes + ) + ) + project_id = credentials.project_id + return credentials, project_id + + def _get_credentials_using_adc(self): + self._log_info( + "Getting connection using `google.auth.default()` since no key file is defined for hook." + ) + credentials, project_id = google.auth.default(scopes=self.scopes) + return credentials, project_id + + def _log_info(self, *args, **kwargs) -> None: + if not self.disable_logging: + self.log.info(*args, **kwargs) + + def _log_debug(self, *args, **kwargs) -> None: + if not self.disable_logging: + self.log.debug(*args, **kwargs) + + +def get_credentials_and_project_id( + *args, **kwargs +) -> Tuple[google.auth.credentials.Credentials, str]: + """Returns the Credentials object for Google API and the associated project_id.""" + return _CredentialProvider(*args, **kwargs).get_credentials_and_project() + + +def _get_scopes(scopes: Optional[str] = None) -> Sequence[str]: + """ + Parse a comma-separated string containing OAuth2 scopes if `scopes` is provided. + Otherwise, default scope will be returned. + + :param scopes: A comma-separated string containing OAuth2 scopes + :type scopes: Optional[str] + :return: Returns the scope defined in the connection configuration, or the default scope + :rtype: Sequence[str] + """ + return [s.strip() for s in scopes.split(",")] if scopes else _DEFAULT_SCOPES + + +def _get_target_principal_and_delegates( + impersonation_chain: Optional[Union[str, Sequence[str]]] = None +) -> Tuple[Optional[str], Optional[Sequence[str]]]: + """ + Analyze contents of impersonation_chain and return target_principal (the service account + to directly impersonate using short-term credentials, if any) and optional list of delegates + required to get the access_token of target_principal. + + :param impersonation_chain: the service account to impersonate or a chained list leading to this + account + :type impersonation_chain: Optional[Union[str, Sequence[str]]] + + :return: Returns the tuple of target_principal and delegates + :rtype: Tuple[Optional[str], Optional[Sequence[str]]] + """ + if not impersonation_chain: + return None, None + + if isinstance(impersonation_chain, str): + return impersonation_chain, None + + return impersonation_chain[-1], impersonation_chain[:-1] + + +def _get_project_id_from_service_account_email(service_account_email: str) -> str: + """ + Extracts project_id from service account's email address. + + :param service_account_email: email of the service account. + :type service_account_email: str + + :return: Returns the project_id of the provided service account. + :rtype: str + """ + try: + return service_account_email.split("@")[1].split(".")[0] + except IndexError: + raise AirflowException( + f"Could not extract project_id from service account's email: " + f"{service_account_email}." + ) diff --git a/reference/providers/google/cloud/utils/field_sanitizer.py b/reference/providers/google/cloud/utils/field_sanitizer.py new file mode 100644 index 0000000..74892fd --- /dev/null +++ b/reference/providers/google/cloud/utils/field_sanitizer.py @@ -0,0 +1,174 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Sanitizer for body fields sent via Google Cloud API. + +The sanitizer removes fields specified from the body. + +Context +------- +In some cases where Google Cloud operation requires modification of existing resources (such +as instances or instance templates) we need to sanitize body of the resources returned +via Google Cloud APIs. This is in the case when we retrieve information from Google Cloud first, +modify the body and either update the existing resource or create a new one with the +modified body. Usually when you retrieve resource from Google Cloud you get some extra fields which +are Output-only, and we need to delete those fields if we want to use +the body as input for subsequent create/insert type operation. + + +Field specification +------------------- + +Specification of fields is an array of strings which denote names of fields to be removed. +The field can be either direct field name to remove from the body or the full +specification of the path you should delete - separated with '.' + + +>>> FIELDS_TO_SANITIZE = [ +>>> "kind", +>>> "properties.disks.kind", +>>> "properties.metadata.kind", +>>>] +>>> body = { +>>> "kind": "compute#instanceTemplate", +>>> "name": "instance", +>>> "properties": { +>>> "disks": [ +>>> { +>>> "name": "a", +>>> "kind": "compute#attachedDisk", +>>> "type": "PERSISTENT", +>>> "mode": "READ_WRITE", +>>> }, +>>> { +>>> "name": "b", +>>> "kind": "compute#attachedDisk", +>>> "type": "PERSISTENT", +>>> "mode": "READ_WRITE", +>>> } +>>> ], +>>> "metadata": { +>>> "kind": "compute#metadata", +>>> "fingerprint": "GDPUYxlwHe4=" +>>> }, +>>> } +>>> } +>>> sanitizer=GcpBodyFieldSanitizer(FIELDS_TO_SANITIZE) +>>> sanitizer.sanitize(body) +>>> json.dumps(body, indent=2) +{ + "name": "instance", + "properties": { + "disks": [ + { + "name": "a", + "type": "PERSISTENT", + "mode": "READ_WRITE", + }, + { + "name": "b", + "type": "PERSISTENT", + "mode": "READ_WRITE", + } + ], + "metadata": { + "fingerprint": "GDPUYxlwHe4=" + }, + } +} + +Note that the components of the path can be either dictionaries or arrays of dictionaries. +In case they are dictionaries, subsequent component names key of the field, in case of +arrays - the sanitizer iterates through all dictionaries in the array and searches +components in all elements of the array. +""" + +from typing import List + +from airflow.exceptions import AirflowException +from airflow.utils.log.logging_mixin import LoggingMixin + + +class GcpFieldSanitizerException(AirflowException): + """Thrown when sanitizer finds unexpected field type in the path + (other than dict or array). + """ + + +class GcpBodyFieldSanitizer(LoggingMixin): + """Sanitizes the body according to specification. + + :param sanitize_specs: array of strings that specifies which fields to remove + :type sanitize_specs: list[str] + + """ + + def __init__(self, sanitize_specs: List[str]) -> None: + super().__init__() + self._sanitize_specs = sanitize_specs + + def _sanitize(self, dictionary, remaining_field_spec, current_path): + field_split = remaining_field_spec.split(".", 1) + if len(field_split) == 1: # pylint: disable=too-many-nested-blocks + field_name = field_split[0] + if field_name in dictionary: + self.log.info("Deleted %s [%s]", field_name, current_path) + del dictionary[field_name] + else: + self.log.debug( + "The field %s is missing in %s at the path %s.", + field_name, + dictionary, + current_path, + ) + else: + field_name = field_split[0] + remaining_path = field_split[1] + child = dictionary.get(field_name) + if child is None: + self.log.debug( + "The field %s is missing in %s at the path %s. ", + field_name, + dictionary, + current_path, + ) + elif isinstance(child, dict): + self._sanitize(child, remaining_path, f"{current_path}.{field_name}") + elif isinstance(child, list): + for index, elem in enumerate(child): + if not isinstance(elem, dict): + self.log.warning( + "The field %s element at index %s is of wrong type. " + "It should be dict and is %s. Skipping it.", + current_path, + index, + elem, + ) + self._sanitize( + elem, remaining_path, f"{current_path}.{field_name}[{index}]" + ) + else: + self.log.warning( + "The field %s is of wrong type. It should be dict or list and it is %s. Skipping it.", + current_path, + child, + ) + + def sanitize(self, body) -> None: + """Sanitizes the body according to specification.""" + for elem in self._sanitize_specs: + self._sanitize(body, elem, "") diff --git a/reference/providers/google/cloud/utils/field_validator.py b/reference/providers/google/cloud/utils/field_validator.py new file mode 100644 index 0000000..6ffda08 --- /dev/null +++ b/reference/providers/google/cloud/utils/field_validator.py @@ -0,0 +1,506 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Validator for body fields sent via Google Cloud API. + +The validator performs validation of the body (being dictionary of fields) that +is sent in the API request to Google Cloud (via ``googleclient`` API usually). + +Context +------- +The specification mostly focuses on helping Airflow DAG developers in the development +phase. You can build your own Google Cloud operator (such as GcfDeployOperator for example) which +can have built-in validation specification for the particular API. It's super helpful +when developer plays with different fields and their values at the initial phase of +DAG development. Most of the Google Cloud APIs perform their own validation on the +server side, but most of the requests are asynchronous and you need to wait for result +of the operation. This takes precious times and slows +down iteration over the API. BodyFieldValidator is meant to be used on the client side +and it should therefore provide an instant feedback to the developer on misspelled or +wrong type of parameters. + +The validation should be performed in "execute()" method call in order to allow +template parameters to be expanded before validation is performed. + +Types of fields +--------------- + +Specification is an array of dictionaries - each dictionary describes field, its type, +validation, optionality, api_version supported and nested fields (for unions and dicts). + +Typically (for clarity and in order to aid syntax highlighting) the array of +dicts should be defined as series of dict() executions. Fragment of example +specification might look as follows:: + + SPECIFICATION =[ + dict(name="an_union", type="union", optional=True, fields=[ + dict(name="variant_1", type="dict"), + dict(name="variant_2", regexp=r'^.+$', api_version='v1beta2'), + ), + dict(name="an_union", type="dict", fields=[ + dict(name="field_1", type="dict"), + dict(name="field_2", regexp=r'^.+$'), + ), + ... + ] + + +Each field should have key = "name" indicating field name. The field can be of one of the +following types: + +* Dict fields: (key = "type", value="dict"): + Field of this type should contain nested fields in form of an array of dicts. + Each of the fields in the array is then expected (unless marked as optional) + and validated recursively. If an extra field is present in the dictionary, warning is + printed in log file (but the validation succeeds - see the Forward-compatibility notes) +* List fields: (key = "type", value="list"): + Field of this type should be a list. Only the type correctness is validated. + The contents of a list are not subject to validation. +* Union fields (key = "type", value="union"): field of this type should contain nested + fields in form of an array of dicts. One of the fields (and only one) should be + present (unless the union is marked as optional). If more than one union field is + present, FieldValidationException is raised. If none of the union fields is + present - warning is printed in the log (see below Forward-compatibility notes). +* Fields validated for non-emptiness: (key = "allow_empty") - this applies only to + fields the value of which is a string, and it allows to check for non-emptiness of + the field (allow_empty=False). +* Regexp-validated fields: (key = "regexp") - fields of this type are assumed to be + strings and they are validated with the regexp specified. Remember that the regexps + should ideally contain ^ at the beginning and $ at the end to make sure that + the whole field content is validated. Typically such regexp + validations should be used carefully and sparingly (see Forward-compatibility + notes below). +* Custom-validated fields: (key = "custom_validation") - fields of this type are validated + using method specified via custom_validation field. Any exception thrown in the custom + validation will be turned into FieldValidationException and will cause validation to + fail. Such custom validations might be used to check numeric fields (including + ranges of values), booleans or any other types of fields. +* API version: (key="api_version") if API version is specified, then the field will only + be validated when api_version used at field validator initialization matches exactly the + version specified. If you want to declare fields that are available in several + versions of the APIs, you should specify the field as many times as many API versions + should be supported (each time with different API version). +* if none of the keys ("type", "regexp", "custom_validation" - the field is not validated + +You can see some of the field examples in EXAMPLE_VALIDATION_SPECIFICATION. + + +Forward-compatibility notes +--------------------------- +Certain decisions are crucial to allow the client APIs to work also with future API +versions. Since body attached is passed to the API’s call, this is entirely +possible to pass-through any new fields in the body (for future API versions) - +albeit without validation on the client side - they can and will still be validated +on the server side usually. + +Here are the guidelines that you should follow to make validation forward-compatible: + +* most of the fields are not validated for their content. It's possible to use regexp + in some specific cases that are guaranteed not to change in the future, but for most + fields regexp validation should be r'^.+$' indicating check for non-emptiness +* api_version is not validated - user can pass any future version of the api here. The API + version is only used to filter parameters that are marked as present in this api version + any new (not present in the specification) fields in the body are allowed (not verified) + For dictionaries, new fields can be added to dictionaries by future calls. However if an + unknown field in dictionary is added, a warning is logged by the client (but validation + remains successful). This is very nice feature to protect against typos in names. +* For unions, newly added union variants can be added by future calls and they will + pass validation, however the content or presence of those fields will not be validated. + This means that it’s possible to send a new non-validated union field together with an + old validated field and this problem will not be detected by the client. In such case + warning will be printed. +* When you add validator to an operator, you should also add ``validate_body`` parameter + (default = True) to __init__ of such operators - when it is set to False, + no validation should be performed. This is a safeguard for totally unpredicted and + backwards-incompatible changes that might sometimes occur in the APIs. + +""" + +import re +from typing import Callable, Dict, Sequence + +from airflow.exceptions import AirflowException +from airflow.utils.log.logging_mixin import LoggingMixin + +COMPOSITE_FIELD_TYPES = ["union", "dict", "list"] + + +class GcpFieldValidationException(AirflowException): + """Thrown when validation finds dictionary field not valid according to specification.""" + + +class GcpValidationSpecificationException(AirflowException): + """Thrown when validation specification is wrong. + + This should only happen during development as ideally + specification itself should not be invalid ;) . + """ + + +def _int_greater_than_zero(value): + if int(value) <= 0: + raise GcpFieldValidationException( + "The available memory has to be greater than 0" + ) + + +EXAMPLE_VALIDATION_SPECIFICATION = [ + dict(name="name", allow_empty=False), + dict(name="description", allow_empty=False, optional=True), + dict( + name="availableMemoryMb", + custom_validation=_int_greater_than_zero, + optional=True, + ), + dict(name="labels", optional=True, type="dict"), + dict( + name="an_union", + type="union", + fields=[ + dict(name="variant_1", regexp=r"^.+$"), + dict(name="variant_2", regexp=r"^.+$", api_version="v1beta2"), + dict( + name="variant_3", type="dict", fields=[dict(name="url", regexp=r"^.+$")] + ), + dict(name="variant_4"), + ], + ), +] + + +class GcpBodyFieldValidator(LoggingMixin): + """Validates correctness of request body according to specification. + + The specification can describe various type of + fields including custom validation, and union of fields. This validator is + to be reusable by various operators. See the EXAMPLE_VALIDATION_SPECIFICATION + for some examples and explanations of how to create specification. + + :param validation_specs: dictionary describing validation specification + :type validation_specs: list[dict] + :param api_version: Version of the api used (for example v1) + :type api_version: str + + """ + + def __init__(self, validation_specs: Sequence[Dict], api_version: str) -> None: + super().__init__() + self._validation_specs = validation_specs + self._api_version = api_version + + @staticmethod + def _get_field_name_with_parent(field_name, parent): + if parent: + return parent + "." + field_name + return field_name + + @staticmethod + def _sanity_checks( + children_validation_specs: Dict, + field_type: str, + full_field_path: str, + regexp: str, + allow_empty: bool, + custom_validation: Callable, + value, + ) -> None: + if value is None and field_type != "union": + raise GcpFieldValidationException( + f"The required body field '{full_field_path}' is missing. Please add it." + ) + if regexp and field_type: + raise GcpValidationSpecificationException( + "The validation specification entry '{}' has both type and regexp. " + "The regexp is only allowed without type (i.e. assume type is 'str' " + "that can be validated with regexp)".format(full_field_path) + ) + if allow_empty is not None and field_type: + raise GcpValidationSpecificationException( + "The validation specification entry '{}' has both type and allow_empty. " + "The allow_empty is only allowed without type (i.e. assume type is 'str' " + "that can be validated with allow_empty)".format(full_field_path) + ) + if children_validation_specs and field_type not in COMPOSITE_FIELD_TYPES: + raise GcpValidationSpecificationException( + "Nested fields are specified in field '{}' of type '{}'. " + "Nested fields are only allowed for fields of those types: ('{}').".format( + full_field_path, field_type, COMPOSITE_FIELD_TYPES + ) + ) + if custom_validation and field_type: + raise GcpValidationSpecificationException( + "The validation specification field '{}' has both type and " + "custom_validation. Custom validation is only allowed without type.".format( + full_field_path + ) + ) + + @staticmethod + def _validate_regexp(full_field_path: str, regexp: str, value: str) -> None: + if not re.match(regexp, value): + # Note matching of only the beginning as we assume the regexps all-or-nothing + raise GcpFieldValidationException( + "The body field '{}' of value '{}' does not match the field " + "specification regexp: '{}'.".format(full_field_path, value, regexp) + ) + + @staticmethod + def _validate_is_empty(full_field_path: str, value: str) -> None: + if not value: + raise GcpFieldValidationException( + f"The body field '{full_field_path}' can't be empty. Please provide a value." + ) + + def _validate_dict( + self, children_validation_specs: Dict, full_field_path: str, value: Dict + ) -> None: + for child_validation_spec in children_validation_specs: + self._validate_field( + validation_spec=child_validation_spec, + dictionary_to_validate=value, + parent=full_field_path, + ) + all_dict_keys = [spec["name"] for spec in children_validation_specs] + for field_name in value.keys(): + if field_name not in all_dict_keys: + self.log.warning( + "The field '%s' is in the body, but is not specified in the " + "validation specification '%s'. " + "This might be because you are using newer API version and " + "new field names defined for that version. Then the warning " + "can be safely ignored, or you might want to upgrade the operator" + "to the version that supports the new API version.", + self._get_field_name_with_parent(field_name, full_field_path), + children_validation_specs, + ) + + def _validate_union( + self, + children_validation_specs: Dict, + full_field_path: str, + dictionary_to_validate: Dict, + ) -> None: + field_found = False + found_field_name = None + for child_validation_spec in children_validation_specs: + # Forcing optional so that we do not have to type optional = True + # in specification for all union fields + new_field_found = self._validate_field( + validation_spec=child_validation_spec, + dictionary_to_validate=dictionary_to_validate, + parent=full_field_path, + force_optional=True, + ) + field_name = child_validation_spec["name"] + if new_field_found and field_found: + raise GcpFieldValidationException( + "The mutually exclusive fields '{}' and '{}' belonging to the " + "union '{}' are both present. Please remove one".format( + field_name, found_field_name, full_field_path + ) + ) + if new_field_found: + field_found = True + found_field_name = field_name + if not field_found: + self.log.warning( + "There is no '%s' union defined in the body %s. " + "Validation expected one of '%s' but could not find any. It's possible " + "that you are using newer API version and there is another union variant " + "defined for that version. Then the warning can be safely ignored, " + "or you might want to upgrade the operator to the version that " + "supports the new API version.", + full_field_path, + dictionary_to_validate, + [field["name"] for field in children_validation_specs], + ) + + def _validate_field( + self, validation_spec, dictionary_to_validate, parent=None, force_optional=False + ): + """ + Validates if field is OK. + + :param validation_spec: specification of the field + :type validation_spec: dict + :param dictionary_to_validate: dictionary where the field should be present + :type dictionary_to_validate: dict + :param parent: full path of parent field + :type parent: str + :param force_optional: forces the field to be optional + (all union fields have force_optional set to True) + :type force_optional: bool + :return: True if the field is present + """ + field_name = validation_spec["name"] + field_type = validation_spec.get("type") + optional = validation_spec.get("optional") + regexp = validation_spec.get("regexp") + allow_empty = validation_spec.get("allow_empty") + children_validation_specs = validation_spec.get("fields") + required_api_version = validation_spec.get("api_version") + custom_validation = validation_spec.get("custom_validation") + + full_field_path = self._get_field_name_with_parent( + field_name=field_name, parent=parent + ) + if required_api_version and required_api_version != self._api_version: + self.log.debug( + "Skipping validation of the field '%s' for API version '%s' " + "as it is only valid for API version '%s'", + field_name, + self._api_version, + required_api_version, + ) + return False + value = dictionary_to_validate.get(field_name) + + if (optional or force_optional) and value is None: + self.log.debug( + "The optional field '%s' is missing. That's perfectly OK.", + full_field_path, + ) + return False + + # Certainly down from here the field is present (value is not None) + # so we should only return True from now on + + self._sanity_checks( + children_validation_specs=children_validation_specs, + field_type=field_type, + full_field_path=full_field_path, + regexp=regexp, + allow_empty=allow_empty, + custom_validation=custom_validation, + value=value, + ) + + if allow_empty is False: + self._validate_is_empty(full_field_path, value) + if regexp: + self._validate_regexp(full_field_path, regexp, value) + elif field_type == "dict": + if not isinstance(value, dict): + raise GcpFieldValidationException( + "The field '{}' should be of dictionary type according to the " + "specification '{}' but it is '{}'".format( + full_field_path, validation_spec, value + ) + ) + if children_validation_specs is None: + self.log.debug( + "The dict field '%s' has no nested fields defined in the " + "specification '%s'. That's perfectly ok - it's content will " + "not be validated.", + full_field_path, + validation_spec, + ) + else: + self._validate_dict(children_validation_specs, full_field_path, value) + elif field_type == "union": + if not children_validation_specs: + raise GcpValidationSpecificationException( + "The union field '{}' has no nested fields " + "defined in specification '{}'. Unions should have at least one " + "nested field defined.".format(full_field_path, validation_spec) + ) + self._validate_union( + children_validation_specs, full_field_path, dictionary_to_validate + ) + elif field_type == "list": + if not isinstance(value, list): + raise GcpFieldValidationException( + "The field '{}' should be of list type according to the " + "specification '{}' but it is '{}'".format( + full_field_path, validation_spec, value + ) + ) + elif custom_validation: + try: + custom_validation(value) + except Exception as e: + raise GcpFieldValidationException( + "Error while validating custom field '{}' specified by '{}': '{}'".format( + full_field_path, validation_spec, e + ) + ) + elif field_type is None: + self.log.debug( + "The type of field '%s' is not specified in '%s'. Not validating its content.", + full_field_path, + validation_spec, + ) + else: + raise GcpValidationSpecificationException( + "The field '{}' is of type '{}' in specification '{}'." + "This type is unknown to validation!".format( + full_field_path, field_type, validation_spec + ) + ) + return True + + def validate(self, body_to_validate: dict) -> None: + """ + Validates if the body (dictionary) follows specification that the validator was + instantiated with. Raises ValidationSpecificationException or + ValidationFieldException in case of problems with specification or the + body not conforming to the specification respectively. + + :param body_to_validate: body that must follow the specification + :type body_to_validate: dict + :return: None + """ + try: + for validation_spec in self._validation_specs: + self._validate_field( + validation_spec=validation_spec, + dictionary_to_validate=body_to_validate, + ) + except GcpFieldValidationException as e: + raise GcpFieldValidationException( + f"There was an error when validating: body '{body_to_validate}': '{e}'" + ) + all_field_names = [ + spec["name"] + for spec in self._validation_specs + if spec.get("type") != "union" + and spec.get("api_version") != self._api_version + ] + all_union_fields = [ + spec for spec in self._validation_specs if spec.get("type") == "union" + ] + for union_field in all_union_fields: + all_field_names.extend( + [ + nested_union_spec["name"] + for nested_union_spec in union_field["fields"] + if nested_union_spec.get("type") != "union" + and nested_union_spec.get("api_version") != self._api_version + ] + ) + for field_name in body_to_validate.keys(): + if field_name not in all_field_names: + self.log.warning( + "The field '%s' is in the body, but is not specified in the " + "validation specification '%s'. " + "This might be because you are using newer API version and " + "new field names defined for that version. Then the warning " + "can be safely ignored, or you might want to upgrade the operator" + "to the version that supports the new API version.", + field_name, + self._validation_specs, + ) diff --git a/reference/providers/google/cloud/utils/mlengine_operator_utils.py b/reference/providers/google/cloud/utils/mlengine_operator_utils.py new file mode 100644 index 0000000..fdf2919 --- /dev/null +++ b/reference/providers/google/cloud/utils/mlengine_operator_utils.py @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# +"""This module contains helper functions for MLEngine operators.""" + +import base64 +import json +import os +import re +from typing import Callable, Dict, Iterable, List, Optional, Tuple, TypeVar +from urllib.parse import urlsplit + +import dill +from airflow import DAG +from airflow.exceptions import AirflowException +from airflow.operators.python import PythonOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.cloud.operators.dataflow import ( + DataflowCreatePythonJobOperator, +) +from airflow.providers.google.cloud.operators.mlengine import ( + MLEngineStartBatchPredictionJobOperator, +) + +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + + +def create_evaluate_ops( # pylint: disable=too-many-arguments + task_prefix: str, + data_format: str, + input_paths: List[str], + prediction_path: str, + metric_fn_and_keys: Tuple[T, Iterable[str]], + validate_fn: T, + batch_prediction_job_id: Optional[str] = None, + region: Optional[str] = None, + project_id: Optional[str] = None, + dataflow_options: Optional[Dict] = None, + model_uri: Optional[str] = None, + model_name: Optional[str] = None, + version_name: Optional[str] = None, + dag: Optional[DAG] = None, + py_interpreter="python3", +): + """ + Creates Operators needed for model evaluation and returns. + + It gets prediction over inputs via Cloud ML Engine BatchPrediction API by + calling MLEngineBatchPredictionOperator, then summarize and validate + the result via Cloud Dataflow using DataFlowPythonOperator. + + For details and pricing about Batch prediction, please refer to the website + https://cloud.google.com/ml-engine/docs/how-tos/batch-predict + and for Cloud Dataflow, https://cloud.google.com/dataflow/docs/ + + It returns three chained operators for prediction, summary, and validation, + named as ``-prediction``, ``-summary``, and ``-validation``, + respectively. + (```` should contain only alphanumeric characters or hyphen.) + + The upstream and downstream can be set accordingly like: + + .. code-block:: python + + pred, _, val = create_evaluate_ops(...) + pred.set_upstream(upstream_op) + ... + downstream_op.set_upstream(val) + + Callers will provide two python callables, metric_fn and validate_fn, in + order to customize the evaluation behavior as they wish. + + - metric_fn receives a dictionary per instance derived from json in the + batch prediction result. The keys might vary depending on the model. + It should return a tuple of metrics. + - validation_fn receives a dictionary of the averaged metrics that metric_fn + generated over all instances. + The key/value of the dictionary matches to what's given by + metric_fn_and_keys arg. + The dictionary contains an additional metric, 'count' to represent the + total number of instances received for evaluation. + The function would raise an exception to mark the task as failed, in a + case the validation result is not okay to proceed (i.e. to set the trained + version as default). + + Typical examples are like this: + + .. code-block:: python + + def get_metric_fn_and_keys(): + import math # imports should be outside of the metric_fn below. + def error_and_squared_error(inst): + label = float(inst['input_label']) + classes = float(inst['classes']) # 0 or 1 + err = abs(classes-label) + squared_err = math.pow(classes-label, 2) + return (err, squared_err) # returns a tuple. + return error_and_squared_error, ['err', 'mse'] # key order must match. + + def validate_err_and_count(summary): + if summary['err'] > 0.2: + raise ValueError('Too high err>0.2; summary=%s' % summary) + if summary['mse'] > 0.05: + raise ValueError('Too high mse>0.05; summary=%s' % summary) + if summary['count'] < 1000: + raise ValueError('Too few instances<1000; summary=%s' % summary) + return summary + + For the details on the other BatchPrediction-related arguments (project_id, + job_id, region, data_format, input_paths, prediction_path, model_uri), + please refer to MLEngineBatchPredictionOperator too. + + :param task_prefix: a prefix for the tasks. Only alphanumeric characters and + hyphen are allowed (no underscores), since this will be used as dataflow + job name, which doesn't allow other characters. + :type task_prefix: str + + :param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP' + :type data_format: str + + :param input_paths: a list of input paths to be sent to BatchPrediction. + :type input_paths: list[str] + + :param prediction_path: GCS path to put the prediction results in. + :type prediction_path: str + + :param metric_fn_and_keys: a tuple of metric_fn and metric_keys: + + - metric_fn is a function that accepts a dictionary (for an instance), + and returns a tuple of metric(s) that it calculates. + + - metric_keys is a list of strings to denote the key of each metric. + :type metric_fn_and_keys: tuple of a function and a list[str] + + :param validate_fn: a function to validate whether the averaged metric(s) is + good enough to push the model. + :type validate_fn: function + + :param batch_prediction_job_id: the id to use for the Cloud ML Batch + prediction job. Passed directly to the MLEngineBatchPredictionOperator as + the job_id argument. + :type batch_prediction_job_id: str + + :param project_id: the Google Cloud project id in which to execute + Cloud ML Batch Prediction and Dataflow jobs. If None, then the `dag`'s + `default_args['project_id']` will be used. + :type project_id: str + + :param region: the Google Cloud region in which to execute Cloud ML + Batch Prediction and Dataflow jobs. If None, then the `dag`'s + `default_args['region']` will be used. + :type region: str + + :param dataflow_options: options to run Dataflow jobs. If None, then the + `dag`'s `default_args['dataflow_default_options']` will be used. + :type dataflow_options: dictionary + + :param model_uri: GCS path of the model exported by Tensorflow using + ``tensorflow.estimator.export_savedmodel()``. It cannot be used with + model_name or version_name below. See MLEngineBatchPredictionOperator for + more detail. + :type model_uri: str + + :param model_name: Used to indicate a model to use for prediction. Can be + used in combination with version_name, but cannot be used together with + model_uri. See MLEngineBatchPredictionOperator for more detail. If None, + then the `dag`'s `default_args['model_name']` will be used. + :type model_name: str + + :param version_name: Used to indicate a model version to use for prediction, + in combination with model_name. Cannot be used together with model_uri. + See MLEngineBatchPredictionOperator for more detail. If None, then the + `dag`'s `default_args['version_name']` will be used. + :type version_name: str + + :param dag: The `DAG` to use for all Operators. + :type dag: airflow.models.DAG + + :param py_interpreter: Python version of the beam pipeline. + If None, this defaults to the python3. + To track python versions supported by beam and related + issues check: https://issues.apache.org/jira/browse/BEAM-1251 + :type py_interpreter: str + + :returns: a tuple of three operators, (prediction, summary, validation) + :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator, + PythonOperator) + """ + batch_prediction_job_id = batch_prediction_job_id or "" + dataflow_options = dataflow_options or {} + region = region or "" + + # Verify that task_prefix doesn't have any special characters except hyphen + # '-', which is the only allowed non-alphanumeric character by Dataflow. + if not re.match(r"^[a-zA-Z][-A-Za-z0-9]*$", task_prefix): + raise AirflowException( + "Malformed task_id for DataFlowPythonOperator (only alphanumeric " + "and hyphens are allowed but got: " + task_prefix + ) + + metric_fn, metric_keys = metric_fn_and_keys + if not callable(metric_fn): + raise AirflowException("`metric_fn` param must be callable.") + if not callable(validate_fn): + raise AirflowException("`validate_fn` param must be callable.") + + if dag is not None and dag.default_args is not None: + default_args = dag.default_args + project_id = project_id or default_args.get("project_id") + region = region or default_args["region"] + model_name = model_name or default_args.get("model_name") + version_name = version_name or default_args.get("version_name") + dataflow_options = dataflow_options or default_args.get( + "dataflow_default_options" + ) + + evaluate_prediction = MLEngineStartBatchPredictionJobOperator( + task_id=(task_prefix + "-prediction"), + project_id=project_id, + job_id=batch_prediction_job_id, + region=region, + data_format=data_format, + input_paths=input_paths, + output_path=prediction_path, + uri=model_uri, + model_name=model_name, + version_name=version_name, + dag=dag, + ) + + metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True)).decode() + evaluate_summary = DataflowCreatePythonJobOperator( + task_id=(task_prefix + "-summary"), + py_file=os.path.join( + os.path.dirname(__file__), "mlengine_prediction_summary.py" + ), + dataflow_default_options=dataflow_options, + options={ + "prediction_path": prediction_path, + "metric_fn_encoded": metric_fn_encoded, + "metric_keys": ",".join(metric_keys), + }, + py_interpreter=py_interpreter, + py_requirements=["apache-beam[gcp]>=2.14.0"], + dag=dag, + ) + evaluate_summary.set_upstream(evaluate_prediction) + + def apply_validate_fn(*args, templates_dict, **kwargs): + prediction_path = templates_dict["prediction_path"] + scheme, bucket, obj, _, _ = urlsplit(prediction_path) + if scheme != "gs" or not bucket or not obj: + raise ValueError(f"Wrong format prediction_path: {prediction_path}") + summary = os.path.join(obj.strip("/"), "prediction.summary.json") + gcs_hook = GCSHook() + summary = json.loads(gcs_hook.download(bucket, summary)) + return validate_fn(summary) + + evaluate_validation = PythonOperator( + task_id=(task_prefix + "-validation"), + python_callable=apply_validate_fn, + templates_dict={"prediction_path": prediction_path}, + dag=dag, + ) + evaluate_validation.set_upstream(evaluate_summary) + + return evaluate_prediction, evaluate_summary, evaluate_validation diff --git a/reference/providers/google/cloud/utils/mlengine_prediction_summary.py b/reference/providers/google/cloud/utils/mlengine_prediction_summary.py new file mode 100644 index 0000000..946d17e --- /dev/null +++ b/reference/providers/google/cloud/utils/mlengine_prediction_summary.py @@ -0,0 +1,216 @@ +# flake8: noqa: F841 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +A template called by DataFlowPythonOperator to summarize BatchPrediction. + +It accepts a user function to calculate the metric(s) per instance in +the prediction results, then aggregates to output as a summary. + +It accepts the following arguments: + +- ``--prediction_path``: + The GCS folder that contains BatchPrediction results, containing + ``prediction.results-NNNNN-of-NNNNN`` files in the json format. + Output will be also stored in this folder, as 'prediction.summary.json'. +- ``--metric_fn_encoded``: + An encoded function that calculates and returns a tuple of metric(s) + for a given instance (as a dictionary). It should be encoded + via ``base64.b64encode(dill.dumps(fn, recurse=True))``. +- ``--metric_keys``: + A comma-separated key(s) of the aggregated metric(s) in the summary + output. The order and the size of the keys must match to the output + of metric_fn. + The summary will have an additional key, 'count', to represent the + total number of instances, so the keys shouldn't include 'count'. + + +Usage example: + +.. code-block: python + + from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator + + + def get_metric_fn(): + import math # all imports must be outside of the function to be passed. + def metric_fn(inst): + label = float(inst["input_label"]) + classes = float(inst["classes"]) + prediction = float(inst["scores"][1]) + log_loss = math.log(1 + math.exp( + -(label * 2 - 1) * math.log(prediction / (1 - prediction)))) + squared_err = (classes-label)**2 + return (log_loss, squared_err) + return metric_fn + metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True)) + DataflowCreatePythonJobOperator( + task_id="summary-prediction", + py_options=["-m"], + py_file="airflow.providers.google.cloud.utils.mlengine_prediction_summary", + options={ + "prediction_path": prediction_path, + "metric_fn_encoded": metric_fn_encoded, + "metric_keys": "log_loss,mse" + }, + dataflow_default_options={ + "project": "xxx", "region": "us-east1", + "staging_location": "gs://yy", "temp_location": "gs://zz", + } + ) >> dag + +When the input file is like the following:: + + {"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]} + {"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]} + {"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]} + {"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]} + +The output file will be:: + + {"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25} + +To test outside of the dag: + +.. code-block:: python + + subprocess.check_call(["python", + "-m", + "airflow.providers.google.cloud.utils.mlengine_prediction_summary", + "--prediction_path=gs://...", + "--metric_fn_encoded=" + metric_fn_encoded, + "--metric_keys=log_loss,mse", + "--runner=DataflowRunner", + "--staging_location=gs://...", + "--temp_location=gs://...", + ]) +""" + +import argparse +import base64 +import json +import logging +import os + +import apache_beam as beam +import dill # pylint: disable=wrong-import-order + + +class JsonCoder: + """JSON encoder/decoder.""" + + @staticmethod + def encode(x): + """JSON encoder.""" + return json.dumps(x).encode() + + @staticmethod + def decode(x): + """JSON decoder.""" + return json.loads(x) + + +@beam.ptransform_fn +def MakeSummary(pcoll, metric_fn, metric_keys): # pylint: disable=invalid-name + """Summary PTransform used in Dataflow.""" + return ( + pcoll + | "ApplyMetricFnPerInstance" >> beam.Map(metric_fn) + | "PairWith1" >> beam.Map(lambda tup: tup + (1,)) + | "SumTuple" + >> beam.CombineGlobally( + beam.combiners.TupleCombineFn(*([sum] * (len(metric_keys) + 1))) + ) + | "AverageAndMakeDict" + >> beam.Map( + lambda tup: dict( + [(name, tup[i] / tup[-1]) for i, name in enumerate(metric_keys)] + + [("count", tup[-1])] + ) + ) + ) + + +def run(argv=None): + """Helper for obtaining prediction summary.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--prediction_path", + required=True, + help=( + "The GCS folder that contains BatchPrediction results, containing " + "prediction.results-NNNNN-of-NNNNN files in the json format. " + "Output will be also stored in this folder, as a file" + "'prediction.summary.json'." + ), + ) + parser.add_argument( + "--metric_fn_encoded", + required=True, + help=( + "An encoded function that calculates and returns a tuple of " + "metric(s) for a given instance (as a dictionary). It should be " + "encoded via base64.b64encode(dill.dumps(fn, recurse=True))." + ), + ) + parser.add_argument( + "--metric_keys", + required=True, + help=( + "A comma-separated keys of the aggregated metric(s) in the summary " + "output. The order and the size of the keys must match to the " + "output of metric_fn. The summary will have an additional key, " + "'count', to represent the total number of instances, so this flag " + "shouldn't include 'count'." + ), + ) + known_args, pipeline_args = parser.parse_known_args(argv) + + metric_fn = dill.loads(base64.b64decode(known_args.metric_fn_encoded)) + if not callable(metric_fn): + raise ValueError("--metric_fn_encoded must be an encoded callable.") + metric_keys = known_args.metric_keys.split(",") + + with beam.Pipeline(options=beam.pipeline.PipelineOptions(pipeline_args)) as pipe: + # pylint: disable=no-value-for-parameter + prediction_result_pattern = os.path.join( + known_args.prediction_path, "prediction.results-*-of-*" + ) + prediction_summary_path = os.path.join( + known_args.prediction_path, "prediction.summary.json" + ) + # This is apache-beam ptransform's convention + _ = ( + pipe + | "ReadPredictionResult" + >> beam.io.ReadFromText(prediction_result_pattern, coder=JsonCoder()) + | "Summary" >> MakeSummary(metric_fn, metric_keys) + | "Write" + >> beam.io.WriteToText( + prediction_summary_path, + shard_name_template="", # without trailing -NNNNN-of-NNNNN. + coder=JsonCoder(), + ) + ) + + +if __name__ == "__main__": + # Dataflow does not print anything on the screen by default. Good practice says to configure the logger + # to be able to track the progress. This code is run in a separate process, so it's safe. + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/reference/providers/google/common/__init__.py b/reference/providers/google/common/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/common/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/common/auth_backend/__init__.py b/reference/providers/google/common/auth_backend/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/common/auth_backend/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/common/auth_backend/google_openid.py b/reference/providers/google/common/auth_backend/google_openid.py new file mode 100644 index 0000000..e78aaaa --- /dev/null +++ b/reference/providers/google/common/auth_backend/google_openid.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Authentication backend that use Google credentials for authorization.""" +import logging +from functools import wraps +from typing import Callable, Optional, TypeVar, cast + +import google +import google.auth.transport.requests +import google.oauth2.id_token +from airflow.configuration import conf +from airflow.providers.google.common.utils.id_token_credentials import ( + get_default_id_token_credentials, +) +from flask import Response, _request_ctx_stack, current_app +from flask import request as flask_request # type: ignore +from google.auth import exceptions +from google.auth.transport.requests import AuthorizedSession +from google.oauth2 import service_account + +log = logging.getLogger(__name__) + +_GOOGLE_ISSUERS = ("accounts.google.com", "https://accounts.google.com") +AUDIENCE = conf.get("api", "google_oauth2_audience") + + +def create_client_session(): + """Create a HTTP authorized client.""" + service_account_path = conf.get("api", "google_key_path") + if service_account_path: + id_token_credentials = ( + service_account.IDTokenCredentials.from_service_account_file( + service_account_path + ) + ) + else: + id_token_credentials = get_default_id_token_credentials( + target_audience=AUDIENCE + ) + return AuthorizedSession(credentials=id_token_credentials) + + +def init_app(_): + """Initializes authentication.""" + + +def _get_id_token_from_request(request) -> Optional[str]: + authorization_header = request.headers.get("Authorization") + + if not authorization_header: + return None + + authorization_header_parts = authorization_header.split(" ", 2) + + if ( + len(authorization_header_parts) != 2 + or authorization_header_parts[0].lower() != "bearer" + ): + return None + + id_token = authorization_header_parts[1] + return id_token + + +def _verify_id_token(id_token: str) -> Optional[str]: + try: + request_adapter = google.auth.transport.requests.Request() + id_info = google.oauth2.id_token.verify_token( + id_token, request_adapter, AUDIENCE + ) + except exceptions.GoogleAuthError: + return None + + # This check is part of google-auth v1.19.0 (2020-07-09), In order not to create strong version + # requirements to too new version, we check it in our code too. + # One day, we may delete this code and set minimum version in requirements. + if id_info.get("iss") not in _GOOGLE_ISSUERS: + return None + + if not id_info.get("email_verified", False): + return None + + return id_info.get("email") + + +def _lookup_user(user_email: str): + security_manager = current_app.appbuilder.sm + user = security_manager.find_user(email=user_email) + + if not user: + return None + + if not user.is_active: + return None + + return user + + +def _set_current_user(user): + ctx = _request_ctx_stack.top + ctx.user = user + + +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + + +def requires_authentication(function: T): + """Decorator for functions that require authentication.""" + + @wraps(function) + def decorated(*args, **kwargs): + access_token = _get_id_token_from_request(flask_request) + if not access_token: + log.debug("Missing ID Token") + return Response("Forbidden", 403) + + userid = _verify_id_token(access_token) + if not userid: + log.debug("Invalid ID Token") + return Response("Forbidden", 403) + + log.debug("Looking for user with e-mail: %s", userid) + + user = _lookup_user(userid) + if not user: + return Response("Forbidden", 403) + + log.debug("Found user: %s", user) + + _set_current_user(user) + + return function(*args, **kwargs) + + return cast(T, decorated) diff --git a/reference/providers/google/common/hooks/__init__.py b/reference/providers/google/common/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/common/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/common/hooks/base_google.py b/reference/providers/google/common/hooks/base_google.py new file mode 100644 index 0000000..847abf8 --- /dev/null +++ b/reference/providers/google/common/hooks/base_google.py @@ -0,0 +1,593 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains a Google Cloud API base hook.""" +import functools +import json +import logging +import os +import tempfile +from contextlib import ExitStack, contextmanager +from subprocess import check_output +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar, Union, cast + +import google.auth +import google.auth.credentials +import google.oauth2.service_account +import google_auth_httplib2 +import tenacity +from airflow import version +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.providers.google.cloud.utils.credentials_provider import ( + _get_scopes, + _get_target_principal_and_delegates, + get_credentials_and_project_id, +) +from airflow.utils.process_utils import patch_environ +from google.api_core.exceptions import Forbidden, ResourceExhausted, TooManyRequests +from google.api_core.gapic_v1.client_info import ClientInfo +from google.auth import _cloud_sdk +from google.auth.environment_vars import CLOUD_SDK_CONFIG_DIR, CREDENTIALS +from googleapiclient import discovery +from googleapiclient.errors import HttpError +from googleapiclient.http import MediaIoBaseDownload, build_http, set_user_agent + +log = logging.getLogger(__name__) + + +# Constants used by the mechanism of repeating requests in reaction to exceeding the temporary quota. +INVALID_KEYS = [ + "DefaultRequestsPerMinutePerProject", + "DefaultRequestsPerMinutePerUser", + "RequestsPerMinutePerProject", + "Resource has been exhausted (e.g. check quota).", +] +INVALID_REASONS = [ + "userRateLimitExceeded", +] + + +def is_soft_quota_exception(exception: Exception): + """ + API for Google services does not have a standardized way to report quota violation errors. + The function has been adapted by trial and error to the following services: + + * Google Translate + * Google Vision + * Google Text-to-Speech + * Google Speech-to-Text + * Google Natural Language + * Google Video Intelligence + """ + if isinstance(exception, Forbidden): + return any( + reason in error.details() + for reason in INVALID_REASONS + for error in exception.errors + ) + + if isinstance(exception, (ResourceExhausted, TooManyRequests)): + return any( + key in error.details() for key in INVALID_KEYS for error in exception.errors + ) + + return False + + +def is_operation_in_progress_exception(exception: Exception) -> bool: + """ + Some of the calls return 429 (too many requests!) or 409 errors (Conflict) + in case of operation in progress. + + * Google Cloud SQL + """ + if isinstance(exception, HttpError): + return exception.resp.status == 429 or exception.resp.status == 409 + return False + + +class retry_if_temporary_quota( + tenacity.retry_if_exception +): # pylint: disable=invalid-name + """Retries if there was an exception for exceeding the temporary quote limit.""" + + def __init__(self): + super().__init__(is_soft_quota_exception) + + +class retry_if_operation_in_progress( + tenacity.retry_if_exception +): # pylint: disable=invalid-name + """Retries if there was an exception for exceeding the temporary quote limit.""" + + def __init__(self): + super().__init__(is_operation_in_progress_exception) + + +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name +RT = TypeVar("RT") # pylint: disable=invalid-name + + +class GoogleBaseHook(BaseHook): + """ + A base hook for Google cloud-related hooks. Google cloud has a shared REST + API client that is built in the same way no matter which service you use. + This class helps construct and authorize the credentials needed to then + call googleapiclient.discovery.build() to actually discover and build a client + for a Google cloud service. + + The class also contains some miscellaneous helper functions. + + All hook derived from this base hook use the 'Google Cloud' connection + type. Three ways of authentication are supported: + + Default credentials: Only the 'Project Id' is required. You'll need to + have set up default credentials, such as by the + ``GOOGLE_APPLICATION_DEFAULT`` environment variable or from the metadata + server on Google Compute Engine. + + JSON key file: Specify 'Project Id', 'Keyfile Path' and 'Scope'. + + Legacy P12 key files are not supported. + + JSON data provided in the UI: Specify 'Keyfile JSON'. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + conn_name_attr = "gcp_conn_id" + default_conn_name = "google_cloud_default" + conn_type = "google_cloud_platform" + hook_name = "Google Cloud" + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import ( + BS3PasswordFieldWidget, + BS3TextFieldWidget, + ) + from flask_babel import lazy_gettext + from wtforms import IntegerField, PasswordField, StringField + from wtforms.validators import NumberRange + + return { + "extra__google_cloud_platform__project": StringField( + lazy_gettext("Project Id"), widget=BS3TextFieldWidget() + ), + "extra__google_cloud_platform__key_path": StringField( + lazy_gettext("Keyfile Path"), widget=BS3TextFieldWidget() + ), + "extra__google_cloud_platform__keyfile_dict": PasswordField( + lazy_gettext("Keyfile JSON"), widget=BS3PasswordFieldWidget() + ), + "extra__google_cloud_platform__scope": StringField( + lazy_gettext("Scopes (comma separated)"), widget=BS3TextFieldWidget() + ), + "extra__google_cloud_platform__num_retries": IntegerField( + lazy_gettext("Number of Retries"), + validators=[NumberRange(min=0)], + widget=BS3TextFieldWidget(), + default=5, + ), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], + "relabeling": {}, + } + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__() + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.extras = self.get_connection(self.gcp_conn_id).extra_dejson # type: Dict + self._cached_credentials: Optional[google.auth.credentials.Credentials] = None + self._cached_project_id: Optional[str] = None + + def _get_credentials_and_project_id( + self, + ) -> Tuple[google.auth.credentials.Credentials, Optional[str]]: + """Returns the Credentials object for Google API and the associated project_id""" + if self._cached_credentials is not None: + return self._cached_credentials, self._cached_project_id + + key_path: Optional[str] = self._get_field("key_path", None) + try: + keyfile_dict: Optional[str] = self._get_field("keyfile_dict", None) + keyfile_dict_json: Optional[Dict[str, str]] = None + if keyfile_dict: + keyfile_dict_json = json.loads(keyfile_dict) + except json.decoder.JSONDecodeError: + raise AirflowException("Invalid key JSON.") + + target_principal, delegates = _get_target_principal_and_delegates( + self.impersonation_chain + ) + + credentials, project_id = get_credentials_and_project_id( + key_path=key_path, + keyfile_dict=keyfile_dict_json, + scopes=self.scopes, + delegate_to=self.delegate_to, + target_principal=target_principal, + delegates=delegates, + ) + + overridden_project_id = self._get_field("project") + if overridden_project_id: + project_id = overridden_project_id + + self._cached_credentials = credentials + self._cached_project_id = project_id + + return credentials, project_id + + def _get_credentials(self) -> google.auth.credentials.Credentials: + """Returns the Credentials object for Google API""" + credentials, _ = self._get_credentials_and_project_id() + return credentials + + def _get_access_token(self) -> str: + """Returns a valid access token from Google API Credentials""" + return self._get_credentials().token + + @functools.lru_cache(maxsize=None) + def _get_credentials_email(self) -> str: + """ + Returns the email address associated with the currently logged in account + + If a service account is used, it returns the service account. + If user authentication (e.g. gcloud auth) is used, it returns the e-mail account of that user. + """ + credentials = self._get_credentials() + service_account_email = getattr(credentials, "service_account_email", None) + if service_account_email: + return service_account_email + + http_authorized = self._authorize() + oauth2_client = discovery.build( + "oauth2", "v1", http=http_authorized, cache_discovery=False + ) + return oauth2_client.tokeninfo().execute()["email"] # pylint: disable=no-member + + def _authorize(self) -> google_auth_httplib2.AuthorizedHttp: + """ + Returns an authorized HTTP object to be used to build a Google cloud + service hook connection. + """ + credentials = self._get_credentials() + http = build_http() + http = set_user_agent(http, "airflow/" + version.version) + authed_http = google_auth_httplib2.AuthorizedHttp(credentials, http=http) + return authed_http + + def _get_field(self, f: str, default: Any = None) -> Any: + """ + Fetches a field from extras, and returns it. This is some Airflow + magic. The google_cloud_platform hook type adds custom UI elements + to the hook page, which allow admins to specify service_account, + key_path, etc. They get formatted as shown below. + """ + long_f = f"extra__google_cloud_platform__{f}" + if hasattr(self, "extras") and long_f in self.extras: + return self.extras[long_f] + else: + return default + + @property + def project_id(self) -> Optional[str]: + """ + Returns project id. + + :return: id of the project + :rtype: str + """ + _, project_id = self._get_credentials_and_project_id() + return project_id + + @property + def num_retries(self) -> int: + """ + Returns num_retries from Connection. + + :return: the number of times each API request should be retried + :rtype: int + """ + field_value = self._get_field("num_retries", default=5) + if field_value is None: + return 5 + if isinstance(field_value, str) and field_value.strip() == "": + return 5 + try: + return int(field_value) + except ValueError: + raise AirflowException( + f"The num_retries field should be a integer. " + f'Current value: "{field_value}" (type: {type(field_value)}). ' + f"Please check the connection configuration." + ) + + @property + def client_info(self) -> ClientInfo: + """ + Return client information used to generate a user-agent for API calls. + + It allows for better errors tracking. + + This object is only used by the google-cloud-* libraries that are built specifically for + the Google Cloud. It is not supported by The Google APIs Python Client that use Discovery + based APIs. + """ + client_info = ClientInfo(client_library_version="airflow_v" + version.version) + return client_info + + @property + def scopes(self) -> Sequence[str]: + """ + Return OAuth 2.0 scopes. + + :return: Returns the scope defined in the connection configuration, or the default scope + :rtype: Sequence[str] + """ + scope_value = self._get_field("scope", None) # type: Optional[str] + + return _get_scopes(scope_value) + + @staticmethod + def quota_retry(*args, **kwargs) -> Callable: + """ + A decorator that provides a mechanism to repeat requests in response to exceeding a temporary quote + limit. + """ + + def decorator(fun: Callable): + default_kwargs = { + "wait": tenacity.wait_exponential(multiplier=1, max=100), + "retry": retry_if_temporary_quota(), + "before": tenacity.before_log(log, logging.DEBUG), + "after": tenacity.after_log(log, logging.DEBUG), + } + default_kwargs.update(**kwargs) + return tenacity.retry(*args, **default_kwargs)(fun) + + return decorator + + @staticmethod + def operation_in_progress_retry(*args, **kwargs) -> Callable[[T], T]: + """ + A decorator that provides a mechanism to repeat requests in response to + operation in progress (HTTP 409) + limit. + """ + + def decorator(fun: T): + default_kwargs = { + "wait": tenacity.wait_exponential(multiplier=1, max=300), + "retry": retry_if_operation_in_progress(), + "before": tenacity.before_log(log, logging.DEBUG), + "after": tenacity.after_log(log, logging.DEBUG), + } + default_kwargs.update(**kwargs) + return cast(T, tenacity.retry(*args, **default_kwargs)(fun)) + + return decorator + + @staticmethod + def fallback_to_default_project_id(func: Callable[..., RT]) -> Callable[..., RT]: + """ + Decorator that provides fallback for Google Cloud project id. If + the project is None it will be replaced with the project_id from the + service account the Hook is authenticated with. Project id can be specified + either via project_id kwarg or via first parameter in positional args. + + :param func: function to wrap + :return: result of the function call + """ + + @functools.wraps(func) + def inner_wrapper(self: GoogleBaseHook, *args, **kwargs) -> RT: + if args: + raise AirflowException( + "You must use keyword arguments in this methods rather than positional" + ) + if "project_id" in kwargs: + kwargs["project_id"] = kwargs["project_id"] or self.project_id + else: + kwargs["project_id"] = self.project_id + if not kwargs["project_id"]: + raise AirflowException( + "The project id must be passed either as " + "keyword project_id parameter or as project_id extra " + "in Google Cloud connection definition. Both are not set!" + ) + return func(self, *args, **kwargs) + + return inner_wrapper + + @staticmethod + def provide_gcp_credential_file(func: T) -> T: + """ + Function decorator that provides a Google Cloud credentials for application supporting Application + Default Credentials (ADC) strategy. + + It is recommended to use ``provide_gcp_credential_file_as_context`` context manager to limit the + scope when authorization data is available. Using context manager also + makes it easier to use multiple connection in one function. + """ + + @functools.wraps(func) + def wrapper(self: GoogleBaseHook, *args, **kwargs): + with self.provide_gcp_credential_file_as_context(): + return func(self, *args, **kwargs) + + return cast(T, wrapper) + + @contextmanager + def provide_gcp_credential_file_as_context(self): + """ + Context manager that provides a Google Cloud credentials for application supporting `Application + Default Credentials (ADC) strategy `__. + + It can be used to provide credentials for external programs (e.g. gcloud) that expect authorization + file in ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable. + """ + key_path = self._get_field( + "key_path", None + ) # type: Optional[str] # noqa: E501 # pylint: disable=protected-access + keyfile_dict = self._get_field( + "keyfile_dict", None + ) # type: Optional[Dict] # noqa: E501 # pylint: disable=protected-access + if key_path and keyfile_dict: + raise AirflowException( + "The `keyfile_dict` and `key_path` fields are mutually exclusive. " + "Please provide only one value." + ) + elif key_path: + if key_path.endswith(".p12"): + raise AirflowException( + "Legacy P12 key file are not supported, use a JSON key file." + ) + with patch_environ({CREDENTIALS: key_path}): + yield key_path + elif keyfile_dict: + with tempfile.NamedTemporaryFile(mode="w+t") as conf_file: + conf_file.write(keyfile_dict) + conf_file.flush() + with patch_environ({CREDENTIALS: conf_file.name}): + yield conf_file.name + else: + # We will use the default service account credentials. + yield None + + @contextmanager + def provide_authorized_gcloud(self): + """ + Provides a separate gcloud configuration with current credentials. + + The gcloud tool allows you to login to Google Cloud only - ``gcloud auth login`` and + for the needs of Application Default Credentials ``gcloud auth application-default login``. + In our case, we want all commands to use only the credentials from ADCm so + we need to configure the credentials in gcloud manually. + """ + credentials_path = _cloud_sdk.get_application_default_credentials_path() + project_id = self.project_id + + with ExitStack() as exit_stack: + exit_stack.enter_context(self.provide_gcp_credential_file_as_context()) + gcloud_config_tmp = exit_stack.enter_context(tempfile.TemporaryDirectory()) + exit_stack.enter_context( + patch_environ({CLOUD_SDK_CONFIG_DIR: gcloud_config_tmp}) + ) + + if project_id: + # Don't display stdout/stderr for security reason + check_output(["gcloud", "config", "set", "core/project", project_id]) + if CREDENTIALS in os.environ: + # This solves most cases when we are logged in using the service key in Airflow. + # Don't display stdout/stderr for security reason + check_output( + [ + "gcloud", + "auth", + "activate-service-account", + f"--key-file={os.environ[CREDENTIALS]}", + ] + ) + elif os.path.exists(credentials_path): + # If we are logged in by `gcloud auth application-default` then we need to log in manually. + # This will make the `gcloud auth application-default` and `gcloud auth` credentials equals. + with open(credentials_path) as creds_file: + creds_content = json.loads(creds_file.read()) + # Don't display stdout/stderr for security reason + check_output( + [ + "gcloud", + "config", + "set", + "auth/client_id", + creds_content["client_id"], + ] + ) + # Don't display stdout/stderr for security reason + check_output( + [ + "gcloud", + "config", + "set", + "auth/client_secret", + creds_content["client_secret"], + ] + ) + # Don't display stdout/stderr for security reason + check_output( + [ + "gcloud", + "auth", + "activate-refresh-token", + creds_content["client_id"], + creds_content["refresh_token"], + ] + ) + yield + + @staticmethod + def download_content_from_request( + file_handle, request: dict, chunk_size: int + ) -> None: + """ + Download media resources. + Note that the Python file object is compatible with io.Base and can be used with this class also. + + :param file_handle: io.Base or file object. The stream in which to write the downloaded + bytes. + :type file_handle: io.Base or file object + :param request: googleapiclient.http.HttpRequest, the media request to perform in chunks. + :type request: Dict + :param chunk_size: int, File will be downloaded in chunks of this many bytes. + :type chunk_size: int + """ + downloader = MediaIoBaseDownload(file_handle, request, chunksize=chunk_size) + done = False + while done is False: + _, done = downloader.next_chunk() + file_handle.flush() diff --git a/reference/providers/google/common/hooks/discovery_api.py b/reference/providers/google/common/hooks/discovery_api.py new file mode 100644 index 0000000..d9abc9a --- /dev/null +++ b/reference/providers/google/common/hooks/discovery_api.py @@ -0,0 +1,192 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module allows you to connect to the Google Discovery API Service and query it.""" +from typing import Optional, Sequence, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import Resource, build + + +class GoogleDiscoveryApiHook(GoogleBaseHook): + """ + A hook to use the Google API Discovery Service. + + :param api_service_name: The name of the api service that is needed to get the data + for example 'youtube'. + :type api_service_name: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + _conn = None # type: Optional[Resource] + + def __init__( + self, + api_service_name: str, + api_version: str, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.api_service_name = api_service_name + self.api_version = api_version + + def get_conn(self) -> Re# + """ + Creates an authenticated api client for the given api service name and credentials. + + :return: the authenticated api service. + :rtype: Resource + """ + self.log.info("Authenticating Google API Client") + + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + serviceName=self.api_service_name, + version=self.api_version, + http=http_authorized, + cache_discovery=False, + ) + return self._conn + + def query( + self, endpoint: str, data: dict, paginate: bool = False, num_retries: int = 0 + ) -> dict: + """ + Creates a dynamic API call to any Google API registered in Google's API Client Library + and queries it. + + :param endpoint: The client libraries path to the api call's executing method. + For example: 'analyticsreporting.reports.batchGet' + + .. seealso:: https://developers.google.com/apis-explorer + for more information on what methods are available. + :type endpoint: str + :param data: The data (endpoint params) needed for the specific request to given endpoint. + :type data: dict + :param paginate: If set to True, it will collect all pages of data. + :type paginate: bool + :param num_retries: Define the number of retries for the requests being made if it fails. + :type num_retries: int + :return: the API response from the passed endpoint. + :rtype: dict + """ + google_api_conn_client = self.get_conn() + + api_response = self._call_api_request( + google_api_conn_client, endpoint, data, paginate, num_retries + ) + return api_response + + def _call_api_request( + self, google_api_conn_client, endpoint, data, paginate, num_retries + ): + api_endpoint_parts = endpoint.split(".") + + google_api_endpoint_instance = self._build_api_request( + google_api_conn_client, + api_sub_functions=api_endpoint_parts[1:], + api_endpoint_params=data, + ) + + if paginate: + return self._paginate_api( + google_api_endpoint_instance, + google_api_conn_client, + api_endpoint_parts, + num_retries, + ) + + return google_api_endpoint_instance.execute(num_retries=num_retries) + + def _build_api_request( + self, google_api_conn_client, api_sub_functions, api_endpoint_params + ): + for sub_function in api_sub_functions: + google_api_conn_client = getattr(google_api_conn_client, sub_function) + if sub_function != api_sub_functions[-1]: + google_api_conn_client = google_api_conn_client() + else: + google_api_conn_client = google_api_conn_client(**api_endpoint_params) + + return google_api_conn_client + + def _paginate_api( + self, + google_api_endpoint_instance, + google_api_conn_client, + api_endpoint_parts, + num_retries, + ): + api_responses = [] + + while google_api_endpoint_instance: + api_response = google_api_endpoint_instance.execute(num_retries=num_retries) + api_responses.append(api_response) + + google_api_endpoint_instance = self._build_next_api_request( + google_api_conn_client, + api_endpoint_parts[1:], + google_api_endpoint_instance, + api_response, + ) + + return api_responses + + def _build_next_api_request( + self, + google_api_conn_client, + api_sub_functions, + api_endpoint_instance, + api_response, + ): + for sub_function in api_sub_functions: + if sub_function != api_sub_functions[-1]: + google_api_conn_client = getattr(google_api_conn_client, sub_function) + google_api_conn_client = google_api_conn_client() + else: + google_api_conn_client = getattr( + google_api_conn_client, sub_function + "_next" + ) + google_api_conn_client = google_api_conn_client( + api_endpoint_instance, api_response + ) + + return google_api_conn_client diff --git a/reference/providers/google/common/utils/__init__.py b/reference/providers/google/common/utils/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/common/utils/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/common/utils/id_token_credentials.py b/reference/providers/google/common/utils/id_token_credentials.py new file mode 100644 index 0000000..f082d79 --- /dev/null +++ b/reference/providers/google/common/utils/id_token_credentials.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +You can execute this module to get ID Token. + + python -m airflow.providers.google.common.utils.id_token_credentials_provider + +To obtain info about this token, run the following commands: + + ID_TOKEN="$(python -m airflow.providers.google.common.utils.id_token_credentials)" + curl "https://www.googleapis.com/oauth2/v3/tokeninfo?id_token=${ID_TOKEN}" -v +""" + +import json +import os +from typing import Optional + +import google.auth.transport +import google.oauth2 +from google.auth import credentials as google_auth_credentials +from google.auth import environment_vars, exceptions +from google.auth._default import ( + _AUTHORIZED_USER_TYPE, + _HELP_MESSAGE, + _SERVICE_ACCOUNT_TYPE, + _VALID_TYPES, +) +from google.oauth2 import credentials as oauth2_credentials +from google.oauth2 import service_account + + +class IDTokenCredentialsAdapter(google_auth_credentials.Credentials): + """Convert Credentials with "openid" scope to IDTokenCredentials.""" + + def __init__(self, credentials: oauth2_credentials.Credentials): + super().__init__() + self.credentials = credentials + self.token = credentials.id_token + + @property + def expired(self): + return self.credentials.expired + + def refresh(self, request): + self.credentials.refresh(request) + self.token = self.credentials.id_token + + +def _load_credentials_from_file( + filename: str, target_audience: Optional[str] +) -> Optional[google_auth_credentials.Credentials]: + """ + Loads credentials from a file. + + The credentials file must be a service account key or a stored authorized user credential. + + :param filename: The full path to the credentials file. + :type filename: str + :return: Loaded credentials + :rtype: google.auth.credentials.Credentials + :raise google.auth.exceptions.DefaultCredentialsError: if the file is in the wrong format or is missing. + """ + if not os.path.exists(filename): + raise exceptions.DefaultCredentialsError(f"File {filename} was not found.") + + with open(filename) as file_obj: + try: + info = json.load(file_obj) + except json.JSONDecodeError: + raise exceptions.DefaultCredentialsError( + f"File {filename} is not a valid json file." + ) + + # The type key should indicate that the file is either a service account + # credentials file or an authorized user credentials file. + credential_type = info.get("type") + + if credential_type == _AUTHORIZED_USER_TYPE: + current_credentials = oauth2_credentials.Credentials.from_authorized_user_info( + info, scopes=["openid", "email"] + ) + current_credentials = IDTokenCredentialsAdapter(credentials=current_credentials) + + return current_credentials + + elif credential_type == _SERVICE_ACCOUNT_TYPE: + try: + return service_account.IDTokenCredentials.from_service_account_info( + info, target_audience=target_audience + ) + except ValueError: + raise exceptions.DefaultCredentialsError( + f"Failed to load service account credentials from {filename}" + ) + + raise exceptions.DefaultCredentialsError( + f"The file {filename} does not have a valid type. Type is {credential_type}, " + f"expected one of {_VALID_TYPES}." + ) + + +def _get_explicit_environ_credentials( + target_audience: Optional[str], +) -> Optional[google_auth_credentials.Credentials]: + """Gets credentials from the GOOGLE_APPLICATION_CREDENTIALS environment variable.""" + explicit_file = os.environ.get(environment_vars.CREDENTIALS) + + if explicit_file is None: + return None + + current_credentials = _load_credentials_from_file( + os.environ[environment_vars.CREDENTIALS], target_audience=target_audience + ) + + return current_credentials + + +def _get_gcloud_sdk_credentials( + target_audience: Optional[str], +) -> Optional[google_auth_credentials.Credentials]: + """Gets the credentials and project ID from the Cloud SDK.""" + from google.auth import _cloud_sdk + + # Check if application default credentials exist. + credentials_filename = _cloud_sdk.get_application_default_credentials_path() + + if not os.path.isfile(credentials_filename): + return None + + current_credentials = _load_credentials_from_file( + credentials_filename, target_audience + ) + + return current_credentials + + +def _get_gce_credentials( + target_audience: Optional[str], + request: Optional[google.auth.transport.Request] = None, +) -> Optional[google_auth_credentials.Credentials]: + """Gets credentials and project ID from the GCE Metadata Service.""" + # Ping requires a transport, but we want application default credentials + # to require no arguments. So, we'll use the _http_client transport which + # uses http.client. This is only acceptable because the metadata server + # doesn't do SSL and never requires proxies. + + # While this library is normally bundled with compute_engine, there are + # some cases where it's not available, so we tolerate ImportError. + try: + from google.auth import compute_engine + from google.auth.compute_engine import _metadata + except ImportError: + return None + from google.auth.transport import _http_client + + if request is None: + request = _http_client.Request() + + if _metadata.ping(request=request): + return compute_engine.IDTokenCredentials( + request, target_audience, use_metadata_identity_endpoint=True + ) + + return None + + +def get_default_id_token_credentials( + target_audience: Optional[str], request: google.auth.transport.Request = None +) -> google_auth_credentials.Credentials: + """Gets the default ID Token credentials for the current environment. + + `Application Default Credentials`_ provides an easy way to obtain credentials to call Google APIs for + server-to-server or local applications. + + .. _Application Default Credentials: https://developers.google.com\ + /identity/protocols/application-default-credentials + + :param target_audience: The intended audience for these credentials. + :type target_audience: Sequence[str] + :param request: An object used to make HTTP requests. This is used to detect whether the application + is running on Compute Engine. If not specified, then it will use the standard library http client + to make requests. + :type request: google.auth.transport.Request + :return: the current environment's credentials. + :rtype: google.auth.credentials.Credentials + :raises ~google.auth.exceptions.DefaultCredentialsError: + If no credentials were found, or if the credentials found were invalid. + """ + checkers = ( + lambda: _get_explicit_environ_credentials(target_audience), + lambda: _get_gcloud_sdk_credentials(target_audience), + lambda: _get_gce_credentials(target_audience, request), + ) + + for checker in checkers: + current_credentials = checker() + if current_credentials is not None: + return current_credentials + + raise exceptions.DefaultCredentialsError(_HELP_MESSAGE) + + +if __name__ == "__main__": + from google.auth.transport import requests + + request_adapter = requests.Request() + + creds = get_default_id_token_credentials(target_audience=None) + creds.refresh(request=request_adapter) + print(creds.token) diff --git a/reference/providers/google/config_templates/config.yml b/reference/providers/google/config_templates/config.yml new file mode 100644 index 0000000..ddd7ec9 --- /dev/null +++ b/reference/providers/google/config_templates/config.yml @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +--- +- name: providers_google + description: Options for google provider + options: + - name: verbose_logging + description: | + Sets verbose logging for google provider + version_added: 2.0.0 + type: boolean + example: ~ + default: "False" diff --git a/reference/providers/google/config_templates/default_config.cfg b/reference/providers/google/config_templates/default_config.cfg new file mode 100644 index 0000000..cdc264b --- /dev/null +++ b/reference/providers/google/config_templates/default_config.cfg @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# This is the template for Airflow's default configuration. When Airflow is +# imported, it looks for a configuration file at $AIRFLOW_HOME/airflow.cfg. If +# it doesn't exist, Airflow uses this template to generate it by replacing +# variables in curly braces with their global values from configuration.py. + +# Users should not modify this file; they should customize the generated +# airflow.cfg instead. + + +# ----------------------- TEMPLATE BEGINS HERE ----------------------- + +[providers_google] + +# Options for google provider +# Sets verbose logging for google provider +verbose_logging = False diff --git a/reference/providers/google/firebase/__init__.py b/reference/providers/google/firebase/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/firebase/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/firebase/example_dags/__init__.py b/reference/providers/google/firebase/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/firebase/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/firebase/example_dags/example_firestore.py b/reference/providers/google/firebase/example_dags/example_firestore.py new file mode 100644 index 0000000..6151f1c --- /dev/null +++ b/reference/providers/google/firebase/example_dags/example_firestore.py @@ -0,0 +1,139 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that shows interactions with Google Cloud Firestore. + +Prerequisites +============= + +This example uses two Google Cloud projects: + +* ``GCP_PROJECT_ID`` - It contains a bucket and a firestore database. +* ``G_FIRESTORE_PROJECT_ID`` - it contains the Data Warehouse based on the BigQuery service. + +Saving in a bucket should be possible from the ``G_FIRESTORE_PROJECT_ID`` project. +Reading from a bucket should be possible from the ``GCP_PROJECT_ID`` project. + +The bucket and dataset should be located in the same region. + +If you want to run this example, you must do the following: + +1. Create Google Cloud project and enable the BigQuery API +2. Create the Firebase project +3. Create a bucket in the same location as the Firebase project +4. Grant Firebase admin account permissions to manage BigQuery. This is required to create a dataset. +5. Create a bucket in Firebase project and +6. Give read/write access for Firebase admin to bucket to step no. 5. +7. Create collection in the Firestore database. +""" + +import os +from urllib.parse import urlparse + +from airflow import models +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateExternalTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryExecuteQueryOperator, +) +from airflow.providers.google.firebase.operators.firestore import ( + CloudFirestoreExportDatabaseOperator, +) +from airflow.utils import dates + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-gcp-project") +FIRESTORE_PROJECT_ID = os.environ.get( + "G_FIRESTORE_PROJECT_ID", "example-firebase-project" +) + +EXPORT_DESTINATION_URL = os.environ.get( + "GCP_FIRESTORE_ARCHIVE_URL", "gs://airflow-firestore/namespace/" +) +BUCKET_NAME = urlparse(EXPORT_DESTINATION_URL).hostname +EXPORT_PREFIX = urlparse(EXPORT_DESTINATION_URL).path + +EXPORT_COLLECTION_ID = os.environ.get( + "GCP_FIRESTORE_COLLECTION_ID", "firestore_collection_id" +) +DATASET_NAME = os.environ.get("GCP_FIRESTORE_DATASET_NAME", "test_firestore_export") +DATASET_LOCATION = os.environ.get("GCP_FIRESTORE_DATASET_LOCATION", "EU") + +if BUCKET_NAME is None: + raise ValueError( + "Bucket name is required. Please set GCP_FIRESTORE_ARCHIVE_URL env variable." + ) + +with models.DAG( + "example_google_firestore", + default_args=dict(start_date=dates.days_ago(1)), + schedule_interval=None, + tags=["example"], +) as dag: + # [START howto_operator_export_database_to_gcs] + export_database_to_gcs = CloudFirestoreExportDatabaseOperator( + task_id="export_database_to_gcs", + project_id=FIRESTORE_PROJECT_ID, + body={ + "outputUriPrefix": EXPORT_DESTINATION_URL, + "collectionIds": [EXPORT_COLLECTION_ID], + }, + ) + # [END howto_operator_export_database_to_gcs] + + create_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create_dataset", + dataset_id=DATASET_NAME, + location=DATASET_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", + dataset_id=DATASET_NAME, + project_id=GCP_PROJECT_ID, + delete_contents=True, + ) + + # [START howto_operator_create_external_table_multiple_types] + create_external_table_multiple_types = BigQueryCreateExternalTableOperator( + task_id="create_external_table", + bucket=BUCKET_NAME, + source_objects=[ + f"{EXPORT_PREFIX}/all_namespaces/kind_{EXPORT_COLLECTION_ID}" + f"/all_namespaces_kind_{EXPORT_COLLECTION_ID}.export_metadata" + ], + source_format="DATASTORE_BACKUP", + destination_project_dataset_table=f"{GCP_PROJECT_ID}.{DATASET_NAME}.firestore_data", + ) + # [END howto_operator_create_external_table_multiple_types] + + read_data_from_gcs_multiple_types = BigQueryExecuteQueryOperator( + task_id="execute_query", + sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.firestore_data`", + use_legacy_sql=False, + ) + + # Firestore + export_database_to_gcs >> create_dataset + + # BigQuery + create_dataset >> create_external_table_multiple_types + create_external_table_multiple_types >> read_data_from_gcs_multiple_types + read_data_from_gcs_multiple_types >> delete_dataset diff --git a/reference/providers/google/firebase/hooks/__init__.py b/reference/providers/google/firebase/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/firebase/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/firebase/hooks/firestore.py b/reference/providers/google/firebase/hooks/firestore.py new file mode 100644 index 0000000..43b8bae --- /dev/null +++ b/reference/providers/google/firebase/hooks/firestore.py @@ -0,0 +1,157 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Google Cloud Firestore service""" + +import time +from typing import Any, Dict, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import build, build_from_document + +# Time to sleep between active checks of the operation results +TIME_TO_SLEEP_IN_SECONDS = 5 + + +class CloudFirestoreHook(GoogleBaseHook): + """ + Hook for the Google Firestore APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + + :param api_version: API version used (for example v1 or v1beta1). + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + _conn = None # type: Optional[Any] + + def __init__( + self, + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.api_version = api_version + + def get_conn(self): + """ + Retrieves the connection to Cloud Firestore. + + :return: Google Cloud Firestore services object. + """ + if not self._conn: + http_authorized = self._authorize() + # We cannot use an Authorized Client to retrieve discovery document due to an error in the API. + # When the authorized customer will send a request to the address below + # https://www.googleapis.com/discovery/v1/apis/firestore/v1/rest + # then it will get the message below: + # > Request contains an invalid argument. + # At the same time, the Non-Authorized Client has no problems. + non_authorized_conn = build( + "firestore", self.api_version, cache_discovery=False + ) + self._conn = build_from_document( + non_authorized_conn._rootDesc, + http=http_authorized, # pylint: disable=protected-access + ) + return self._conn + + @GoogleBaseHook.fallback_to_default_project_id + def export_documents( + self, + body: Dict, + database_id: str = "(default)", + project_id: Optional[str] = None, + ) -> None: + """ + Starts a export with the specified configuration. + + :param database_id: The Database ID. + :type database_id: str + :param body: The request body. + See: + https://firebase.google.com/docs/firestore/reference/rest/v1beta1/projects.databases/exportDocuments + :type body: dict + :param project_id: Optional, Google Cloud Project project_id where the database belongs. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + """ + service = self.get_conn() + + name = f"projects/{project_id}/databases/{database_id}" + + operation = ( + service.projects() # pylint: disable=no-member + .databases() + .exportDocuments(name=name, body=body) + .execute(num_retries=self.num_retries) + ) + + self._wait_for_operation_to_complete(operation["name"]) + + def _wait_for_operation_to_complete(self, operation_name: str) -> None: + """ + Waits for the named operation to complete - checks status of the + asynchronous call. + + :param operation_name: The name of the operation. + :type operation_name: str + :return: The response returned by the operation. + :rtype: dict + :exception: AirflowException in case error is returned. + """ + service = self.get_conn() + while True: + operation_response = ( + service.projects() # pylint: disable=no-member + .databases() + .operations() + .get(name=operation_name) + .execute(num_retries=self.num_retries) + ) + if operation_response.get("done"): + response = operation_response.get("response") + error = operation_response.get("error") + # Note, according to documentation always either response or error is + # set when "done" == True + if error: + raise AirflowException(str(error)) + return response + time.sleep(TIME_TO_SLEEP_IN_SECONDS) diff --git a/reference/providers/google/firebase/operators/__init__.py b/reference/providers/google/firebase/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/firebase/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/firebase/operators/firestore.py b/reference/providers/google/firebase/operators/firestore.py new file mode 100644 index 0000000..d551fbe --- /dev/null +++ b/reference/providers/google/firebase/operators/firestore.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.firebase.hooks.firestore import CloudFirestoreHook +from airflow.utils.decorators import apply_defaults + + +class CloudFirestoreExportDatabaseOperator(BaseOperator): + """ + Exports a copy of all or a subset of documents from Google Cloud Firestore to another storage system, + such as Google Cloud Storage. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudFirestoreExportDatabaseOperator` + + :param database_id: The Database ID. + :type database_id: str + :param body: The request body. + See: + https://firebase.google.com/docs/firestore/reference/rest/v1beta1/projects.databases/exportDocuments + :type body: dict + :param project_id: ID of the Google Cloud project if None then + default project_id is used. + :type project_id: str + :param gcp_conn_id: The connection ID to use to connect to Google Cloud. + :type gcp_conn_id: str + :param api_version: API version used (for example v1 or v1beta1). + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + body: Dict, + database_id: str = "(default)", + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.database_id = database_id + self.body = body + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self._validate_inputs() + self.impersonation_chain = impersonation_chain + + def _validate_inputs(self) -> None: + if not self.body: + raise AirflowException("The required parameter 'body' is missing") + + def execute(self, context): + hook = CloudFirestoreHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + return hook.export_documents( + database_id=self.database_id, body=self.body, project_id=self.project_id + ) diff --git a/reference/providers/google/leveldb/__init__.py b/reference/providers/google/leveldb/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/leveldb/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/leveldb/example_dags/__init__.py b/reference/providers/google/leveldb/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/google/leveldb/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/leveldb/example_dags/example_leveldb.py b/reference/providers/google/leveldb/example_dags/example_leveldb.py new file mode 100644 index 0000000..01641a5 --- /dev/null +++ b/reference/providers/google/leveldb/example_dags/example_leveldb.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example use of LevelDB operators. +""" + +from airflow import models +from airflow.providers.google.leveldb.operators.leveldb import LevelDBOperator +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", +} + +with models.DAG( + "example_leveldb", + default_args=default_args, + start_date=days_ago(2), + schedule_interval=None, + tags=["example"], +) as dag: + # [START howto_operator_leveldb_get_key] + get_key_leveldb_task = LevelDBOperator( + task_id="get_key_leveldb", + leveldb_conn_id="leveldb_default", + command="get", + key=b"key", + dag=dag, + ) + # [END howto_operator_leveldb_get_key] + # [START howto_operator_leveldb_put_key] + put_key_leveldb_task = LevelDBOperator( + task_id="put_key_leveldb", + leveldb_conn_id="leveldb_default", + command="put", + key=b"another_key", + value=b"another_value", + dag=dag, + ) + # [END howto_operator_leveldb_put_key] + get_key_leveldb_task >> put_key_leveldb_task diff --git a/reference/providers/google/leveldb/hooks/__init__.py b/reference/providers/google/leveldb/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/leveldb/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/leveldb/hooks/leveldb.py b/reference/providers/google/leveldb/hooks/leveldb.py new file mode 100644 index 0000000..ef6eb42 --- /dev/null +++ b/reference/providers/google/leveldb/hooks/leveldb.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Level DB""" +from typing import List, Optional + +import plyvel +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from plyvel import DB + + +class LevelDBHookException(AirflowException): + """Exception specific for LevelDB""" + + +class LevelDBHook(BaseHook): + """ + Plyvel Wrapper to Interact With LevelDB Database + `LevelDB Connection Documentation `__ + """ + + conn_name_attr = "leveldb_conn_id" + default_conn_name = "leveldb_default" + conn_type = "leveldb" + hook_name = "LevelDB" + + def __init__(self, leveldb_conn_id: str = default_conn_name): + super().__init__() + self.leveldb_conn_id = leveldb_conn_id + self.connection = self.get_connection(leveldb_conn_id) + self.db = None + + def get_conn( + self, name: str = "/tmp/testdb/", create_if_missing: bool = False, **kwargs + ) -> DB: + """ + Creates `Plyvel DB `__ + + :param name: path to create database e.g. `/tmp/testdb/`) + :type name: str + :param create_if_missing: whether a new database should be created if needed + :type create_if_missing: bool + :param kwargs: other options of creation plyvel.DB. See more in the link above. + :type kwargs: Dict[str, Any] + :returns: DB + :rtype: plyvel.DB + """ + if self.db is not None: + return self.db + self.db = plyvel.DB(name=name, create_if_missing=create_if_missing, **kwargs) + return self.db + + def close_conn(self) -> None: + """Closes connection""" + db = self.db + if db is not None: + db.close() + self.db = None + + def run( + self, + command: str, + key: bytes, + value: bytes = None, + keys: List[bytes] = None, + values: List[bytes] = None, + ) -> Optional[bytes]: + """ + Execute operation with leveldb + + :param command: command of plyvel(python wrap for leveldb) for DB object e.g. + ``"put"``, ``"get"``, ``"delete"``, ``"write_batch"``. + :type command: str + :param key: key for command(put,get,delete) execution(, e.g. ``b'key'``, ``b'another-key'``) + :type key: bytes + :param value: value for command(put) execution(bytes, e.g. ``b'value'``, ``b'another-value'``) + :type value: bytes + :param keys: keys for command(write_batch) execution(List[bytes], e.g. ``[b'key', b'another-key'])`` + :type keys: List[bytes] + :param values: values for command(write_batch) execution e.g. ``[b'value'``, ``b'another-value']`` + :type values: List[bytes] + :returns: value from get or None + :rtype: Optional[bytes] + """ + if command == "put": + return self.put(key, value) + elif command == "get": + return self.get(key) + elif command == "delete": + return self.delete(key) + elif command == "write_batch": + return self.write_batch(keys, values) + else: + raise LevelDBHookException("Unknown command for LevelDB hook") + + def put(self, key: bytes, value: bytes): + """ + Put a single value into a leveldb db by key + + :param key: key for put execution, e.g. ``b'key'``, ``b'another-key'`` + :type key: bytes + :param value: value for put execution e.g. ``b'value'``, ``b'another-value'`` + :type value: bytes + """ + self.db.put(key, value) + + def get(self, key: bytes) -> bytes: + """ + Get a single value into a leveldb db by key + + :param key: key for get execution, e.g. ``b'key'``, ``b'another-key'`` + :type key: bytes + :returns: value of key from db.get + :rtype: bytes + """ + return self.db.get(key) + + def delete(self, key: bytes): + """ + Delete a single value in a leveldb db by key. + + :param key: key for delete execution, e.g. ``b'key'``, ``b'another-key'`` + :type key: bytes + """ + self.db.delete(key) + + def write_batch(self, keys: List[bytes], values: List[bytes]): + """ + Write batch of values in a leveldb db by keys + + :param keys: keys for write_batch execution e.g. ``[b'key', b'another-key']`` + :type keys: List[bytes] + :param values: values for write_batch execution e.g. ``[b'value', b'another-value']`` + :type values: List[bytes] + """ + with self.db.write_batch() as batch: + for i, key in enumerate(keys): + batch.put(key, values[i]) diff --git a/reference/providers/google/leveldb/operators/__init__.py b/reference/providers/google/leveldb/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/leveldb/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/leveldb/operators/leveldb.py b/reference/providers/google/leveldb/operators/leveldb.py new file mode 100644 index 0000000..9317257 --- /dev/null +++ b/reference/providers/google/leveldb/operators/leveldb.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, List, Optional + +from airflow.models import BaseOperator +from airflow.providers.google.leveldb.hooks.leveldb import LevelDBHook +from airflow.utils.decorators import apply_defaults + + +class LevelDBOperator(BaseOperator): + """ + Execute command in LevelDB + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:LevelDBOperator` + + :param command: command of plyvel(python wrap for leveldb) for DB object e.g. + ``"put"``, ``"get"``, ``"delete"``, ``"write_batch"``. + :type command: str + :param key: key for command(put,get,delete) execution(, e.g. ``b'key'``, ``b'another-key'``) + :type key: bytes + :param value: value for command(put) execution(bytes, e.g. ``b'value'``, ``b'another-value'``) + :type value: bytes + :param keys: keys for command(write_batch) execution(List[bytes], e.g. ``[b'key', b'another-key'])`` + :type keys: List[bytes] + :param values: values for command(write_batch) execution e.g. ``[b'value'``, ``b'another-value']`` + :type values: List[bytes] + :param leveldb_conn_id: + :type leveldb_conn_id: str + :param create_if_missing: whether a new database should be created if needed + :type create_if_missing: bool + :param create_db_extra_options: extra options of creation LevelDBOperator. See more in the link below + `Plyvel DB `__ + :type create_db_extra_options: Optional[Dict[str, Any]] + """ + + @apply_defaults + def __init__( + self, + *, + command: str, + key: bytes, + value: bytes = None, + keys: List[bytes] = None, + values: List[bytes] = None, + leveldb_conn_id: str = "leveldb_default", + name: str = "/tmp/testdb/", + create_if_missing: bool = True, + create_db_extra_options: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.command = command + self.key = key + self.value = value + self.keys = keys + self.values = values + self.leveldb_conn_id = leveldb_conn_id + self.name = name + self.create_if_missing = create_if_missing + self.create_db_extra_options = create_db_extra_options or {} + + def execute(self, context) -> Optional[str]: + """ + Execute command in LevelDB + + :returns: value from get(str, not bytes, to prevent error in json.dumps in serialize_value in xcom.py) + or None(Optional[str]) + :rtype: Optional[str] + """ + leveldb_hook = LevelDBHook(leveldb_conn_id=self.leveldb_conn_id) + leveldb_hook.get_conn( + name=self.name, + create_if_missing=self.create_if_missing, + **self.create_db_extra_options, + ) + value = leveldb_hook.run( + command=self.command, + key=self.key, + value=self.value, + keys=self.keys, + values=self.values, + ) + self.log.info("Done. Returned value was: %s", str(value)) + leveldb_hook.close_conn() + value = value if value is None else value.decode() + return value diff --git a/reference/providers/google/marketing_platform/__init__.py b/reference/providers/google/marketing_platform/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/marketing_platform/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/marketing_platform/example_dags/__init__.py b/reference/providers/google/marketing_platform/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/marketing_platform/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/marketing_platform/example_dags/example_analytics.py b/reference/providers/google/marketing_platform/example_dags/example_analytics.py new file mode 100644 index 0000000..5bdd3a0 --- /dev/null +++ b/reference/providers/google/marketing_platform/example_dags/example_analytics.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that shows how to use Google Analytics 360. +""" +import os + +from airflow import models +from airflow.providers.google.marketing_platform.operators.analytics import ( + GoogleAnalyticsDataImportUploadOperator, + GoogleAnalyticsDeletePreviousDataUploadsOperator, + GoogleAnalyticsGetAdsLinkOperator, + GoogleAnalyticsListAccountsOperator, + GoogleAnalyticsModifyFileHeadersDataImportOperator, + GoogleAnalyticsRetrieveAdsLinksListOperator, +) +from airflow.utils import dates + +ACCOUNT_ID = os.environ.get("GA_ACCOUNT_ID", "123456789") + +BUCKET = os.environ.get("GMP_ANALYTICS_BUCKET", "test-airflow-analytics-bucket") +BUCKET_FILENAME = "data.csv" +WEB_PROPERTY_ID = os.environ.get("GA_WEB_PROPERTY", "UA-12345678-1") +WEB_PROPERTY_AD_WORDS_LINK_ID = os.environ.get( + "GA_WEB_PROPERTY_AD_WORDS_LINK_ID", "rQafFTPOQdmkx4U-fxUfhj" +) +DATA_ID = "kjdDu3_tQa6n8Q1kXFtSmg" + +with models.DAG( + "example_google_analytics", + schedule_interval=None, # Override to match your needs, + start_date=dates.days_ago(1), +) as dag: + # [START howto_marketing_platform_list_accounts_operator] + list_account = GoogleAnalyticsListAccountsOperator(task_id="list_account") + # [END howto_marketing_platform_list_accounts_operator] + + # [START howto_marketing_platform_get_ads_link_operator] + get_ad_words_link = GoogleAnalyticsGetAdsLinkOperator( + web_property_ad_words_link_id=WEB_PROPERTY_AD_WORDS_LINK_ID, + web_property_id=WEB_PROPERTY_ID, + account_id=ACCOUNT_ID, + task_id="get_ad_words_link", + ) + # [END howto_marketing_platform_get_ads_link_operator] + + # [START howto_marketing_platform_retrieve_ads_links_list_operator] + list_ad_words_link = GoogleAnalyticsRetrieveAdsLinksListOperator( + task_id="list_ad_link", account_id=ACCOUNT_ID, web_property_id=WEB_PROPERTY_ID + ) + # [END howto_marketing_platform_retrieve_ads_links_list_operator] + + upload = GoogleAnalyticsDataImportUploadOperator( + task_id="upload", + storage_bucket=BUCKET, + storage_name_object=BUCKET_FILENAME, + account_id=ACCOUNT_ID, + web_property_id=WEB_PROPERTY_ID, + custom_data_source_id=DATA_ID, + ) + + delete = GoogleAnalyticsDeletePreviousDataUploadsOperator( + task_id="delete", + account_id=ACCOUNT_ID, + web_property_id=WEB_PROPERTY_ID, + custom_data_source_id=DATA_ID, + ) + + transform = GoogleAnalyticsModifyFileHeadersDataImportOperator( + task_id="transform", + storage_bucket=BUCKET, + storage_name_object=BUCKET_FILENAME, + ) + + upload >> [delete, transform] diff --git a/reference/providers/google/marketing_platform/example_dags/example_campaign_manager.py b/reference/providers/google/marketing_platform/example_dags/example_campaign_manager.py new file mode 100644 index 0000000..ee8bf89 --- /dev/null +++ b/reference/providers/google/marketing_platform/example_dags/example_campaign_manager.py @@ -0,0 +1,163 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that shows how to use CampaignManager. +""" +import os +import time + +from airflow import models +from airflow.providers.google.marketing_platform.operators.campaign_manager import ( + GoogleCampaignManagerBatchInsertConversionsOperator, + GoogleCampaignManagerBatchUpdateConversionsOperator, + GoogleCampaignManagerDeleteReportOperator, + GoogleCampaignManagerDownloadReportOperator, + GoogleCampaignManagerInsertReportOperator, + GoogleCampaignManagerRunReportOperator, +) +from airflow.providers.google.marketing_platform.sensors.campaign_manager import ( + GoogleCampaignManagerReportSensor, +) +from airflow.utils import dates +from airflow.utils.state import State + +PROFILE_ID = os.environ.get("MARKETING_PROFILE_ID", "123456789") +FLOODLIGHT_ACTIVITY_ID = int(os.environ.get("FLOODLIGHT_ACTIVITY_ID", 12345)) +FLOODLIGHT_CONFIGURATION_ID = int(os.environ.get("FLOODLIGHT_CONFIGURATION_ID", 12345)) +ENCRYPTION_ENTITY_ID = int(os.environ.get("ENCRYPTION_ENTITY_ID", 12345)) +DEVICE_ID = os.environ.get("DEVICE_ID", "12345") +BUCKET = os.environ.get("MARKETING_BUCKET", "test-cm-bucket") +REPORT_NAME = "test-report" +REPORT = { + "type": "STANDARD", + "name": REPORT_NAME, + "criteria": { + "dateRange": { + "kind": "dfareporting#dateRange", + "relativeDateRange": "LAST_365_DAYS", + }, + "dimensions": [ + {"kind": "dfareporting#sortedDimension", "name": "dfa:advertiser"} + ], + "metricNames": ["dfa:activeViewImpressionDistributionViewable"], + }, +} + +CONVERSION = { + "kind": "dfareporting#conversion", + "floodlightActivityId": FLOODLIGHT_ACTIVITY_ID, + "floodlightConfigurationId": FLOODLIGHT_CONFIGURATION_ID, + "mobileDeviceId": DEVICE_ID, + "ordinal": "0", + "quantity": 42, + "value": 123.4, + "timestampMicros": int(time.time()) * 1000000, + "customVariables": [ + { + "kind": "dfareporting#customFloodlightVariable", + "type": "U4", + "value": "value", + } + ], +} + +CONVERSION_UPDATE = { + "kind": "dfareporting#conversion", + "floodlightActivityId": FLOODLIGHT_ACTIVITY_ID, + "floodlightConfigurationId": FLOODLIGHT_CONFIGURATION_ID, + "mobileDeviceId": DEVICE_ID, + "ordinal": "0", + "quantity": 42, + "value": 123.4, +} + +with models.DAG( + "example_campaign_manager", + schedule_interval=None, # Override to match your needs, + start_date=dates.days_ago(1), +) as dag: + # [START howto_campaign_manager_insert_report_operator] + create_report = GoogleCampaignManagerInsertReportOperator( + profile_id=PROFILE_ID, report=REPORT, task_id="create_report" + ) + report_id = "{{ task_instance.xcom_pull('create_report')['id'] }}" + # [END howto_campaign_manager_insert_report_operator] + + # [START howto_campaign_manager_run_report_operator] + run_report = GoogleCampaignManagerRunReportOperator( + profile_id=PROFILE_ID, report_id=report_id, task_id="run_report" + ) + file_id = "{{ task_instance.xcom_pull('run_report')['id'] }}" + # [END howto_campaign_manager_run_report_operator] + + # [START howto_campaign_manager_wait_for_operation] + wait_for_report = GoogleCampaignManagerReportSensor( + task_id="wait_for_report", + profile_id=PROFILE_ID, + report_id=report_id, + file_id=file_id, + ) + # [END howto_campaign_manager_wait_for_operation] + + # [START howto_campaign_manager_get_report_operator] + get_report = GoogleCampaignManagerDownloadReportOperator( + task_id="get_report", + profile_id=PROFILE_ID, + report_id=report_id, + file_id=file_id, + report_name="test_report.csv", + bucket_name=BUCKET, + ) + # [END howto_campaign_manager_get_report_operator] + + # [START howto_campaign_manager_delete_report_operator] + delete_report = GoogleCampaignManagerDeleteReportOperator( + profile_id=PROFILE_ID, report_name=REPORT_NAME, task_id="delete_report" + ) + # [END howto_campaign_manager_delete_report_operator] + + create_report >> run_report >> wait_for_report >> get_report >> delete_report + + # [START howto_campaign_manager_insert_conversions] + insert_conversion = GoogleCampaignManagerBatchInsertConversionsOperator( + task_id="insert_conversion", + profile_id=PROFILE_ID, + conversions=[CONVERSION], + encryption_source="AD_SERVING", + encryption_entity_type="DCM_ADVERTISER", + encryption_entity_id=ENCRYPTION_ENTITY_ID, + ) + # [END howto_campaign_manager_insert_conversions] + + # [START howto_campaign_manager_update_conversions] + update_conversion = GoogleCampaignManagerBatchUpdateConversionsOperator( + task_id="update_conversion", + profile_id=PROFILE_ID, + conversions=[CONVERSION_UPDATE], + encryption_source="AD_SERVING", + encryption_entity_type="DCM_ADVERTISER", + encryption_entity_id=ENCRYPTION_ENTITY_ID, + max_failed_updates=1, + ) + # [END howto_campaign_manager_update_conversions] + + insert_conversion >> update_conversion + +if __name__ == "__main__": + dag.clear(dag_run_state=State.NONE) + dag.run() diff --git a/reference/providers/google/marketing_platform/example_dags/example_display_video.py b/reference/providers/google/marketing_platform/example_dags/example_display_video.py new file mode 100644 index 0000000..780c6ad --- /dev/null +++ b/reference/providers/google/marketing_platform/example_dags/example_display_video.py @@ -0,0 +1,215 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that shows how to use DisplayVideo. +""" +import os +from typing import Dict + +from airflow import models +from airflow.providers.google.cloud.transfers.gcs_to_bigquery import ( + GCSToBigQueryOperator, +) +from airflow.providers.google.marketing_platform.hooks.display_video import ( + GoogleDisplayVideo360Hook, +) +from airflow.providers.google.marketing_platform.operators.display_video import ( + GoogleDisplayVideo360CreateReportOperator, + GoogleDisplayVideo360CreateSDFDownloadTaskOperator, + GoogleDisplayVideo360DeleteReportOperator, + GoogleDisplayVideo360DownloadLineItemsOperator, + GoogleDisplayVideo360DownloadReportOperator, + GoogleDisplayVideo360RunReportOperator, + GoogleDisplayVideo360SDFtoGCSOperator, + GoogleDisplayVideo360UploadLineItemsOperator, +) +from airflow.providers.google.marketing_platform.sensors.display_video import ( + GoogleDisplayVideo360GetSDFDownloadOperationSensor, + GoogleDisplayVideo360ReportSensor, +) +from airflow.utils import dates + +# [START howto_display_video_env_variables] +BUCKET = os.environ.get("GMP_DISPLAY_VIDEO_BUCKET", "gs://test-display-video-bucket") +ADVERTISER_ID = os.environ.get("GMP_ADVERTISER_ID", 1234567) +OBJECT_NAME = os.environ.get("GMP_OBJECT_NAME", "files/report.csv") +PATH_TO_UPLOAD_FILE = os.environ.get( + "GCP_GCS_PATH_TO_UPLOAD_FILE", "test-gcs-example.txt" +) +PATH_TO_SAVED_FILE = os.environ.get( + "GCP_GCS_PATH_TO_SAVED_FILE", "test-gcs-example-download.txt" +) +BUCKET_FILE_LOCATION = PATH_TO_UPLOAD_FILE.rpartition("/")[-1] +SDF_VERSION = os.environ.get("GMP_SDF_VERSION", "SDF_VERSION_5_1") +BQ_DATA_SET = os.environ.get("GMP_BQ_DATA_SET", "airflow_test") +GMP_PARTNER_ID = os.environ.get("GMP_PARTNER_ID", 123) +ENTITY_TYPE = os.environ.get("GMP_ENTITY_TYPE", "LineItem") +ERF_SOURCE_OBJECT = GoogleDisplayVideo360Hook.erf_uri(GMP_PARTNER_ID, ENTITY_TYPE) + +REPORT = { + "kind": "doubleclickbidmanager#query", + "metadata": { + "title": "Polidea Test Report", + "dataRange": "LAST_7_DAYS", + "format": "CSV", + "sendNotification": False, + }, + "params": { + "type": "TYPE_GENERAL", + "groupBys": ["FILTER_DATE", "FILTER_PARTNER"], + "filters": [{"type": "FILTER_PARTNER", "value": 1486931}], + "metrics": ["METRIC_IMPRESSIONS", "METRIC_CLICKS"], + "includeInviteData": True, + }, + "schedule": {"frequency": "ONE_TIME"}, +} + +PARAMS = {"dataRange": "LAST_14_DAYS", "timezoneCode": "America/New_York"} + +CREATE_SDF_DOWNLOAD_TASK_BODY_REQUEST: Dict = { + "version": SDF_VERSION, + "advertiserId": ADVERTISER_ID, + "inventorySourceFilter": {"inventorySourceIds": []}, +} + +DOWNLOAD_LINE_ITEMS_REQUEST: Dict = { + "filterType": ADVERTISER_ID, + "format": "CSV", + "fileSpec": "EWF", +} +# [END howto_display_video_env_variables] + +with models.DAG( + "example_display_video", + schedule_interval=None, # Override to match your needs, + start_date=dates.days_ago(1), +) as dag1: + # [START howto_google_display_video_createquery_report_operator] + create_report = GoogleDisplayVideo360CreateReportOperator( + body=REPORT, task_id="create_report" + ) + report_id = "{{ task_instance.xcom_pull('create_report', key='report_id') }}" + # [END howto_google_display_video_createquery_report_operator] + + # [START howto_google_display_video_runquery_report_operator] + run_report = GoogleDisplayVideo360RunReportOperator( + report_id=report_id, params=PARAMS, task_id="run_report" + ) + # [END howto_google_display_video_runquery_report_operator] + + # [START howto_google_display_video_wait_report_operator] + wait_for_report = GoogleDisplayVideo360ReportSensor( + task_id="wait_for_report", report_id=report_id + ) + # [END howto_google_display_video_wait_report_operator] + + # [START howto_google_display_video_getquery_report_operator] + get_report = GoogleDisplayVideo360DownloadReportOperator( + report_id=report_id, + task_id="get_report", + bucket_name=BUCKET, + report_name="test1.csv", + ) + # [END howto_google_display_video_getquery_report_operator] + + # [START howto_google_display_video_deletequery_report_operator] + delete_report = GoogleDisplayVideo360DeleteReportOperator( + report_id=report_id, task_id="delete_report" + ) + # [END howto_google_display_video_deletequery_report_operator] + + create_report >> run_report >> wait_for_report >> get_report >> delete_report + +with models.DAG( + "example_display_video_misc", + schedule_interval=None, # Override to match your needs, + start_date=dates.days_ago(1), +) as dag2: + # [START howto_google_display_video_upload_multiple_entity_read_files_to_big_query] + upload_erf_to_bq = GCSToBigQueryOperator( + task_id="upload_erf_to_bq", + bucket=BUCKET, + source_objects=ERF_SOURCE_OBJECT, + destination_project_dataset_table=f"{BQ_DATA_SET}.gcs_to_bq_table", + write_disposition="WRITE_TRUNCATE", + ) + # [END howto_google_display_video_upload_multiple_entity_read_files_to_big_query] + + # [START howto_google_display_video_download_line_items_operator] + download_line_items = GoogleDisplayVideo360DownloadLineItemsOperator( + task_id="download_line_items", + request_body=DOWNLOAD_LINE_ITEMS_REQUEST, + bucket_name=BUCKET, + object_name=OBJECT_NAME, + gzip=False, + ) + # [END howto_google_display_video_download_line_items_operator] + + # [START howto_google_display_video_upload_line_items_operator] + upload_line_items = GoogleDisplayVideo360UploadLineItemsOperator( + task_id="upload_line_items", + bucket_name=BUCKET, + object_name=BUCKET_FILE_LOCATION, + ) + # [END howto_google_display_video_upload_line_items_operator] + +with models.DAG( + "example_display_video_sdf", + schedule_interval=None, # Override to match your needs, + start_date=dates.days_ago(1), +) as dag3: + # [START howto_google_display_video_create_sdf_download_task_operator] + create_sdf_download_task = GoogleDisplayVideo360CreateSDFDownloadTaskOperator( + task_id="create_sdf_download_task", + body_request=CREATE_SDF_DOWNLOAD_TASK_BODY_REQUEST, + ) + operation_name = '{{ task_instance.xcom_pull("create_sdf_download_task")["name"] }}' + # [END howto_google_display_video_create_sdf_download_task_operator] + + # [START howto_google_display_video_wait_for_operation_sensor] + wait_for_operation = GoogleDisplayVideo360GetSDFDownloadOperationSensor( + task_id="wait_for_operation", + operation_name=operation_name, + ) + # [END howto_google_display_video_wait_for_operation_sensor] + + # [START howto_google_display_video_save_sdf_in_gcs_operator] + save_sdf_in_gcs = GoogleDisplayVideo360SDFtoGCSOperator( + task_id="save_sdf_in_gcs", + operation_name=operation_name, + bucket_name=BUCKET, + object_name=BUCKET_FILE_LOCATION, + gzip=False, + ) + # [END howto_google_display_video_save_sdf_in_gcs_operator] + + # [START howto_google_display_video_gcs_to_big_query_operator] + upload_sdf_to_big_query = GCSToBigQueryOperator( + task_id="upload_sdf_to_big_query", + bucket=BUCKET, + source_objects=['{{ task_instance.xcom_pull("upload_sdf_to_bigquery")}}'], + destination_project_dataset_table=f"{BQ_DATA_SET}.gcs_to_bq_table", + schema_fields=[ + {"name": "name", "type": "STRING", "mode": "NULLABLE"}, + {"name": "post_abbr", "type": "STRING", "mode": "NULLABLE"}, + ], + write_disposition="WRITE_TRUNCATE", + ) + # [END howto_google_display_video_gcs_to_big_query_operator] + + create_sdf_download_task >> wait_for_operation >> save_sdf_in_gcs >> upload_sdf_to_big_query diff --git a/reference/providers/google/marketing_platform/example_dags/example_search_ads.py b/reference/providers/google/marketing_platform/example_dags/example_search_ads.py new file mode 100644 index 0000000..10e0e40 --- /dev/null +++ b/reference/providers/google/marketing_platform/example_dags/example_search_ads.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that shows how to use SearchAds. +""" +import os + +from airflow import models +from airflow.providers.google.marketing_platform.operators.search_ads import ( + GoogleSearchAdsDownloadReportOperator, + GoogleSearchAdsInsertReportOperator, +) +from airflow.providers.google.marketing_platform.sensors.search_ads import ( + GoogleSearchAdsReportSensor, +) +from airflow.utils import dates + +# [START howto_search_ads_env_variables] +AGENCY_ID = os.environ.get("GMP_AGENCY_ID") +ADVERTISER_ID = os.environ.get("GMP_ADVERTISER_ID") +GCS_BUCKET = os.environ.get("GMP_GCS_BUCKET", "test-cm-bucket") + +REPORT = { + "reportScope": {"agencyId": AGENCY_ID, "advertiserId": ADVERTISER_ID}, + "reportType": "account", + "columns": [{"columnName": "agency"}, {"columnName": "lastModifiedTimestamp"}], + "includeRemovedEntities": False, + "statisticsCurrency": "usd", + "maxRowsPerFile": 1000000, + "downloadFormat": "csv", +} +# [END howto_search_ads_env_variables] + +with models.DAG( + "example_search_ads", + schedule_interval=None, # Override to match your needs, + start_date=dates.days_ago(1), +) as dag: + # [START howto_search_ads_generate_report_operator] + generate_report = GoogleSearchAdsInsertReportOperator( + report=REPORT, task_id="generate_report" + ) + # [END howto_search_ads_generate_report_operator] + + # [START howto_search_ads_get_report_id] + report_id = "{{ task_instance.xcom_pull('generate_report', key='report_id') }}" + # [END howto_search_ads_get_report_id] + + # [START howto_search_ads_get_report_operator] + wait_for_report = GoogleSearchAdsReportSensor( + report_id=report_id, task_id="wait_for_report" + ) + # [END howto_search_ads_get_report_operator] + + # [START howto_search_ads_getfile_report_operator] + download_report = GoogleSearchAdsDownloadReportOperator( + report_id=report_id, bucket_name=GCS_BUCKET, task_id="download_report" + ) + # [END howto_search_ads_getfile_report_operator] + + generate_report >> wait_for_report >> download_report diff --git a/reference/providers/google/marketing_platform/hooks/__init__.py b/reference/providers/google/marketing_platform/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/marketing_platform/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/marketing_platform/hooks/analytics.py b/reference/providers/google/marketing_platform/hooks/analytics.py new file mode 100644 index 0000000..95616d4 --- /dev/null +++ b/reference/providers/google/marketing_platform/hooks/analytics.py @@ -0,0 +1,229 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, List, Optional + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import Resource, build +from googleapiclient.http import MediaFileUpload + + +class GoogleAnalyticsHook(GoogleBaseHook): + """Hook for Google Analytics 360.""" + + def __init__(self, api_version: str = "v3", *args, **kwargs): + super().__init__(*args, **kwargs) + self.api_version = api_version + self._conn = None + + def _paginate( + self, re# Resource, list_args: Optional[Dict[str, Any]] = None + ) -> List[dict]: + list_args = list_args or {} + result: List[dict] = [] + while True: + # start index has value 1 + request = resource.list( + start_index=len(result) + 1, **list_args + ) # pylint: disable=no-member + response = request.execute(num_retries=self.num_retries) + result.extend(response.get("items", [])) + # result is the number of fetched links from Analytics + # when all links will be added to the result + # the loop will break + if response["totalResults"] <= len(result): + break + return result + + def get_conn(self) -> Re# + """Retrieves connection to Google Analytics 360.""" + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "analytics", + self.api_version, + http=http_authorized, + cache_discovery=False, + ) + return self._conn + + def list_accounts(self) -> List[Dict[str, Any]]: + """Lists accounts list from Google Analytics 360.""" + self.log.info("Retrieving accounts list...") + conn = self.get_conn() + accounts = conn.management().accounts() # pylint: disable=no-member + result = self._paginate(accounts) + return result + + def get_ad_words_link( + self, account_id: str, web_property_id: str, web_property_ad_words_link_id: str + ) -> Dict[str, Any]: + """ + Returns a web property-Google Ads link to which the user has access. + + :param account_id: ID of the account which the given web property belongs to. + :type account_id: string + :param web_property_id: Web property-Google Ads link UA-string. + :type web_property_id: string + :param web_property_ad_words_link_id: to retrieve the Google Ads link for. + :type web_property_ad_words_link_id: string + + :returns: web property-Google Ads + :rtype: Dict + """ + self.log.info("Retrieving ad words links...") + ad_words_link = ( + self.get_conn() # pylint: disable=no-member + .management() + .webPropertyAdWordsLinks() + .get( + accountId=account_id, + webPropertyId=web_property_id, + webPropertyAdWordsLinkId=web_property_ad_words_link_id, + ) + .execute(num_retries=self.num_retries) + ) + return ad_words_link + + def list_ad_words_links( + self, account_id: str, web_property_id: str + ) -> List[Dict[str, Any]]: + """ + Lists webProperty-Google Ads links for a given web property. + + :param account_id: ID of the account which the given web property belongs to. + :type account_id: str + :param web_property_id: Web property UA-string to retrieve the Google Ads links for. + :type web_property_id: str + + :returns: list of entity Google Ads links. + :rtype: list + """ + self.log.info("Retrieving ad words list...") + conn = self.get_conn() + ads_links = ( + conn.management().webPropertyAdWordsLinks() + ) # pylint: disable=no-member + list_args = {"accountId": account_id, "webPropertyId": web_property_id} + result = self._paginate(ads_links, list_args) + return result + + def upload_data( + self, + file_location: str, + account_id: str, + web_property_id: str, + custom_data_source_id: str, + resumable_upload: bool = False, + ) -> None: + """ + Uploads file to GA via the Data Import API + + :param file_location: The path and name of the file to upload. + :type file_location: str + :param account_id: The GA account Id to which the data upload belongs. + :type account_id: str + :param web_property_id: UA-string associated with the upload. + :type web_property_id: str + :param custom_data_source_id: Custom Data Source Id to which this data import belongs. + :type custom_data_source_id: str + :param resumable_upload: flag to upload the file in a resumable fashion, using a + series of at least two requests. + :type resumable_upload: bool + """ + media = MediaFileUpload( + file_location, + mimetype="application/octet-stream", + resumable=resumable_upload, + ) + + self.log.info( + "Uploading file to GA file for accountId: %s, webPropertyId:%s and customDataSourceId:%s ", + account_id, + web_property_id, + custom_data_source_id, + ) + + self.get_conn().management().uploads().uploadData( # pylint: disable=no-member + accountId=account_id, + webPropertyId=web_property_id, + customDataSourceId=custom_data_source_id, + media_body=media, + ).execute() + + def delete_upload_data( + self, + account_id: str, + web_property_id: str, + custom_data_source_id: str, + delete_request_body: Dict[str, Any], + ) -> None: + """ + Deletes the uploaded data for a given account/property/dataset + + :param account_id: The GA account Id to which the data upload belongs. + :type account_id: str + :param web_property_id: UA-string associated with the upload. + :type web_property_id: str + :param custom_data_source_id: Custom Data Source Id to which this data import belongs. + :type custom_data_source_id: str + :param delete_request_body: Dict of customDataImportUids to delete. + :type delete_request_body: dict + """ + self.log.info( + "Deleting previous uploads to GA file for accountId:%s, " + "webPropertyId:%s and customDataSourceId:%s ", + account_id, + web_property_id, + custom_data_source_id, + ) + + self.get_conn().management().uploads().deleteUploadData( # pylint: disable=no-member + accountId=account_id, + webPropertyId=web_property_id, + customDataSourceId=custom_data_source_id, + body=delete_request_body, + ).execute() + + def list_uploads( + self, account_id, web_property_id, custom_data_source_id + ) -> List[Dict[str, Any]]: + """ + Get list of data upload from GA + + :param account_id: The GA account Id to which the data upload belongs. + :type account_id: str + :param web_property_id: UA-string associated with the upload. + :type web_property_id: str + :param custom_data_source_id: Custom Data Source Id to which this data import belongs. + :type custom_data_source_id: str + """ + self.log.info( + "Getting list of uploads for accountId:%s, webPropertyId:%s and customDataSourceId:%s ", + account_id, + web_property_id, + custom_data_source_id, + ) + + uploads = self.get_conn().management().uploads() # pylint: disable=no-member + list_args = { + "accountId": account_id, + "webPropertyId": web_property_id, + "customDataSourceId": custom_data_source_id, + } + result = self._paginate(uploads, list_args) + return result diff --git a/reference/providers/google/marketing_platform/hooks/campaign_manager.py b/reference/providers/google/marketing_platform/hooks/campaign_manager.py new file mode 100644 index 0000000..8e94e61 --- /dev/null +++ b/reference/providers/google/marketing_platform/hooks/campaign_manager.py @@ -0,0 +1,353 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Campaign Manager hook.""" +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient import http +from googleapiclient.discovery import Resource, build + + +class GoogleCampaignManagerHook(GoogleBaseHook): + """Hook for Google Campaign Manager.""" + + _conn = None # type: Optional[Resource] + + def __init__( + self, + api_version: str = "v3.3", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.api_version = api_version + + def get_conn(self) -> Re# + """Retrieves connection to Campaign Manager.""" + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "dfareporting", + self.api_version, + http=http_authorized, + cache_discovery=False, + ) + return self._conn + + def delete_report(self, profile_id: str, report_id: str) -> Any: + """ + Deletes a report by its ID. + + :param profile_id: The DFA user profile ID. + :type profile_id: str + :param report_id: The ID of the report. + :type report_id: str + """ + response = ( + self.get_conn() # pylint: disable=no-member + .reports() + .delete(profileId=profile_id, reportId=report_id) + .execute(num_retries=self.num_retries) + ) + return response + + def insert_report(self, profile_id: str, report: Dict[str, Any]) -> Any: + """ + Creates a report. + + :param profile_id: The DFA user profile ID. + :type profile_id: str + :param report: The report resource to be inserted. + :type report: Dict[str, Any] + """ + response = ( + self.get_conn() # pylint: disable=no-member + .reports() + .insert(profileId=profile_id, body=report) + .execute(num_retries=self.num_retries) + ) + return response + + def list_reports( + self, + profile_id: str, + max_results: Optional[int] = None, + scope: Optional[str] = None, + sort_field: Optional[str] = None, + sort_order: Optional[str] = None, + ) -> List[dict]: + """ + Retrieves list of reports. + + :param profile_id: The DFA user profile ID. + :type profile_id: str + :param max_results: Maximum number of results to return. + :type max_results: Optional[int] + :param scope: The scope that defines which results are returned. + :type scope: Optional[str] + :param sort_field: The field by which to sort the list. + :type sort_field: Optional[str] + :param sort_order: Order of sorted results. + :type sort_order: Optional[str] + """ + reports: List[dict] = [] + conn = self.get_conn() + request = conn.reports().list( # pylint: disable=no-member + profileId=profile_id, + maxResults=max_results, + scope=scope, + sortField=sort_field, + sortOrder=sort_order, + ) + while request is not None: + response = request.execute(num_retries=self.num_retries) + reports.extend(response.get("items", [])) + request = conn.reports().list_next( # pylint: disable=no-member + previous_request=request, previous_response=response + ) + + return reports + + def patch_report(self, profile_id: str, report_id: str, update_mask: dict) -> Any: + """ + Updates a report. This method supports patch semantics. + + :param profile_id: The DFA user profile ID. + :type profile_id: str + :param report_id: The ID of the report. + :type report_id: str + :param update_mask: The relevant portions of a report resource, + according to the rules of patch semantics. + :type update_mask: Dict + """ + response = ( + self.get_conn() # pylint: disable=no-member + .reports() + .patch(profileId=profile_id, reportId=report_id, body=update_mask) + .execute(num_retries=self.num_retries) + ) + return response + + def run_report( + self, profile_id: str, report_id: str, synchronous: Optional[bool] = None + ) -> Any: + """ + Runs a report. + + :param profile_id: The DFA profile ID. + :type profile_id: str + :param report_id: The ID of the report. + :type report_id: str + :param synchronous: If set and true, tries to run the report synchronously. + :type synchronous: Optional[bool] + """ + response = ( + self.get_conn() # pylint: disable=no-member + .reports() + .run(profileId=profile_id, reportId=report_id, synchronous=synchronous) + .execute(num_retries=self.num_retries) + ) + return response + + def update_report(self, profile_id: str, report_id: str) -> Any: + """ + Updates a report. + + :param profile_id: The DFA user profile ID. + :type profile_id: str + :param report_id: The ID of the report. + :type report_id: str + """ + response = ( + self.get_conn() # pylint: disable=no-member + .reports() + .update(profileId=profile_id, reportId=report_id) + .execute(num_retries=self.num_retries) + ) + return response + + def get_report(self, file_id: str, profile_id: str, report_id: str) -> Any: + """ + Retrieves a report file. + + :param profile_id: The DFA user profile ID. + :type profile_id: str + :param report_id: The ID of the report. + :type report_id: str + :param file_id: The ID of the report file. + :type file_id: str + """ + response = ( + self.get_conn() # pylint: disable=no-member + .reports() + .files() + .get(fileId=file_id, profileId=profile_id, reportId=report_id) + .execute(num_retries=self.num_retries) + ) + return response + + def get_report_file( + self, file_id: str, profile_id: str, report_id: str + ) -> http.HttpRequest: + """ + Retrieves a media part of report file. + + :param profile_id: The DFA user profile ID. + :type profile_id: str + :param report_id: The ID of the report. + :type report_id: str + :param file_id: The ID of the report file. + :type file_id: str + :return: googleapiclient.http.HttpRequest + """ + request = ( + self.get_conn() # pylint: disable=no-member + .reports() + .files() + .get_media(fileId=file_id, profileId=profile_id, reportId=report_id) + ) + return request + + @staticmethod + def _conversions_batch_request( + conversions: List[Dict[str, Any]], + encryption_entity_type: str, + encryption_entity_id: int, + encryption_# str, + kind: str, + ) -> Dict[str, Any]: + return { + "kind": kind, + "conversions": conversions, + "encryptionInfo": { + "kind": "dfareporting#encryptionInfo", + "encryptionEntityType": encryption_entity_type, + "encryptionEntityId": encryption_entity_id, + "encryptionSource": encryption_source, + }, + } + + def conversions_batch_insert( + self, + profile_id: str, + conversions: List[Dict[str, Any]], + encryption_entity_type: str, + encryption_entity_id: int, + encryption_# str, + max_failed_inserts: int = 0, + ) -> Any: + """ + Inserts conversions. + + :param profile_id: User profile ID associated with this request. + :type profile_id: str + :param conversions: Conversations to insert, should by type of Conversation: + https://developers.google.com/doubleclick-advertisers/v3.3/conversions#resource + :type conversions: List[Dict[str, Any]] + :param encryption_entity_type: The encryption entity type. This should match the encryption + configuration for ad serving or Data Transfer. + :type encryption_entity_type: str + :param encryption_entity_id: The encryption entity ID. This should match the encryption + configuration for ad serving or Data Transfer. + :type encryption_entity_id: int + :param encryption_# Describes whether the encrypted cookie was received from ad serving + (the %m macro) or from Data Transfer. + :type encryption_# str + :param max_failed_inserts: The maximum number of conversions that failed to be inserted + :type max_failed_inserts: int + """ + response = ( + self.get_conn() # pylint: disable=no-member + .conversions() + .batchinsert( + profileId=profile_id, + body=self._conversions_batch_request( + conversions=conversions, + encryption_entity_type=encryption_entity_type, + encryption_entity_id=encryption_entity_id, + encryption_source=encryption_source, + kind="dfareporting#conversionsBatchInsertRequest", + ), + ) + .execute(num_retries=self.num_retries) + ) + if response.get("hasFailures", False): + errored_conversions = [ + stat["errors"] for stat in response["status"] if "errors" in stat + ] + if len(errored_conversions) > max_failed_inserts: + raise AirflowException(errored_conversions) + return response + + def conversions_batch_update( + self, + profile_id: str, + conversions: List[Dict[str, Any]], + encryption_entity_type: str, + encryption_entity_id: int, + encryption_# str, + max_failed_updates: int = 0, + ) -> Any: + """ + Updates existing conversions. + + :param profile_id: User profile ID associated with this request. + :type profile_id: str + :param conversions: Conversations to update, should by type of Conversation: + https://developers.google.com/doubleclick-advertisers/v3.3/conversions#resource + :type conversions: List[Dict[str, Any]] + :param encryption_entity_type: The encryption entity type. This should match the encryption + configuration for ad serving or Data Transfer. + :type encryption_entity_type: str + :param encryption_entity_id: The encryption entity ID. This should match the encryption + configuration for ad serving or Data Transfer. + :type encryption_entity_id: int + :param encryption_# Describes whether the encrypted cookie was received from ad serving + (the %m macro) or from Data Transfer. + :type encryption_# str + :param max_failed_updates: The maximum number of conversions that failed to be updated + :type max_failed_updates: int + """ + response = ( + self.get_conn() # pylint: disable=no-member + .conversions() + .batchupdate( + profileId=profile_id, + body=self._conversions_batch_request( + conversions=conversions, + encryption_entity_type=encryption_entity_type, + encryption_entity_id=encryption_entity_id, + encryption_source=encryption_source, + kind="dfareporting#conversionsBatchUpdateRequest", + ), + ) + .execute(num_retries=self.num_retries) + ) + if response.get("hasFailures", False): + errored_conversions = [ + stat["errors"] for stat in response["status"] if "errors" in stat + ] + if len(errored_conversions) > max_failed_updates: + raise AirflowException(errored_conversions) + return response diff --git a/reference/providers/google/marketing_platform/hooks/display_video.py b/reference/providers/google/marketing_platform/hooks/display_video.py new file mode 100644 index 0000000..75e96fc --- /dev/null +++ b/reference/providers/google/marketing_platform/hooks/display_video.py @@ -0,0 +1,250 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google DisplayVideo hook.""" + +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import Resource, build + + +class GoogleDisplayVideo360Hook(GoogleBaseHook): + """Hook for Google Display & Video 360.""" + + _conn = None # type: Optional[Any] + + def __init__( + self, + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.api_version = api_version + + def get_conn(self) -> Re# + """Retrieves connection to DisplayVideo.""" + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "doubleclickbidmanager", + self.api_version, + http=http_authorized, + cache_discovery=False, + ) + return self._conn + + def get_conn_to_display_video(self) -> Re# + """Retrieves connection to DisplayVideo.""" + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "displayvideo", + self.api_version, + http=http_authorized, + cache_discovery=False, + ) + return self._conn + + @staticmethod + def erf_uri(partner_id, entity_type) -> List[str]: + """ + Return URI for all Entity Read Files in bucket. + + For example, if you were generating a file name to retrieve the entity read file + for partner 123 accessing the line_item table from April 2, 2013, your filename + would look something like this: + gdbm-123/entity/20130402.0.LineItem.json + + More information: + https://developers.google.com/bid-manager/guides/entity-read/overview + + :param partner_id The numeric ID of your Partner. + :type partner_id: int + :param entity_type: The type of file Partner, Advertiser, InsertionOrder, + LineItem, Creative, Pixel, InventorySource, UserList, UniversalChannel, and summary. + :type entity_type: str + """ + return [f"gdbm-{partner_id}/entity/{{{{ ds_nodash }}}}.*.{entity_type}.json"] + + def create_query(self, query: Dict[str, Any]) -> dict: + """ + Creates a query. + + :param query: Query object to be passed to request body. + :type query: Dict[str, Any] + """ + response = ( + self.get_conn() # pylint: disable=no-member + .queries() + .createquery(body=query) + .execute(num_retries=self.num_retries) + ) + return response + + def delete_query(self, query_id: str) -> None: + """ + Deletes a stored query as well as the associated stored reports. + + :param query_id: Query ID to delete. + :type query_id: str + """ + ( + self.get_conn() # pylint: disable=no-member + .queries() + .deletequery(queryId=query_id) + .execute(num_retries=self.num_retries) + ) + + def get_query(self, query_id: str) -> dict: + """ + Retrieves a stored query. + + :param query_id: Query ID to retrieve. + :type query_id: str + """ + response = ( + self.get_conn() # pylint: disable=no-member + .queries() + .getquery(queryId=query_id) + .execute(num_retries=self.num_retries) + ) + return response + + def list_queries( + self, + ) -> List[Dict]: + """Retrieves stored queries.""" + response = ( + self.get_conn() # pylint: disable=no-member + .queries() + .listqueries() + .execute(num_retries=self.num_retries) + ) + return response.get("queries", []) + + def run_query(self, query_id: str, params: Dict[str, Any]) -> None: + """ + Runs a stored query to generate a report. + + :param query_id: Query ID to run. + :type query_id: str + :param params: Parameters for the report. + :type params: Dict[str, Any] + """ + ( + self.get_conn() # pylint: disable=no-member + .queries() + .runquery(queryId=query_id, body=params) + .execute(num_retries=self.num_retries) + ) + + def upload_line_items(self, line_items: Any) -> List[Dict[str, Any]]: + """ + Uploads line items in CSV format. + + :param line_items: downloaded data from GCS and passed to the body request + :type line_items: Any + :return: response body. + :rtype: List[Dict[str, Any]] + """ + request_body = { + "lineItems": line_items, + "dryRun": False, + "format": "CSV", + } + + response = ( + self.get_conn() # pylint: disable=no-member + .lineitems() + .uploadlineitems(body=request_body) + .execute(num_retries=self.num_retries) + ) + return response + + def download_line_items(self, request_body: Dict[str, Any]) -> List[Any]: + """ + Retrieves line items in CSV format. + + :param request_body: dictionary with parameters that should be passed into. + More information about it can be found here: + https://developers.google.com/bid-manager/v1.1/lineitems/downloadlineitems + :type request_body: Dict[str, Any] + """ + response = ( + self.get_conn() # pylint: disable=no-member + .lineitems() + .downloadlineitems(body=request_body) + .execute(num_retries=self.num_retries) + ) + return response["lineItems"] + + def create_sdf_download_operation( + self, body_request: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Creates an SDF Download Task and Returns an Operation. + + :param body_request: Body request. + :type body_request: Dict[str, Any] + + More information about body request n be found here: + https://developers.google.com/display-video/api/reference/rest/v1/sdfdownloadtasks/create + """ + result = ( + self.get_conn_to_display_video() # pylint: disable=no-member + .sdfdownloadtasks() + .create(body=body_request) + .execute(num_retries=self.num_retries) + ) + return result + + def get_sdf_download_operation(self, operation_name: str): + """ + Gets the latest state of an asynchronous SDF download task operation. + + :param operation_name: The name of the operation resource. + :type operation_name: str + """ + result = ( + self.get_conn_to_display_video() # pylint: disable=no-member + .sdfdownloadtasks() + .operations() + .get(name=operation_name) + .execute(num_retries=self.num_retries) + ) + return result + + def download_media(self, resource_name: str): + """ + Downloads media. + + :param resource_name: of the media that is being downloaded. + :type resource_name: str + """ + request = ( + self.get_conn_to_display_video() # pylint: disable=no-member + .media() + .download_media(resource_name=resource_name) + ) + return request diff --git a/reference/providers/google/marketing_platform/hooks/search_ads.py b/reference/providers/google/marketing_platform/hooks/search_ads.py new file mode 100644 index 0000000..037186b --- /dev/null +++ b/reference/providers/google/marketing_platform/hooks/search_ads.py @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Search Ads 360 hook.""" +from typing import Any, Dict, Optional, Sequence, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import build + + +class GoogleSearchAdsHook(GoogleBaseHook): + """Hook for Google Search Ads 360.""" + + _conn = None # type: Optional[Any] + + def __init__( + self, + api_version: str = "v2", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.api_version = api_version + + def get_conn(self): + """Retrieves connection to Google SearchAds.""" + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "doubleclicksearch", + self.api_version, + http=http_authorized, + cache_discovery=False, + ) + return self._conn + + def insert_report(self, report: Dict[str, Any]) -> Any: + """ + Inserts a report request into the reporting system. + + :param report: Report to be generated. + :type report: Dict[str, Any] + """ + response = ( + self.get_conn() # pylint: disable=no-member + .reports() + .request(body=report) + .execute(num_retries=self.num_retries) + ) + return response + + def get(self, report_id: str) -> Any: + """ + Polls for the status of a report request. + + :param report_id: ID of the report request being polled. + :type report_id: str + """ + response = ( + self.get_conn() # pylint: disable=no-member + .reports() + .get(reportId=report_id) + .execute(num_retries=self.num_retries) + ) + return response + + def get_file(self, report_fragment: int, report_id: str) -> Any: + """ + Downloads a report file encoded in UTF-8. + + :param report_fragment: The index of the report fragment to download. + :type report_fragment: int + :param report_id: ID of the report. + :type report_id: str + """ + response = ( + self.get_conn() # pylint: disable=no-member + .reports() + .getFile(reportFragment=report_fragment, reportId=report_id) + .execute(num_retries=self.num_retries) + ) + return response diff --git a/reference/providers/google/marketing_platform/operators/__init__.py b/reference/providers/google/marketing_platform/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/marketing_platform/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/marketing_platform/operators/analytics.py b/reference/providers/google/marketing_platform/operators/analytics.py new file mode 100644 index 0000000..4d15635 --- /dev/null +++ b/reference/providers/google/marketing_platform/operators/analytics.py @@ -0,0 +1,543 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Analytics 360 operators.""" +import csv +from tempfile import NamedTemporaryFile +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.marketing_platform.hooks.analytics import ( + GoogleAnalyticsHook, +) +from airflow.utils.decorators import apply_defaults + + +class GoogleAnalyticsListAccountsOperator(BaseOperator): + """ + Lists all accounts to which the user has access. + + .. seealso:: + Check official API docs: + https://developers.google.com/analytics/devguides/config/mgmt/v3/mgmtReference/management/accounts/list + and for python client + http://googleapis.github.io/google-api-python-client/docs/dyn/analytics_v3.management.accounts.html#list + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleAnalyticsListAccountsOperator` + + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "api_version", + "gcp_conn_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + api_version: str = "v3", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> List[Dict[str, Any]]: + hook = GoogleAnalyticsHook( + api_version=self.api_version, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + result = hook.list_accounts() + return result + + +class GoogleAnalyticsGetAdsLinkOperator(BaseOperator): + """ + Returns a web property-Google Ads link to which the user has access. + + .. seealso:: + Check official API docs: + https://developers.google.com/analytics/devguides/config/mgmt/v3/mgmtReference/management/webPropertyAdWordsLinks/get + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleAnalyticsGetAdsLinkOperator` + + :param account_id: ID of the account which the given web property belongs to. + :type account_id: str + :param web_property_ad_words_link_id: Web property-Google Ads link ID. + :type web_property_ad_words_link_id: str + :param web_property_id: Web property ID to retrieve the Google Ads link for. + :type web_property_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "api_version", + "gcp_conn_id", + "account_id", + "web_property_ad_words_link_id", + "web_property_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + account_id: str, + web_property_ad_words_link_id: str, + web_property_id: str, + api_version: str = "v3", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.account_id = account_id + self.web_property_ad_words_link_id = web_property_ad_words_link_id + self.web_property_id = web_property_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> Dict[str, Any]: + hook = GoogleAnalyticsHook( + api_version=self.api_version, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + result = hook.get_ad_words_link( + account_id=self.account_id, + web_property_id=self.web_property_id, + web_property_ad_words_link_id=self.web_property_ad_words_link_id, + ) + return result + + +class GoogleAnalyticsRetrieveAdsLinksListOperator(BaseOperator): + """ + Lists webProperty-Google Ads links for a given web property + + .. seealso:: + Check official API docs: + https://developers.google.com/analytics/devguides/config/mgmt/v3/mgmtReference/management/webPropertyAdWordsLinks/list#http-request + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleAnalyticsRetrieveAdsLinksListOperator` + + :param account_id: ID of the account which the given web property belongs to. + :type account_id: str + :param web_property_id: Web property UA-string to retrieve the Google Ads links for. + :type web_property_id: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "api_version", + "gcp_conn_id", + "account_id", + "web_property_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + account_id: str, + web_property_id: str, + api_version: str = "v3", + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.account_id = account_id + self.web_property_id = web_property_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> List[Dict[str, Any]]: + hook = GoogleAnalyticsHook( + api_version=self.api_version, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + result = hook.list_ad_words_links( + account_id=self.account_id, + web_property_id=self.web_property_id, + ) + return result + + +class GoogleAnalyticsDataImportUploadOperator(BaseOperator): + """ + Take a file from Cloud Storage and uploads it to GA via data import API. + + :param storage_bucket: The Google cloud storage bucket where the file is stored. + :type storage_bucket: str + :param storage_name_object: The name of the object in the desired Google cloud + storage bucket. (templated) If the destination points to an existing + folder, the file will be taken from the specified folder. + :type storage_name_object: str + :param account_id: The GA account Id (long) to which the data upload belongs. + :type account_id: str + :param web_property_id: The web property UA-string associated with the upload. + :type web_property_id: str + :param custom_data_source_id: The id to which the data import belongs + :type custom_data_source_id: str + :param resumable_upload: flag to upload the file in a resumable fashion, using a + series of at least two requests. + :type resumable_upload: bool + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "storage_bucket", + "storage_name_object", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + storage_bucket: str, + storage_name_object: str, + account_id: str, + web_property_id: str, + custom_data_source_id: str, + resumable_upload: bool = False, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + api_version: str = "v3", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.storage_bucket = storage_bucket + self.storage_name_object = storage_name_object + self.account_id = account_id + self.web_property_id = web_property_id + self.custom_data_source_id = custom_data_source_id + self.resumable_upload = resumable_upload + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.api_version = api_version + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + ga_hook = GoogleAnalyticsHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + + with NamedTemporaryFile("w+") as tmp_file: + self.log.info( + "Downloading file from GCS: %s/%s ", + self.storage_bucket, + self.storage_name_object, + ) + gcs_hook.download( + bucket_name=self.storage_bucket, + object_name=self.storage_name_object, + filename=tmp_file.name, + ) + + ga_hook.upload_data( + tmp_file.name, + self.account_id, + self.web_property_id, + self.custom_data_source_id, + self.resumable_upload, + ) + + +class GoogleAnalyticsDeletePreviousDataUploadsOperator(BaseOperator): + """ + Deletes previous GA uploads to leave the latest file to control the size of the Data Set Quota. + + :param account_id: The GA account Id (long) to which the data upload belongs. + :type account_id: str + :param web_property_id: The web property UA-string associated with the upload. + :type web_property_id: str + :param custom_data_source_id: The id to which the data import belongs. + :type custom_data_source_id: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ("impersonation_chain",) + + def __init__( + self, + account_id: str, + web_property_id: str, + custom_data_source_id: str, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + api_version: str = "v3", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.account_id = account_id + self.web_property_id = web_property_id + self.custom_data_source_id = custom_data_source_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.api_version = api_version + self.impersonation_chain = impersonation_chain + + def execute(self, context) -> None: + ga_hook = GoogleAnalyticsHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + + uploads = ga_hook.list_uploads( + account_id=self.account_id, + web_property_id=self.web_property_id, + custom_data_source_id=self.custom_data_source_id, + ) + + cids = [upload["id"] for upload in uploads] + delete_request_body = {"customDataImportUids": cids} + + ga_hook.delete_upload_data( + self.account_id, + self.web_property_id, + self.custom_data_source_id, + delete_request_body, + ) + + +class GoogleAnalyticsModifyFileHeadersDataImportOperator(BaseOperator): + """ + GA has a very particular naming convention for Data Import. Ability to + prefix "ga:" to all column headers and also a dict to rename columns to + match the custom dimension ID in GA i.e clientId : dimensionX. + + :param storage_bucket: The Google cloud storage bucket where the file is stored. + :type storage_bucket: str + :param storage_name_object: The name of the object in the desired Google cloud + storage bucket. (templated) If the destination points to an existing + folder, the file will be taken from the specified folder. + :type storage_name_object: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param custom_dimension_header_mapping: Dictionary to handle when uploading + custom dimensions which have generic IDs ie. 'dimensionX' which are + set by GA. Dictionary maps the current CSV header to GA ID which will + be the new header for the CSV to upload to GA eg clientId : dimension1. + :type custom_dimension_header_mapping: dict + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "storage_bucket", + "storage_name_object", + "impersonation_chain", + ) + + def __init__( + self, + storage_bucket: str, + storage_name_object: str, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + custom_dimension_header_mapping: Optional[Dict[str, str]] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.storage_bucket = storage_bucket + self.storage_name_object = storage_name_object + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.custom_dimension_header_mapping = custom_dimension_header_mapping or {} + self.impersonation_chain = impersonation_chain + + def _modify_column_headers( + self, tmp_file_location: str, custom_dimension_header_mapping: Dict[str, str] + ) -> None: + # Check headers + self.log.info("Checking if file contains headers") + with open(tmp_file_location) as check_header_file: + has_header = csv.Sniffer().has_header(check_header_file.read(1024)) + if not has_header: + raise NameError( + "CSV does not contain headers, please add them " + "to use the modify column headers functionality" + ) + + # Transform + self.log.info("Modifying column headers to be compatible for data upload") + with open(tmp_file_location) as read_file: + reader = csv.reader(read_file) + headers = next(reader) + new_headers = [] + for header in headers: + if header in custom_dimension_header_mapping: + header = custom_dimension_header_mapping.get(header) # type: ignore + new_header = f"ga:{header}" + new_headers.append(new_header) + all_data = read_file.readlines() + final_headers = ",".join(new_headers) + "\n" + all_data.insert(0, final_headers) + + # Save result + self.log.info("Saving transformed file") + with open(tmp_file_location, "w") as write_file: + write_file.writelines(all_data) + + def execute(self, context) -> None: + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + with NamedTemporaryFile("w+") as tmp_file: + # Download file from GCS + self.log.info( + "Downloading file from GCS: %s/%s ", + self.storage_bucket, + self.storage_name_object, + ) + + gcs_hook.download( + bucket_name=self.storage_bucket, + object_name=self.storage_name_object, + filename=tmp_file.name, + ) + + # Modify file + self.log.info("Modifying temporary file %s", tmp_file.name) + self._modify_column_headers( + tmp_file_location=tmp_file.name, + custom_dimension_header_mapping=self.custom_dimension_header_mapping, + ) + + # Upload newly formatted file to cloud storage + self.log.info( + "Uploading file to GCS: %s/%s ", + self.storage_bucket, + self.storage_name_object, + ) + gcs_hook.upload( + bucket_name=self.storage_bucket, + object_name=self.storage_name_object, + filename=tmp_file.name, + ) diff --git a/reference/providers/google/marketing_platform/operators/campaign_manager.py b/reference/providers/google/marketing_platform/operators/campaign_manager.py new file mode 100644 index 0000000..3a59bf9 --- /dev/null +++ b/reference/providers/google/marketing_platform/operators/campaign_manager.py @@ -0,0 +1,654 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google CampaignManager operators.""" +import json +import tempfile +import uuid +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.marketing_platform.hooks.campaign_manager import ( + GoogleCampaignManagerHook, +) +from airflow.utils.decorators import apply_defaults +from googleapiclient import http + + +class GoogleCampaignManagerDeleteReportOperator(BaseOperator): + """ + Deletes a report by its ID. + + .. seealso:: + Check official API docs: + https://developers.google.com/doubleclick-advertisers/v3.3/reports/delete + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleCampaignManagerDeleteReportOperator` + + :param profile_id: The DFA user profile ID. + :type profile_id: str + :param report_name: The name of the report to delete. + :type report_name: str + :param report_id: The ID of the report. + :type report_id: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "profile_id", + "report_id", + "report_name", + "api_version", + "gcp_conn_id", + "delegate_to", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + profile_id: str, + report_name: Optional[str] = None, + report_id: Optional[str] = None, + api_version: str = "v3.3", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if not (report_name or report_id): + raise AirflowException("Please provide `report_name` or `report_id`.") + if report_name and report_id: + raise AirflowException( + "Please provide only one parameter `report_name` or `report_id`." + ) + + self.profile_id = profile_id + self.report_name = report_name + self.report_id = report_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = GoogleCampaignManagerHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + if self.report_name: + reports = hook.list_reports(profile_id=self.profile_id) + reports_with_name = [r for r in reports if r["name"] == self.report_name] + for report in reports_with_name: + report_id = report["id"] + self.log.info("Deleting Campaign Manager report: %s", report_id) + hook.delete_report(profile_id=self.profile_id, report_id=report_id) + self.log.info("Report deleted.") + elif self.report_id: + self.log.info("Deleting Campaign Manager report: %s", self.report_id) + hook.delete_report(profile_id=self.profile_id, report_id=self.report_id) + self.log.info("Report deleted.") + + +class GoogleCampaignManagerDownloadReportOperator(BaseOperator): + """ + Retrieves a report and uploads it to GCS bucket. + + .. seealso:: + Check official API docs: + https://developers.google.com/doubleclick-advertisers/v3.3/reports/files/get + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleCampaignManagerDownloadReportOperator` + + :param profile_id: The DFA user profile ID. + :type profile_id: str + :param report_id: The ID of the report. + :type report_id: str + :param file_id: The ID of the report file. + :type file_id: str + :param bucket_name: The bucket to upload to. + :type bucket_name: str + :param report_name: The report name to set when uploading the local file. + :type report_name: str + :param gzip: Option to compress local file or file data for upload + :type gzip: bool + :param chunk_size: File will be downloaded in chunks of this many bytes. + :type chunk_size: int + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "profile_id", + "report_id", + "file_id", + "bucket_name", + "report_name", + "chunk_size", + "api_version", + "gcp_conn_id", + "delegate_to", + "impersonation_chain", + ) + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + profile_id: str, + report_id: str, + file_id: str, + bucket_name: str, + report_name: Optional[str] = None, + gzip: bool = True, + chunk_size: int = 10 * 1024 * 1024, + api_version: str = "v3.3", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.profile_id = profile_id + self.report_id = report_id + self.file_id = file_id + self.api_version = api_version + self.chunk_size = chunk_size + self.gzip = gzip + self.bucket_name = self._set_bucket_name(bucket_name) + self.report_name = report_name + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def _resolve_file_name(self, name: str) -> str: + csv = ".csv" + gzip = ".gz" + if not name.endswith(csv): + name += csv + if self.gzip: + name += gzip + return name + + @staticmethod + def _set_bucket_name(name: str) -> str: + bucket = name if not name.startswith("gs://") else name[5:] + return bucket.strip("/") + + def execute(self, context: dict) -> None: + hook = GoogleCampaignManagerHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + # Get name of the report + report = hook.get_report( + file_id=self.file_id, profile_id=self.profile_id, report_id=self.report_id + ) + report_name = self.report_name or report.get("fileName", str(uuid.uuid4())) + report_name = self._resolve_file_name(report_name) + + # Download the report + self.log.info("Starting downloading report %s", self.report_id) + request = hook.get_report_file( + profile_id=self.profile_id, report_id=self.report_id, file_id=self.file_id + ) + with tempfile.NamedTemporaryFile() as temp_file: + downloader = http.MediaIoBaseDownload( + fd=temp_file, request=request, chunksize=self.chunk_size + ) + download_finished = False + while not download_finished: + _, download_finished = downloader.next_chunk() + + temp_file.flush() + # Upload the local file to bucket + gcs_hook.upload( + bucket_name=self.bucket_name, + object_name=report_name, + gzip=self.gzip, + filename=temp_file.name, + mime_type="text/csv", + ) + + self.xcom_push(context, key="report_name", value=report_name) + + +class GoogleCampaignManagerInsertReportOperator(BaseOperator): + """ + Creates a report. + + .. seealso:: + Check official API docs: + https://developers.google.com/doubleclick-advertisers/v3.3/reports/insert + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleCampaignManagerInsertReportOperator` + + :param profile_id: The DFA user profile ID. + :type profile_id: str + :param report: Report to be created. + :type report: Dict[str, Any] + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "profile_id", + "report", + "api_version", + "gcp_conn_id", + "delegate_to", + "impersonation_chain", + ) + + template_ext = (".json",) + + @apply_defaults + def __init__( + self, + *, + profile_id: str, + report: Dict[str, Any], + api_version: str = "v3.3", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.profile_id = profile_id + self.report = report + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def prepare_template(self) -> None: + # If .json is passed then we have to read the file + if isinstance(self.report, str) and self.report.endswith(".json"): + with open(self.report) as file: + self.report = json.load(file) + + def execute(self, context: dict): + hook = GoogleCampaignManagerHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Inserting Campaign Manager report.") + response = hook.insert_report(profile_id=self.profile_id, report=self.report) + report_id = response.get("id") + self.xcom_push(context, key="report_id", value=report_id) + self.log.info("Report successfully inserted. Report id: %s", report_id) + return response + + +class GoogleCampaignManagerRunReportOperator(BaseOperator): + """ + Runs a report. + + .. seealso:: + Check official API docs: + https://developers.google.com/doubleclick-advertisers/v3.3/reports/run + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleCampaignManagerRunReportOperator` + + :param profile_id: The DFA profile ID. + :type profile_id: str + :param report_id: The ID of the report. + :type report_id: str + :param synchronous: If set and true, tries to run the report synchronously. + :type synchronous: bool + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "profile_id", + "report_id", + "synchronous", + "api_version", + "gcp_conn_id", + "delegate_to", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + profile_id: str, + report_id: str, + synchronous: bool = False, + api_version: str = "v3.3", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.profile_id = profile_id + self.report_id = report_id + self.synchronous = synchronous + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict): + hook = GoogleCampaignManagerHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Running report %s", self.report_id) + response = hook.run_report( + profile_id=self.profile_id, + report_id=self.report_id, + synchronous=self.synchronous, + ) + file_id = response.get("id") + self.xcom_push(context, key="file_id", value=file_id) + self.log.info("Report file id: %s", file_id) + return response + + +class GoogleCampaignManagerBatchInsertConversionsOperator(BaseOperator): + """ + Inserts conversions. + + .. seealso:: + Check official API docs: + https://developers.google.com/doubleclick-advertisers/v3.3/conversions/batchinsert + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleCampaignManagerBatchInsertConversionsOperator` + + :param profile_id: User profile ID associated with this request. + :type profile_id: str + :param conversions: Conversations to insert, should by type of Conversation: + https://developers.google.com/doubleclick-advertisers/v3.3/conversions#resource + :type conversions: List[Dict[str, Any]] + :param encryption_entity_type: The encryption entity type. This should match the encryption + configuration for ad serving or Data Transfer. + :type encryption_entity_type: str + :param encryption_entity_id: The encryption entity ID. This should match the encryption + configuration for ad serving or Data Transfer. + :type encryption_entity_id: int + :param encryption_# Describes whether the encrypted cookie was received from ad serving + (the %m macro) or from Data Transfer. + :type encryption_# str + :param max_failed_inserts: The maximum number of conversions that failed to be inserted + :type max_failed_inserts: int + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "profile_id", + "conversions", + "encryption_entity_type", + "encryption_entity_id", + "encryption_source", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + profile_id: str, + conversions: List[Dict[str, Any]], + encryption_entity_type: str, + encryption_entity_id: int, + encryption_# str, + max_failed_inserts: int = 0, + api_version: str = "v3.3", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.profile_id = profile_id + self.conversions = conversions + self.encryption_entity_type = encryption_entity_type + self.encryption_entity_id = encryption_entity_id + self.encryption_source = encryption_source + self.max_failed_inserts = max_failed_inserts + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict): + hook = GoogleCampaignManagerHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + response = hook.conversions_batch_insert( + profile_id=self.profile_id, + conversions=self.conversions, + encryption_entity_type=self.encryption_entity_type, + encryption_entity_id=self.encryption_entity_id, + encryption_source=self.encryption_source, + max_failed_inserts=self.max_failed_inserts, + ) + return response + + +class GoogleCampaignManagerBatchUpdateConversionsOperator(BaseOperator): + """ + Updates existing conversions. + + .. seealso:: + Check official API docs: + https://developers.google.com/doubleclick-advertisers/v3.3/conversions/batchupdate + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleCampaignManagerBatchUpdateConversionsOperator` + + :param profile_id: User profile ID associated with this request. + :type profile_id: str + :param conversions: Conversations to update, should by type of Conversation: + https://developers.google.com/doubleclick-advertisers/v3.3/conversions#resource + :type conversions: List[Dict[str, Any]] + :param encryption_entity_type: The encryption entity type. This should match the encryption + configuration for ad serving or Data Transfer. + :type encryption_entity_type: str + :param encryption_entity_id: The encryption entity ID. This should match the encryption + configuration for ad serving or Data Transfer. + :type encryption_entity_id: int + :param encryption_# Describes whether the encrypted cookie was received from ad serving + (the %m macro) or from Data Transfer. + :type encryption_# str + :param max_failed_updates: The maximum number of conversions that failed to be updated + :type max_failed_updates: int + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "profile_id", + "conversions", + "encryption_entity_type", + "encryption_entity_id", + "encryption_source", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + profile_id: str, + conversions: List[Dict[str, Any]], + encryption_entity_type: str, + encryption_entity_id: int, + encryption_# str, + max_failed_updates: int = 0, + api_version: str = "v3.3", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.profile_id = profile_id + self.conversions = conversions + self.encryption_entity_type = encryption_entity_type + self.encryption_entity_id = encryption_entity_id + self.encryption_source = encryption_source + self.max_failed_updates = max_failed_updates + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict): + hook = GoogleCampaignManagerHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + response = hook.conversions_batch_update( + profile_id=self.profile_id, + conversions=self.conversions, + encryption_entity_type=self.encryption_entity_type, + encryption_entity_id=self.encryption_entity_id, + encryption_source=self.encryption_source, + max_failed_updates=self.max_failed_updates, + ) + return response diff --git a/reference/providers/google/marketing_platform/operators/display_video.py b/reference/providers/google/marketing_platform/operators/display_video.py new file mode 100644 index 0000000..cdbf75c --- /dev/null +++ b/reference/providers/google/marketing_platform/operators/display_video.py @@ -0,0 +1,749 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google DisplayVideo operators.""" +import csv +import json +import shutil +import tempfile +import urllib.request +from typing import Any, Dict, List, Optional, Sequence, Union +from urllib.parse import urlparse + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.marketing_platform.hooks.display_video import ( + GoogleDisplayVideo360Hook, +) +from airflow.utils.decorators import apply_defaults + + +class GoogleDisplayVideo360CreateReportOperator(BaseOperator): + """ + Creates a query. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleDisplayVideo360CreateReportOperator` + + .. seealso:: + Check also the official API docs: + `https://developers.google.com/bid-manager/v1/queries/createquery` + + :param body: Report object passed to the request's body as described here: + https://developers.google.com/bid-manager/v1/queries#resource + :type body: Dict[str, Any] + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "body", + "impersonation_chain", + ) + template_ext = (".json",) + + @apply_defaults + def __init__( + self, + *, + body: Dict[str, Any], + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.body = body + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def prepare_template(self) -> None: + # If .json is passed then we have to read the file + if isinstance(self.body, str) and self.body.endswith(".json"): + with open(self.body) as file: + self.body = json.load(file) + + def execute(self, context: dict) -> dict: + hook = GoogleDisplayVideo360Hook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Creating Display & Video 360 report.") + response = hook.create_query(query=self.body) + report_id = response["queryId"] + self.xcom_push(context, key="report_id", value=report_id) + self.log.info("Created report with ID: %s", report_id) + return response + + +class GoogleDisplayVideo360DeleteReportOperator(BaseOperator): + """ + Deletes a stored query as well as the associated stored reports. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleDisplayVideo360DeleteReportOperator` + + .. seealso:: + Check also the official API docs: + `https://developers.google.com/bid-manager/v1/queries/deletequery` + + :param report_id: Report ID to delete. + :type report_id: str + :param report_name: Name of the report to delete. + :type report_name: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "report_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + report_id: Optional[str] = None, + report_name: Optional[str] = None, + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.report_id = report_id + self.report_name = report_name + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + if report_name and report_id: + raise AirflowException("Use only one value - `report_name` or `report_id`.") + + if not (report_name or report_id): + raise AirflowException( + "Provide one of the values: `report_name` or `report_id`." + ) + + def execute(self, context: dict) -> None: + hook = GoogleDisplayVideo360Hook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + if self.report_id: + reports_ids_to_delete = [self.report_id] + else: + reports = hook.list_queries() + reports_ids_to_delete = [ + report["queryId"] + for report in reports + if report["metadata"]["title"] == self.report_name + ] + + for report_id in reports_ids_to_delete: + self.log.info("Deleting report with id: %s", report_id) + hook.delete_query(query_id=report_id) + self.log.info("Report deleted.") + + +class GoogleDisplayVideo360DownloadReportOperator(BaseOperator): + """ + Retrieves a stored query. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleDisplayVideo360DownloadReportOperator` + + .. seealso:: + Check also the official API docs: + `https://developers.google.com/bid-manager/v1/queries/getquery` + + :param report_id: Report ID to retrieve. + :type report_id: str + :param bucket_name: The bucket to upload to. + :type bucket_name: str + :param report_name: The report name to set when uploading the local file. + :type report_name: str + :param chunk_size: File will be downloaded in chunks of this many bytes. + :type chunk_size: int + :param gzip: Option to compress local file or file data for upload + :type gzip: bool + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "report_id", + "bucket_name", + "report_name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + report_id: str, + bucket_name: str, + report_name: Optional[str] = None, + gzip: bool = True, + chunk_size: int = 10 * 1024 * 1024, + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.report_id = report_id + self.chunk_size = chunk_size + self.gzip = gzip + self.bucket_name = self._set_bucket_name(bucket_name) + self.report_name = report_name + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def _resolve_file_name(self, name: str) -> str: + new_name = name if name.endswith(".csv") else f"{name}.csv" + new_name = f"{new_name}.gz" if self.gzip else new_name + return new_name + + @staticmethod + def _set_bucket_name(name: str) -> str: + bucket = name if not name.startswith("gs://") else name[5:] + return bucket.strip("/") + + def execute(self, context: dict): + hook = GoogleDisplayVideo360Hook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + resource = hook.get_query(query_id=self.report_id) + # Check if report is ready + if resource["metadata"]["running"]: + raise AirflowException(f"Report {self.report_id} is still running") + + # If no custom report_name provided, use DV360 name + file_url = resource["metadata"]["googleCloudStoragePathForLatestReport"] + report_name = self.report_name or urlparse(file_url).path.split("/")[-1] + report_name = self._resolve_file_name(report_name) + + # Download the report + self.log.info("Starting downloading report %s", self.report_id) + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + with urllib.request.urlopen(file_url) as response: + shutil.copyfileobj(response, temp_file, length=self.chunk_size) + + temp_file.flush() + # Upload the local file to bucket + gcs_hook.upload( + bucket_name=self.bucket_name, + object_name=report_name, + gzip=self.gzip, + filename=temp_file.name, + mime_type="text/csv", + ) + self.log.info( + "Report %s was saved in bucket %s as %s.", + self.report_id, + self.bucket_name, + report_name, + ) + self.xcom_push(context, key="report_name", value=report_name) + + +class GoogleDisplayVideo360RunReportOperator(BaseOperator): + """ + Runs a stored query to generate a report. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleDisplayVideo360RunReportOperator` + + .. seealso:: + Check also the official API docs: + `https://developers.google.com/bid-manager/v1/queries/runquery` + + :param report_id: Report ID to run. + :type report_id: str + :param params: Parameters for running a report as described here: + https://developers.google.com/bid-manager/v1/queries/runquery + :type params: Dict[str, Any] + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "report_id", + "params", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + report_id: str, + params: Dict[str, Any], + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.report_id = report_id + self.params = params + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + hook = GoogleDisplayVideo360Hook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info( + "Running report %s with the following params:\n %s", + self.report_id, + self.params, + ) + hook.run_query(query_id=self.report_id, params=self.params) + + +class GoogleDisplayVideo360DownloadLineItemsOperator(BaseOperator): + """ + Retrieves line items in CSV format. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleDisplayVideo360DownloadLineItemsOperator` + + .. seealso:: + Check also the official API docs: + `https://developers.google.com/bid-manager/v1.1/lineitems/downloadlineitems` + + :param request_body: dictionary with parameters that should be passed into. + More information about it can be found here: + https://developers.google.com/bid-manager/v1.1/lineitems/downloadlineitems + :type request_body: Dict[str, Any], + """ + + template_fields = ( + "request_body", + "bucket_name", + "object_name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + request_body: Dict[str, Any], + bucket_name: str, + object_name: str, + gzip: bool = False, + api_version: str = "v1.1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.request_body = request_body + self.object_name = object_name + self.bucket_name = bucket_name + self.gzip = gzip + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> str: + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + hook = GoogleDisplayVideo360Hook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Retrieving report...") + content: List[str] = hook.download_line_items(request_body=self.request_body) + + with tempfile.NamedTemporaryFile("w+") as temp_file: + writer = csv.writer(temp_file) + writer.writerows(content) + temp_file.flush() + gcs_hook.upload( + bucket_name=self.bucket_name, + object_name=self.object_name, + filename=temp_file.name, + mime_type="text/csv", + gzip=self.gzip, + ) + return f"{self.bucket_name}/{self.object_name}" + + +class GoogleDisplayVideo360UploadLineItemsOperator(BaseOperator): + """ + Uploads line items in CSV format. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleDisplayVideo360UploadLineItemsOperator` + + .. seealso:: + Check also the official API docs: + `https://developers.google.com/bid-manager/v1.1/lineitems/uploadlineitems` + + :param request_body: request to upload line items. + :type request_body: Dict[str, Any] + :param bucket_name: The bucket form data is downloaded. + :type bucket_name: str + :param object_name: The object to fetch. + :type object_name: str, + :param filename: The filename to fetch. + :type filename: str, + :param dry_run: Upload status without actually persisting the line items. + :type filename: str, + """ + + template_fields = ( + "bucket_name", + "object_name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + bucket_name: str, + object_name: str, + api_version: str = "v1.1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.bucket_name = bucket_name + self.object_name = object_name + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> None: + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + hook = GoogleDisplayVideo360Hook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Uploading file %s...") + # Saving file in the temporary directory, + # downloaded file from the GCS could be a 1GB size or even more + with tempfile.NamedTemporaryFile("w+") as f: + line_items = gcs_hook.download( + bucket_name=self.bucket_name, + object_name=self.object_name, + filename=f.name, + ) + f.flush() + hook.upload_line_items(line_items=line_items) + + +class GoogleDisplayVideo360CreateSDFDownloadTaskOperator(BaseOperator): + """ + Creates SDF operation task. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleDisplayVideo360CreateSDFDownloadTaskOperator` + + .. seealso:: + Check also the official API docs: + `https://developers.google.com/display-video/api/reference/rest` + + :param version: The SDF version of the downloaded file.. + :type version: str + :param partner_id: The ID of the partner to download SDF for. + :type partner_id: str + :param advertiser_id: The ID of the advertiser to download SDF for. + :type advertiser_id: str + :param parent_entity_filter: Filters on selected file types. + :type parent_entity_filter: Dict[str, Any] + :param id_filter: Filters on entities by their entity IDs. + :type id_filter: Dict[str, Any] + :param inventory_source_filter: Filters on Inventory Sources by their IDs. + :type inventory_source_filter: Dict[str, Any] + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "body_request", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + body_request: Dict[str, Any], + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.body_request = body_request + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> Dict[str, Any]: + hook = GoogleDisplayVideo360Hook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Creating operation for SDF download task...") + operation = hook.create_sdf_download_operation(body_request=self.body_request) + + return operation + + +class GoogleDisplayVideo360SDFtoGCSOperator(BaseOperator): + """ + Download SDF media and save it in the Google Cloud Storage. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleDisplayVideo360SDFtoGCSOperator` + + .. seealso:: + Check also the official API docs: + `https://developers.google.com/display-video/api/reference/rest` + + :param version: The SDF version of the downloaded file.. + :type version: str + :param partner_id: The ID of the partner to download SDF for. + :type partner_id: str + :param advertiser_id: The ID of the advertiser to download SDF for. + :type advertiser_id: str + :param parent_entity_filter: Filters on selected file types. + :type parent_entity_filter: Dict[str, Any] + :param id_filter: Filters on entities by their entity IDs. + :type id_filter: Dict[str, Any] + :param inventory_source_filter: Filters on Inventory Sources by their IDs. + :type inventory_source_filter: Dict[str, Any] + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "operation_name", + "bucket_name", + "object_name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + operation_name: str, + bucket_name: str, + object_name: str, + gzip: bool = False, + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.operation_name = operation_name + self.bucket_name = bucket_name + self.object_name = object_name + self.gzip = gzip + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: dict) -> str: + hook = GoogleDisplayVideo360Hook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Retrieving operation...") + operation = hook.get_sdf_download_operation(operation_name=self.operation_name) + + self.log.info("Creating file for upload...") + media = hook.download_media(resource_name=operation) + + self.log.info("Sending file to the Google Cloud Storage...") + with tempfile.NamedTemporaryFile() as temp_file: + hook.download_content_from_request(temp_file, media, chunk_size=1024 * 1024) + temp_file.flush() + gcs_hook.upload( + bucket_name=self.bucket_name, + object_name=self.object_name, + filename=temp_file.name, + gzip=self.gzip, + ) + + return f"{self.bucket_name}/{self.object_name}" diff --git a/reference/providers/google/marketing_platform/operators/search_ads.py b/reference/providers/google/marketing_platform/operators/search_ads.py new file mode 100644 index 0000000..8dadd31 --- /dev/null +++ b/reference/providers/google/marketing_platform/operators/search_ads.py @@ -0,0 +1,252 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Search Ads operators.""" +import json +from tempfile import NamedTemporaryFile +from typing import Any, Dict, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.marketing_platform.hooks.search_ads import ( + GoogleSearchAdsHook, +) +from airflow.utils.decorators import apply_defaults + + +class GoogleSearchAdsInsertReportOperator(BaseOperator): + """ + Inserts a report request into the reporting system. + + .. seealso: + For API documentation check: + https://developers.google.com/search-ads/v2/reference/reports/request + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleSearchAdsInsertReportOperator` + + :param report: Report to be generated + :type report: Dict[str, Any] + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "report", + "impersonation_chain", + ) + template_ext = (".json",) + + @apply_defaults + def __init__( + self, + *, + report: Dict[str, Any], + api_version: str = "v2", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.report = report + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def prepare_template(self) -> None: + # If .json is passed then we have to read the file + if isinstance(self.report, str) and self.report.endswith(".json"): + with open(self.report) as file: + self.report = json.load(file) + + def execute(self, context: dict): + hook = GoogleSearchAdsHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Generating Search Ads report") + response = hook.insert_report(report=self.report) + report_id = response.get("id") + self.xcom_push(context, key="report_id", value=report_id) + self.log.info("Report generated, id: %s", report_id) + return response + + +class GoogleSearchAdsDownloadReportOperator(BaseOperator): + """ + Downloads a report to GCS bucket. + + .. seealso: + For API documentation check: + https://developers.google.com/search-ads/v2/reference/reports/getFile + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleSearchAdsGetfileReportOperator` + + :param report_id: ID of the report. + :type report_id: str + :param bucket_name: The bucket to upload to. + :type bucket_name: str + :param report_name: The report name to set when uploading the local file. If not provided then + report_id is used. + :type report_name: str + :param gzip: Option to compress local file or file data for upload + :type gzip: bool + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "report_name", + "report_id", + "bucket_name", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + report_id: str, + bucket_name: str, + report_name: Optional[str] = None, + gzip: bool = True, + chunk_size: int = 10 * 1024 * 1024, + api_version: str = "v2", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.report_id = report_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.report_id = report_id + self.chunk_size = chunk_size + self.gzip = gzip + self.bucket_name = self._set_bucket_name(bucket_name) + self.report_name = report_name + self.impersonation_chain = impersonation_chain + + def _resolve_file_name(self, name: str) -> str: + csv = ".csv" + gzip = ".gz" + if not name.endswith(csv): + name += csv + if self.gzip: + name += gzip + return name + + @staticmethod + def _set_bucket_name(name: str) -> str: + bucket = name if not name.startswith("gs://") else name[5:] + return bucket.strip("/") + + @staticmethod + def _handle_report_fragment(fragment: bytes) -> bytes: + fragment_records = fragment.split(b"\n", 1) + if len(fragment_records) > 1: + return fragment_records[1] + return b"" + + def execute(self, context: dict): + hook = GoogleSearchAdsHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + # Resolve file name of the report + report_name = self.report_name or self.report_id + report_name = self._resolve_file_name(report_name) + + response = hook.get(report_id=self.report_id) + if not response["isReportReady"]: + raise AirflowException(f"Report {self.report_id} is not ready yet") + + # Resolve report fragments + fragments_count = len(response["files"]) + + # Download chunks of report's data + self.log.info("Downloading Search Ads report %s", self.report_id) + with NamedTemporaryFile() as temp_file: + for i in range(fragments_count): + byte_content = hook.get_file( + report_fragment=i, report_id=self.report_id + ) + fragment = ( + byte_content + if i == 0 + else self._handle_report_fragment(byte_content) + ) + temp_file.write(fragment) + + temp_file.flush() + + gcs_hook.upload( + bucket_name=self.bucket_name, + object_name=report_name, + gzip=self.gzip, + filename=temp_file.name, + ) + self.xcom_push(context, key="file_name", value=report_name) diff --git a/reference/providers/google/marketing_platform/sensors/__init__.py b/reference/providers/google/marketing_platform/sensors/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/marketing_platform/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/marketing_platform/sensors/campaign_manager.py b/reference/providers/google/marketing_platform/sensors/campaign_manager.py new file mode 100644 index 0000000..77a4503 --- /dev/null +++ b/reference/providers/google/marketing_platform/sensors/campaign_manager.py @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Campaign Manager sensor.""" +from typing import Dict, Optional, Sequence, Union + +from airflow.providers.google.marketing_platform.hooks.campaign_manager import ( + GoogleCampaignManagerHook, +) +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class GoogleCampaignManagerReportSensor(BaseSensorOperator): + """ + Check if report is ready. + + .. seealso:: + Check official API docs: + https://developers.google.com/doubleclick-advertisers/v3.3/reports/get + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleCampaignManagerReportSensor` + + :param profile_id: The DFA user profile ID. + :type profile_id: str + :param report_id: The ID of the report. + :type report_id: str + :param file_id: The ID of the report file. + :type file_id: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "profile_id", + "report_id", + "file_id", + "impersonation_chain", + ) + + def poke(self, context: Dict) -> bool: + hook = GoogleCampaignManagerHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + response = hook.get_report( + profile_id=self.profile_id, report_id=self.report_id, file_id=self.file_id + ) + self.log.info("Report status: %s", response["status"]) + return response["status"] != "PROCESSING" + + @apply_defaults + def __init__( + self, + *, + profile_id: str, + report_id: str, + file_id: str, + api_version: str = "v3.3", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + mode: str = "reschedule", + poke_interval: int = 60 * 5, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.mode = mode + self.poke_interval = poke_interval + self.profile_id = profile_id + self.report_id = report_id + self.file_id = file_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain diff --git a/reference/providers/google/marketing_platform/sensors/display_video.py b/reference/providers/google/marketing_platform/sensors/display_video.py new file mode 100644 index 0000000..d473087 --- /dev/null +++ b/reference/providers/google/marketing_platform/sensors/display_video.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Sensor for detecting the completion of DV360 reports.""" +from typing import Optional, Sequence, Union + +from airflow import AirflowException +from airflow.providers.google.marketing_platform.hooks.display_video import ( + GoogleDisplayVideo360Hook, +) +from airflow.sensors.base import BaseSensorOperator + + +class GoogleDisplayVideo360ReportSensor(BaseSensorOperator): + """ + Sensor for detecting the completion of DV360 reports. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleDisplayVideo360ReportSensor` + + :param report_id: Report ID to delete. + :type report_id: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "report_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + report_id: str, + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.report_id = report_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def poke(self, context: dict) -> bool: + hook = GoogleDisplayVideo360Hook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + + response = hook.get_query(query_id=self.report_id) + if response and not response.get("metadata", {}).get("running"): + return True + return False + + +class GoogleDisplayVideo360GetSDFDownloadOperationSensor(BaseSensorOperator): + """ + Sensor for detecting the completion of SDF operation. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleDisplayVideo360GetSDFDownloadOperationSensor` + + :param operation_name: The name of the operation resource + :type operation_name: Dict[str, Any] + :param api_version: The version of the api that will be requested for example 'v1'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + """ + + template_fields = ( + "operation_name", + "impersonation_chain", + ) + + def __init__( + self, + operation_name: str, + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + mode: str = "reschedule", + poke_interval: int = 60 * 5, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.mode = mode + self.poke_interval = poke_interval + self.operation_name = operation_name + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def poke(self, context: dict) -> bool: + hook = GoogleDisplayVideo360Hook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + operation = hook.get_sdf_download_operation(operation_name=self.operation_name) + + if "error" in operation: + raise AirflowException( + f'The operation finished in error with {operation["error"]}' + ) + if operation and operation.get("done"): + return True + return False diff --git a/reference/providers/google/marketing_platform/sensors/search_ads.py b/reference/providers/google/marketing_platform/sensors/search_ads.py new file mode 100644 index 0000000..8330b1e --- /dev/null +++ b/reference/providers/google/marketing_platform/sensors/search_ads.py @@ -0,0 +1,95 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Search Ads sensor.""" +from typing import Optional, Sequence, Union + +from airflow.providers.google.marketing_platform.hooks.search_ads import ( + GoogleSearchAdsHook, +) +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class GoogleSearchAdsReportSensor(BaseSensorOperator): + """ + Polls for the status of a report request. + + .. seealso:: + For API documentation check: + https://developers.google.com/search-ads/v2/reference/reports/get + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleSearchAdsReportSensor` + + :param report_id: ID of the report request being polled. + :type report_id: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "report_id", + "impersonation_chain", + ) + + @apply_defaults + def __init__( + self, + *, + report_id: str, + api_version: str = "v2", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + mode: str = "reschedule", + poke_interval: int = 5 * 60, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(mode=mode, poke_interval=poke_interval, **kwargs) + self.report_id = report_id + self.api_version = api_version + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def poke(self, context: dict): + hook = GoogleSearchAdsHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Checking status of %s report.", self.report_id) + response = hook.get(report_id=self.report_id) + return response["isReportReady"] diff --git a/reference/providers/google/provider.yaml b/reference/providers/google/provider.yaml new file mode 100644 index 0000000..4917083 --- /dev/null +++ b/reference/providers/google/provider.yaml @@ -0,0 +1,739 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-google +name: Google +description: | + Google services including: + + - `Google Ads `__ + - `Google Cloud (GCP) `__ + - `Google Firebase `__ + - `Google LevelDB `__ + - `Google Marketing Platform `__ + - `Google Workspace `__ (formerly Google Suite) + +versions: + - 2.1.0 + - 2.0.0 + - 1.0.0 + +integrations: + - integration-name: Google Analytics360 + external-doc-url: https://analytics.google.com/ + logo: /integration-logos/gcp/Google-Analytics.png + how-to-guide: + - /docs/apache-airflow-providers-google/operators/marketing_platform/analytics.rst + tags: [gmp] + - integration-name: Google Ads + external-doc-url: https://ads.google.com/ + logo: /integration-logos/gcp/Google-Ads.png + how-to-guide: + - /docs/apache-airflow-providers-google/operators/ads.rst + tags: [gmp] + - integration-name: Google AutoML + external-doc-url: https://cloud.google.com/automl/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/automl.rst + logo: /integration-logos/gcp/Cloud-AutoML.png + tags: [gcp] + - integration-name: Google BigQuery Data Transfer Service + external-doc-url: https://cloud.google.com/bigquery/transfer/ + logo: /integration-logos/gcp/BigQuery.png + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/bigquery_dts.rst + tags: [gcp] + - integration-name: Google BigQuery + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/bigquery.rst + external-doc-url: https://cloud.google.com/bigquery/ + logo: /integration-logos/gcp/BigQuery.png + tags: [gcp] + - integration-name: Google Bigtable + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/bigtable.rst + external-doc-url: https://cloud.google.com/bigtable/ + logo: /integration-logos/gcp/Cloud-Bigtable.png + tags: [gcp] + - integration-name: Google Cloud Build + external-doc-url: https://cloud.google.com/cloud-build/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/cloud_build.rst + logo: /integration-logos/gcp/Cloud-Build.png + tags: [gcp] + - integration-name: Google Cloud Data Loss Prevention (DLP) + external-doc-url: https://cloud.google.com/dlp/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/data_loss_prevention.rst + logo: /integration-logos/gcp/google-data-loss-prevention.png + tags: [gcp] + - integration-name: Google Cloud Firestore + external-doc-url: https://firebase.google.com/docs/firestore + how-to-guide: + - /docs/apache-airflow-providers-google/operators/firebase/firestore.rst + logo: /integration-logos/gcp/Google-Firestore.png + tags: [gcp] + - integration-name: Google Cloud Functions + external-doc-url: https://cloud.google.com/functions/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/functions.rst + logo: /integration-logos/gcp/Cloud-Functions.png + tags: [gcp] + - integration-name: Google Cloud Key Management Service (KMS) + external-doc-url: https://cloud.google.com/kms/ + logo: /integration-logos/gcp/Key-Management-Service.png + tags: [gcp] + - integration-name: Google Cloud Life Sciences + external-doc-url: https://cloud.google.com/life-sciences/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/life_sciences.rst + logo: /integration-logos/gcp/Google-Cloud-Life-Sciences.png + tags: [gcp] + - integration-name: Google Cloud Memorystore + external-doc-url: https://cloud.google.com/memorystore/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/cloud_memorystore.rst + - /docs/apache-airflow-providers-google/operators/cloud/cloud_memorystore_memcached.rst + logo: /integration-logos/gcp/Cloud-Memorystore.png + tags: [gcp] + - integration-name: Google Cloud OS Login + external-doc-url: https://cloud.google.com/compute/docs/oslogin/ + logo: /integration-logos/gcp/Google-Cloud-Generic.png + tags: [gcp] + - integration-name: Google Cloud Pub/Sub + external-doc-url: https://cloud.google.com/pubsub/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/pubsub.rst + logo: /integration-logos/gcp/Cloud-PubSub.png + tags: [gcp] + - integration-name: Google Cloud Secret Manager + external-doc-url: https://cloud.google.com/secret-manager/ + logo: /integration-logos/gcp/Google-Cloud-Secret-Manager.png + tags: [gcp] + - integration-name: Google Cloud Spanner + external-doc-url: https://cloud.google.com/spanner/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/spanner.rst + logo: /integration-logos/gcp/Cloud-Spanner.png + tags: [gcp] + - integration-name: Google Cloud Speech-to-Text + external-doc-url: https://cloud.google.com/speech-to-text/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/speech_to_text.rst + - /docs/apache-airflow-providers-google/operators/cloud/translate_speech.rst + logo: /integration-logos/gcp/Cloud-Speech-to-Text.png + tags: [gcp] + - integration-name: Google Cloud SQL + external-doc-url: https://cloud.google.com/sql/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst + logo: /integration-logos/gcp/Cloud-SQL.png + tags: [gcp] + - integration-name: Google Cloud Stackdriver + external-doc-url: https://cloud.google.com/stackdriver + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/stackdriver.rst + logo: /integration-logos/gcp/Google-Cloud-Stackdriver.png + tags: [gcp] + - integration-name: Google Cloud Storage (GCS) + external-doc-url: https://cloud.google.com/gcs/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/gcs.rst + logo: /integration-logos/gcp/Cloud-Storage.png + tags: [gcp] + - integration-name: Google Cloud Tasks + external-doc-url: https://cloud.google.com/tasks/ + logo: /integration-logos/gcp/Cloud-Tasks.png + tags: [gcp] + - integration-name: Google Cloud Text-to-Speech + external-doc-url: https://cloud.google.com/text-to-speech/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/text_to_speech.rst + logo: /integration-logos/gcp/Cloud-Text-to-Speech.png + tags: [gcp] + - integration-name: Google Cloud Translation + external-doc-url: https://cloud.google.com/translate/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/translate.rst + - /docs/apache-airflow-providers-google/operators/cloud/translate_speech.rst + logo: /integration-logos/gcp/Cloud-Translation-API.png + tags: [gcp] + - integration-name: Google Cloud Video Intelligence + external-doc-url: https://cloud.google.com/video_intelligence/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/video_intelligence.rst + logo: /integration-logos/gcp/Cloud-Video-Intelligence-API.png + tags: [gcp] + - integration-name: Google Cloud Vision + external-doc-url: https://cloud.google.com/vision/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/vision.rst + logo: /integration-logos/gcp/Cloud-Vision-API.png + tags: [gcp] + - integration-name: Google Compute Engine + external-doc-url: https://cloud.google.com/compute/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/compute.rst + - /docs/apache-airflow-providers-google/operators/cloud/compute_ssh.rst + logo: /integration-logos/gcp/Compute-Engine.png + tags: [gcp] + - integration-name: Google Data Proc + external-doc-url: https://cloud.yandex.com/services/data-proc + logo: /integration-logos/gcp/Google-Data-Proc.png + tags: [gcp] + - integration-name: Google Data Catalog + external-doc-url: https://cloud.google.com/data-catalog + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/datacatalog.rst + logo: /integration-logos/gcp/Google-Data-Catalog.png + tags: [gcp] + - integration-name: Google Dataflow + external-doc-url: https://cloud.google.com/dataflow/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/dataflow.rst + logo: /integration-logos/gcp/Cloud-Dataflow.png + tags: [gcp] + - integration-name: Google Data Fusion + external-doc-url: https://cloud.google.com/data-fusion/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/datafusion.rst + logo: /integration-logos/gcp/Google-Data-Fusion.png + tags: [gcp] + - integration-name: Google Dataprep + external-doc-url: https://cloud.google.com/dataprep/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/dataprep.rst + logo: /integration-logos/gcp/Google-Dataprep.png + tags: [gcp] + - integration-name: Google Dataproc + external-doc-url: https://cloud.google.com/dataproc/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/dataproc.rst + logo: /integration-logos/gcp/Cloud-Dataproc.png + tags: [gcp] + - integration-name: Google Datastore + external-doc-url: https://cloud.google.com/datastore/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/datastore.rst + logo: /integration-logos/gcp/Cloud-Datastore.png + tags: [gcp] + - integration-name: Google Deployment Manager + external-doc-url: https://cloud.google.com/deployment-manager/ + logo: /integration-logos/gcp/Google-Deployment-Manager.png + tags: [gcp] + - integration-name: Google API Python Client + external-doc-url: https://github.com/googleapis/google-api-python-client + logo: /integration-logos/gcp/Google-API-Python-Client.png + tags: [google] + - integration-name: Google Campaign Manager + external-doc-url: https://developers.google.com/doubleclick-advertisers + how-to-guide: + - /docs/apache-airflow-providers-google/operators/marketing_platform/campaign_manager.rst + logo: /integration-logos/gcp/Google-Campaign-Manager.png + tags: [gcp] + - integration-name: Google Cloud + external-doc-url: https://cloud.google.com/ + logo: /integration-logos/gcp/Google-Cloud.png + tags: [gcp] + - integration-name: Google Discovery API + external-doc-url: https://developers.google.com/discovery + logo: /integration-logos/gcp/Google-Cloud-Generic.png + tags: [google] + - integration-name: Google Display&Video 360 + external-doc-url: https://marketingplatform.google.com/about/display-video-360/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/marketing_platform/display_video.rst + logo: /integration-logos/gcp/Google-Display-And-Video-360.png + tags: [gmp] + - integration-name: Google Drive + external-doc-url: https://www.google.com/drive/ + logo: /integration-logos/gcp/Google-Drive.png + tags: [google] + - integration-name: Google Search Ads 360 + external-doc-url: https://marketingplatform.google.com/about/search-ads-360/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/marketing_platform/search_ads.rst + logo: /integration-logos/gcp/Google-Search-Ads360.png + tags: [gmp] + - integration-name: Google + external-doc-url: https://developer.google.com/ + logo: /integration-logos/gcp/Google.png + tags: [google] + - integration-name: Google Spreadsheet + external-doc-url: https://www.google.com/intl/en/sheets/about/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/suite/sheets.rst + logo: /integration-logos/gcp/Google-Spreadsheet.png + tags: [google] + - integration-name: Google Cloud Storage Transfer Service + external-doc-url: https://cloud.google.com/storage/transfer/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/cloud_storage_transfer_service.rst + logo: /integration-logos/gcp/Cloud-Storage.png + tags: [gcp] + - integration-name: Google Kubernetes Engine + external-doc-url: https://cloud.google.com/kubernetes_engine/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst + logo: /integration-logos/gcp/Kubernetes-Engine.png + tags: [gcp] + - integration-name: Google Machine Learning Engine + external-doc-url: https://cloud.google.com/ai-platform/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/mlengine.rst + logo: /integration-logos/gcp/AI-Platform.png + tags: [gcp] + - integration-name: Google Cloud Natural Language + external-doc-url: https://cloud.google.com/natural-language/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/natural_language.rst + logo: /integration-logos/gcp/Cloud-NLP.png + tags: [gcp] + - integration-name: Google Cloud Workflows + external-doc-url: https://cloud.google.com/workflows/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/workflows.rst + tags: [gcp] + - integration-name: Google LevelDB + external-doc-url: https://github.com/google/leveldb/blob/master/doc/index.md + how-to-guide: + - /docs/apache-airflow-providers-google/operators/leveldb/leveldb.rst + tags: [google] + +operators: + - integration-name: Google Ads + python-modules: + - airflow.providers.google.ads.operators.ads + - integration-name: Google AutoML + python-modules: + - airflow.providers.google.cloud.operators.automl + - integration-name: Google BigQuery + python-modules: + - airflow.providers.google.cloud.operators.bigquery + - integration-name: Google BigQuery Data Transfer Service + python-modules: + - airflow.providers.google.cloud.operators.bigquery_dts + - integration-name: Google Bigtable + python-modules: + - airflow.providers.google.cloud.operators.bigtable + - integration-name: Google Cloud Build + python-modules: + - airflow.providers.google.cloud.operators.cloud_build + - integration-name: Google Cloud Memorystore + python-modules: + - airflow.providers.google.cloud.operators.cloud_memorystore + - integration-name: Google Cloud SQL + python-modules: + - airflow.providers.google.cloud.operators.cloud_sql + - integration-name: Google Cloud Storage Transfer Service + python-modules: + - airflow.providers.google.cloud.operators.cloud_storage_transfer_service + - integration-name: Google Compute Engine + python-modules: + - airflow.providers.google.cloud.operators.compute + - integration-name: Google Data Catalog + python-modules: + - airflow.providers.google.cloud.operators.datacatalog + - integration-name: Google Dataflow + python-modules: + - airflow.providers.google.cloud.operators.dataflow + - integration-name: Google Data Fusion + python-modules: + - airflow.providers.google.cloud.operators.datafusion + - integration-name: Google Dataprep + python-modules: + - airflow.providers.google.cloud.operators.dataprep + - integration-name: Google Dataproc + python-modules: + - airflow.providers.google.cloud.operators.dataproc + - integration-name: Google Datastore + python-modules: + - airflow.providers.google.cloud.operators.datastore + - integration-name: Google Cloud Data Loss Prevention (DLP) + python-modules: + - airflow.providers.google.cloud.operators.dlp + - integration-name: Google Cloud Functions + python-modules: + - airflow.providers.google.cloud.operators.functions + - integration-name: Google Cloud Storage (GCS) + python-modules: + - airflow.providers.google.cloud.operators.gcs + - integration-name: Google Kubernetes Engine + python-modules: + - airflow.providers.google.cloud.operators.kubernetes_engine + - integration-name: Google Cloud Life Sciences + python-modules: + - airflow.providers.google.cloud.operators.life_sciences + - integration-name: Google Machine Learning Engine + python-modules: + - airflow.providers.google.cloud.operators.mlengine + - integration-name: Google Cloud Natural Language + python-modules: + - airflow.providers.google.cloud.operators.natural_language + - integration-name: Google Cloud Pub/Sub + python-modules: + - airflow.providers.google.cloud.operators.pubsub + - integration-name: Google Cloud Spanner + python-modules: + - airflow.providers.google.cloud.operators.spanner + - integration-name: Google Cloud Speech-to-Text + python-modules: + - airflow.providers.google.cloud.operators.speech_to_text + - integration-name: Google Cloud Stackdriver + python-modules: + - airflow.providers.google.cloud.operators.stackdriver + - integration-name: Google Cloud Tasks + python-modules: + - airflow.providers.google.cloud.operators.tasks + - integration-name: Google Cloud Text-to-Speech + python-modules: + - airflow.providers.google.cloud.operators.text_to_speech + - airflow.providers.google.cloud.operators.translate_speech + - integration-name: Google Cloud Translation + python-modules: + - airflow.providers.google.cloud.operators.translate + - airflow.providers.google.cloud.operators.translate_speech + - integration-name: Google Cloud Video Intelligence + python-modules: + - airflow.providers.google.cloud.operators.video_intelligence + - integration-name: Google Cloud Vision + python-modules: + - airflow.providers.google.cloud.operators.vision + - integration-name: Google Cloud Workflows + python-modules: + - airflow.providers.google.cloud.operators.workflows + - integration-name: Google Cloud Firestore + python-modules: + - airflow.providers.google.firebase.operators.firestore + - integration-name: Google Analytics360 + python-modules: + - airflow.providers.google.marketing_platform.operators.analytics + - integration-name: Google Campaign Manager + python-modules: + - airflow.providers.google.marketing_platform.operators.campaign_manager + - integration-name: Google Display&Video 360 + python-modules: + - airflow.providers.google.marketing_platform.operators.display_video + - integration-name: Google Search Ads 360 + python-modules: + - airflow.providers.google.marketing_platform.operators.search_ads + - integration-name: Google Spreadsheet + python-modules: + - airflow.providers.google.suite.operators.sheets + - integration-name: Google LevelDB + python-modules: + - airflow.providers.google.leveldb.operators.leveldb + +sensors: + - integration-name: Google BigQuery + python-modules: + - airflow.providers.google.cloud.sensors.bigquery + - integration-name: Google BigQuery Data Transfer Service + python-modules: + - airflow.providers.google.cloud.sensors.bigquery_dts + - integration-name: Google Bigtable + python-modules: + - airflow.providers.google.cloud.sensors.bigtable + - integration-name: Google Cloud Storage Transfer Service + python-modules: + - airflow.providers.google.cloud.sensors.cloud_storage_transfer_service + - integration-name: Google Dataflow + python-modules: + - airflow.providers.google.cloud.sensors.dataflow + - integration-name: Google Dataproc + python-modules: + - airflow.providers.google.cloud.sensors.dataproc + - integration-name: Google Cloud Storage (GCS) + python-modules: + - airflow.providers.google.cloud.sensors.gcs + - integration-name: Google Cloud Pub/Sub + python-modules: + - airflow.providers.google.cloud.sensors.pubsub + - integration-name: Google Cloud Workflows + python-modules: + - airflow.providers.google.cloud.sensors.workflows + - integration-name: Google Drive + python-modules: + - airflow.providers.google.suite.sensors.drive + - integration-name: Google Campaign Manager + python-modules: + - airflow.providers.google.marketing_platform.sensors.campaign_manager + - integration-name: Google Display&Video 360 + python-modules: + - airflow.providers.google.marketing_platform.sensors.display_video + - integration-name: Google Search Ads 360 + python-modules: + - airflow.providers.google.marketing_platform.sensors.search_ads + +hooks: + - integration-name: Google Ads + python-modules: + - airflow.providers.google.ads.hooks.ads + - integration-name: Google AutoML + python-modules: + - airflow.providers.google.cloud.hooks.automl + - integration-name: Google BigQuery + python-modules: + - airflow.providers.google.cloud.hooks.bigquery + - integration-name: Google BigQuery Data Transfer Service + python-modules: + - airflow.providers.google.cloud.hooks.bigquery_dts + - integration-name: Google Bigtable + python-modules: + - airflow.providers.google.cloud.hooks.bigtable + - integration-name: Google Cloud Build + python-modules: + - airflow.providers.google.cloud.hooks.cloud_build + - integration-name: Google Cloud Memorystore + python-modules: + - airflow.providers.google.cloud.hooks.cloud_memorystore + - integration-name: Google Cloud SQL + python-modules: + - airflow.providers.google.cloud.hooks.cloud_sql + - integration-name: Google Cloud Storage Transfer Service + python-modules: + - airflow.providers.google.cloud.hooks.cloud_storage_transfer_service + - integration-name: Google Compute Engine + python-modules: + - airflow.providers.google.cloud.hooks.compute + - airflow.providers.google.cloud.hooks.compute_ssh + - integration-name: Google Data Catalog + python-modules: + - airflow.providers.google.cloud.hooks.datacatalog + - integration-name: Google Dataflow + python-modules: + - airflow.providers.google.cloud.hooks.dataflow + - integration-name: Google Data Fusion + python-modules: + - airflow.providers.google.cloud.hooks.datafusion + - integration-name: Google Dataprep + python-modules: + - airflow.providers.google.cloud.hooks.dataprep + - integration-name: Google Dataproc + python-modules: + - airflow.providers.google.cloud.hooks.dataproc + - integration-name: Google Datastore + python-modules: + - airflow.providers.google.cloud.hooks.datastore + - integration-name: Google Cloud Data Loss Prevention (DLP) + python-modules: + - airflow.providers.google.cloud.hooks.dlp + - integration-name: Google Cloud Functions + python-modules: + - airflow.providers.google.cloud.hooks.functions + - integration-name: Google Cloud Storage (GCS) + python-modules: + - airflow.providers.google.cloud.hooks.gcs + - integration-name: Google Deployment Manager + python-modules: + - airflow.providers.google.cloud.hooks.gdm + - integration-name: Google Cloud Key Management Service (KMS) + python-modules: + - airflow.providers.google.cloud.hooks.kms + - integration-name: Google Kubernetes Engine + python-modules: + - airflow.providers.google.cloud.hooks.kubernetes_engine + - integration-name: Google Cloud Life Sciences + python-modules: + - airflow.providers.google.cloud.hooks.life_sciences + - integration-name: Google Machine Learning Engine + python-modules: + - airflow.providers.google.cloud.hooks.mlengine + - integration-name: Google Cloud Natural Language + python-modules: + - airflow.providers.google.cloud.hooks.natural_language + - integration-name: Google Cloud OS Login + python-modules: + - airflow.providers.google.cloud.hooks.os_login + - integration-name: Google Cloud Pub/Sub + python-modules: + - airflow.providers.google.cloud.hooks.pubsub + - integration-name: Google Cloud Secret Manager + python-modules: + - airflow.providers.google.cloud.hooks.secret_manager + - integration-name: Google Cloud Spanner + python-modules: + - airflow.providers.google.cloud.hooks.spanner + - integration-name: Google Cloud Speech-to-Text + python-modules: + - airflow.providers.google.cloud.hooks.speech_to_text + - integration-name: Google Cloud Stackdriver + python-modules: + - airflow.providers.google.cloud.hooks.stackdriver + - integration-name: Google Cloud Tasks + python-modules: + - airflow.providers.google.cloud.hooks.tasks + - integration-name: Google Cloud Text-to-Speech + python-modules: + - airflow.providers.google.cloud.hooks.text_to_speech + - integration-name: Google Cloud Translation + python-modules: + - airflow.providers.google.cloud.hooks.translate + - integration-name: Google Cloud Video Intelligence + python-modules: + - airflow.providers.google.cloud.hooks.video_intelligence + - integration-name: Google Cloud Vision + python-modules: + - airflow.providers.google.cloud.hooks.vision + - integration-name: Google Cloud Workflows + python-modules: + - airflow.providers.google.cloud.hooks.workflows + - integration-name: Google + python-modules: + - airflow.providers.google.common.hooks.base_google + - integration-name: Google Discovery API + python-modules: + - airflow.providers.google.common.hooks.discovery_api + - integration-name: Google Cloud Firestore + python-modules: + - airflow.providers.google.firebase.hooks.firestore + - integration-name: Google Analytics360 + python-modules: + - airflow.providers.google.marketing_platform.hooks.analytics + - integration-name: Google Campaign Manager + python-modules: + - airflow.providers.google.marketing_platform.hooks.campaign_manager + - integration-name: Google Display&Video 360 + python-modules: + - airflow.providers.google.marketing_platform.hooks.display_video + - integration-name: Google Search Ads 360 + python-modules: + - airflow.providers.google.marketing_platform.hooks.search_ads + - integration-name: Google Drive + python-modules: + - airflow.providers.google.suite.hooks.drive + - integration-name: Google Spreadsheet + python-modules: + - airflow.providers.google.suite.hooks.sheets + - integration-name: Google LevelDB + python-modules: + - airflow.providers.google.leveldb.hooks.leveldb + +transfers: + - source-integration-name: Presto + target-integration-name: Google Cloud Storage (GCS) + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/presto_to_gcs.rst + python-module: airflow.providers.google.cloud.transfers.presto_to_gcs + - source-integration-name: SQL + target-integration-name: Google Cloud Storage (GCS) + python-module: airflow.providers.google.cloud.transfers.sql_to_gcs + - source-integration-name: Google Cloud Storage (GCS) + target-integration-name: Google Drive + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/gcs_to_gdrive.rst + python-module: airflow.providers.google.suite.transfers.gcs_to_gdrive + - source-integration-name: Google Drive + target-integration-name: Google Cloud Storage (GCS) + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/gdrive_to_gcs.rst + python-module: airflow.providers.google.cloud.transfers.gdrive_to_gcs + - source-integration-name: Microsoft SQL Server (MSSQL) + target-integration-name: Google Cloud Storage (GCS) + python-module: airflow.providers.google.cloud.transfers.mssql_to_gcs + - source-integration-name: Microsoft Azure FileShare + target-integration-name: Google Cloud Storage (GCS) + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/azure_fileshare_to_gcs.rst + python-module: airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs + - source-integration-name: Apache Cassandra + target-integration-name: Google Cloud Storage (GCS) + python-module: airflow.providers.google.cloud.transfers.cassandra_to_gcs + - source-integration-name: Google Spreadsheet + target-integration-name: Google Cloud Storage (GCS) + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/sheets_to_gcs.rst + python-module: airflow.providers.google.cloud.transfers.sheets_to_gcs + - source-integration-name: Amazon Simple Storage Service (S3) + target-integration-name: Google Cloud Storage (GCS) + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/s3_to_gcs.rst + python-module: airflow.providers.google.cloud.transfers.s3_to_gcs + - source-integration-name: Google Cloud Storage (GCS) + target-integration-name: SSH File Transfer Protocol (SFTP) + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/gcs_to_sftp.rst + python-module: airflow.providers.google.cloud.transfers.gcs_to_sftp + - source-integration-name: PostgreSQL + target-integration-name: Google Cloud Storage (GCS) + python-module: airflow.providers.google.cloud.transfers.postgres_to_gcs + - source-integration-name: Google BigQuery + target-integration-name: MySQL + python-module: airflow.providers.google.cloud.transfers.bigquery_to_mysql + - source-integration-name: Google Cloud Storage (GCS) + target-integration-name: Google BigQuery + python-module: airflow.providers.google.cloud.transfers.gcs_to_bigquery + - source-integration-name: Google Cloud Storage (GCS) + target-integration-name: Google Cloud Storage (GCS) + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/gcs_to_gcs.rst + python-module: airflow.providers.google.cloud.transfers.gcs_to_gcs + - source-integration-name: Facebook Ads + target-integration-name: Google Cloud Storage (GCS) + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/facebook_ads_to_gcs.rst + python-module: airflow.providers.google.cloud.transfers.facebook_ads_to_gcs + - source-integration-name: SSH File Transfer Protocol (SFTP) + target-integration-name: Google Cloud Storage (GCS) + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/sftp_to_gcs.rst + python-module: airflow.providers.google.cloud.transfers.sftp_to_gcs + - source-integration-name: Microsoft Azure Data Lake Storage + target-integration-name: Google Cloud Storage (GCS) + python-module: airflow.providers.google.cloud.transfers.adls_to_gcs + - source-integration-name: Google BigQuery + target-integration-name: Google BigQuery + python-module: airflow.providers.google.cloud.transfers.bigquery_to_bigquery + - source-integration-name: MySQL + target-integration-name: Google Cloud Storage (GCS) + python-module: airflow.providers.google.cloud.transfers.mysql_to_gcs + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/mysql_to_gcs.rst + - source-integration-name: Oracle + target-integration-name: Google Cloud Storage (GCS) + python-module: airflow.providers.google.cloud.transfers.oracle_to_gcs + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/oracle_to_gcs.rst + - source-integration-name: Google Cloud Storage (GCS) + target-integration-name: Google Spreadsheet + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/gcs_to_sheets.rst + python-module: airflow.providers.google.suite.transfers.gcs_to_sheets + - source-integration-name: Local + target-integration-name: Google Cloud Storage (GCS) + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/local_to_gcs.rst + python-module: airflow.providers.google.cloud.transfers.local_to_gcs + - source-integration-name: Google BigQuery + target-integration-name: Google Cloud Storage (GCS) + python-module: airflow.providers.google.cloud.transfers.bigquery_to_gcs + - source-integration-name: Google Cloud Storage (GCS) + target-integration-name: Local + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/gcs_to_local.rst + python-module: airflow.providers.google.cloud.transfers.gcs_to_local + - source-integration-name: Google Drive + target-integration-name: Local + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/gdrive_to_local.rst + python-module: airflow.providers.google.cloud.transfers.gdrive_to_local + - source-integration-name: Salesforce + target-integration-name: Google Cloud Storage (GCS) + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/salesforce_to_gcs.rst + python-module: airflow.providers.google.cloud.transfers.salesforce_to_gcs + - source-integration-name: Google Ads + target-integration-name: Google Cloud Storage (GCS) + python-module: airflow.providers.google.ads.transfers.ads_to_gcs + +hook-class-names: + - airflow.providers.google.common.hooks.base_google.GoogleBaseHook + - airflow.providers.google.cloud.hooks.dataprep.GoogleDataprepHook + - airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook + - airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook + - airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook + - airflow.providers.google.cloud.hooks.bigquery.BigQueryHook + - airflow.providers.google.common.hooks.leveldb.LevelDBHook + +extra-links: + - airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink + - airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink + - airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink diff --git a/reference/providers/google/suite/__init__.py b/reference/providers/google/suite/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/google/suite/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/suite/example_dags/__init__.py b/reference/providers/google/suite/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/google/suite/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/suite/example_dags/example_gcs_to_gdrive.py b/reference/providers/google/suite/example_dags/example_gcs_to_gdrive.py new file mode 100644 index 0000000..845a254 --- /dev/null +++ b/reference/providers/google/suite/example_dags/example_gcs_to_gdrive.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example DAG using GoogleCloudStorageToGoogleDriveOperator. +""" +import os + +from airflow import models +from airflow.providers.google.suite.transfers.gcs_to_gdrive import ( + GCSToGoogleDriveOperator, +) +from airflow.utils.dates import days_ago + +GCS_TO_GDRIVE_BUCKET = os.environ.get("GCS_TO_DRIVE_BUCKET", "example-object") + +with models.DAG( + "example_gcs_to_gdrive", + schedule_interval=None, # Override to match your needs, + start_date=days_ago(1), + tags=["example"], +) as dag: + # [START howto_operator_gcs_to_gdrive_copy_single_file] + copy_single_file = GCSToGoogleDriveOperator( + task_id="copy_single_file", + source_bucket=GCS_TO_GDRIVE_BUCKET, + source_object="sales/january.avro", + destination_object="copied_sales/january-backup.avro", + ) + # [END howto_operator_gcs_to_gdrive_copy_single_file] + # [START howto_operator_gcs_to_gdrive_copy_files] + copy_files = GCSToGoogleDriveOperator( + task_id="copy_files", + source_bucket=GCS_TO_GDRIVE_BUCKET, + source_object="sales/*", + destination_object="copied_sales/", + ) + # [END howto_operator_gcs_to_gdrive_copy_files] + # [START howto_operator_gcs_to_gdrive_move_files] + move_files = GCSToGoogleDriveOperator( + task_id="move_files", + source_bucket=GCS_TO_GDRIVE_BUCKET, + source_object="sales/*.avro", + move_object=True, + ) + # [END howto_operator_gcs_to_gdrive_move_files] diff --git a/reference/providers/google/suite/example_dags/example_gcs_to_sheets.py b/reference/providers/google/suite/example_dags/example_gcs_to_sheets.py new file mode 100644 index 0000000..3ca4c30 --- /dev/null +++ b/reference/providers/google/suite/example_dags/example_gcs_to_sheets.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.google.cloud.transfers.sheets_to_gcs import ( + GoogleSheetsToGCSOperator, +) +from airflow.providers.google.suite.transfers.gcs_to_sheets import ( + GCSToGoogleSheetsOperator, +) +from airflow.utils.dates import days_ago + +BUCKET = os.environ.get("GCP_GCS_BUCKET", "example-test-bucket3") +SPREADSHEET_ID = os.environ.get("SPREADSHEET_ID", "example-spreadsheetID") +NEW_SPREADSHEET_ID = os.environ.get("NEW_SPREADSHEET_ID", "1234567890qwerty") + +with models.DAG( + "example_gcs_to_sheets", + start_date=days_ago(1), + schedule_interval=None, # Override to match your needs + tags=["example"], +) as dag: + + upload_sheet_to_gcs = GoogleSheetsToGCSOperator( + task_id="upload_sheet_to_gcs", + destination_bucket=BUCKET, + spreadsheet_id=SPREADSHEET_ID, + ) + + # [START upload_gcs_to_sheets] + upload_gcs_to_sheet = GCSToGoogleSheetsOperator( + task_id="upload_gcs_to_sheet", + bucket_name=BUCKET, + object_name="{{ task_instance.xcom_pull('upload_sheet_to_gcs')[0] }}", + spreadsheet_id=NEW_SPREADSHEET_ID, + ) + # [END upload_gcs_to_sheets] + + upload_sheet_to_gcs >> upload_gcs_to_sheet diff --git a/reference/providers/google/suite/example_dags/example_sheets.py b/reference/providers/google/suite/example_dags/example_sheets.py new file mode 100644 index 0000000..9ba6245 --- /dev/null +++ b/reference/providers/google/suite/example_dags/example_sheets.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.transfers.sheets_to_gcs import ( + GoogleSheetsToGCSOperator, +) +from airflow.providers.google.suite.operators.sheets import ( + GoogleSheetsCreateSpreadsheetOperator, +) +from airflow.providers.google.suite.transfers.gcs_to_sheets import ( + GCSToGoogleSheetsOperator, +) +from airflow.utils.dates import days_ago + +GCS_BUCKET = os.environ.get("SHEETS_GCS_BUCKET", "test28397ye") +SPREADSHEET_ID = os.environ.get("SPREADSHEET_ID", "1234567890qwerty") +NEW_SPREADSHEET_ID = os.environ.get("NEW_SPREADSHEET_ID", "1234567890qwerty") + +SPREADSHEET = { + "properties": {"title": "Test1"}, + "sheets": [{"properties": {"title": "Sheet1"}}], +} + +with models.DAG( + "example_sheets_gcs", + schedule_interval=None, # Override to match your needs, + start_date=days_ago(1), + tags=["example"], +) as dag: + # [START upload_sheet_to_gcs] + upload_sheet_to_gcs = GoogleSheetsToGCSOperator( + task_id="upload_sheet_to_gcs", + destination_bucket=GCS_BUCKET, + spreadsheet_id=SPREADSHEET_ID, + ) + # [END upload_sheet_to_gcs] + + # [START create_spreadsheet] + create_spreadsheet = GoogleSheetsCreateSpreadsheetOperator( + task_id="create_spreadsheet", spreadsheet=SPREADSHEET + ) + # [END create_spreadsheet] + + # [START print_spreadsheet_url] + print_spreadsheet_url = BashOperator( + task_id="print_spreadsheet_url", + bash_command="echo {{ task_instance.xcom_pull('create_spreadsheet', key='spreadsheet_url') }}", + ) + # [END print_spreadsheet_url] + + # [START upload_gcs_to_sheet] + upload_gcs_to_sheet = GCSToGoogleSheetsOperator( + task_id="upload_gcs_to_sheet", + bucket_name=GCS_BUCKET, + object_name="{{ task_instance.xcom_pull('upload_sheet_to_gcs')[0] }}", + spreadsheet_id=NEW_SPREADSHEET_ID, + ) + # [END upload_gcs_to_sheet] + + create_spreadsheet >> print_spreadsheet_url + upload_sheet_to_gcs >> upload_gcs_to_sheet diff --git a/reference/providers/google/suite/hooks/__init__.py b/reference/providers/google/suite/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/google/suite/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/suite/hooks/drive.py b/reference/providers/google/suite/hooks/drive.py new file mode 100644 index 0000000..1a43727 --- /dev/null +++ b/reference/providers/google/suite/hooks/drive.py @@ -0,0 +1,260 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Google Drive service""" +from io import TextIOWrapper +from typing import Any, Optional, Sequence, Union + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import Resource, build +from googleapiclient.http import HttpRequest, MediaFileUpload + + +class GoogleDriveHook(GoogleBaseHook): + """ + Hook for the Google Drive APIs. + + :param api_version: API version used (for example v3). + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + _conn = None # type: Optional[Resource] + + def __init__( + self, + api_version: str = "v3", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.api_version = api_version + + def get_conn(self) -> Any: + """ + Retrieves the connection to Google Drive. + + :return: Google Drive services object. + """ + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "drive", self.api_version, http=http_authorized, cache_discovery=False + ) + return self._conn + + def _ensure_folders_exists(self, path: str) -> str: + service = self.get_conn() + current_parent = "root" + folders = path.split("/") + depth = 0 + # First tries to enter directories + for current_folder in folders: + self.log.debug( + "Looking for %s directory with %s parent", + current_folder, + current_parent, + ) + conditions = [ + "mimeType = 'application/vnd.google-apps.folder'", + f"name='{current_folder}'", + f"'{current_parent}' in parents", + ] + result = ( + service.files() # pylint: disable=no-member + .list( + q=" and ".join(conditions), spaces="drive", fields="files(id, name)" + ) + .execute(num_retries=self.num_retries) + ) + files = result.get("files", []) + if not files: + self.log.info("Not found %s directory", current_folder) + # If the directory does not exist, break loops + break + depth += 1 + current_parent = files[0].get("id") + + # Check if there are directories to process + if depth != len(folders): + # Create missing directories + for current_folder in folders[depth:]: + file_metadata = { + "name": current_folder, + "mimeType": "application/vnd.google-apps.folder", + "parents": [current_parent], + } + file = ( + service.files() # pylint: disable=no-member + .create(body=file_metadata, fields="id") + .execute(num_retries=self.num_retries) + ) + self.log.info("Created %s directory", current_folder) + + current_parent = file.get("id") + # Return the ID of the last directory + return current_parent + + def get_media_request(self, file_id: str) -> HttpRequest: + """ + Returns a get_media http request to a Google Drive object. + + :param file_id: The Google Drive file id + :type file_id: str + :return: request + :rtype: HttpRequest + """ + service = self.get_conn() + request = service.files().get_media(fileId=file_id) # pylint: disable=no-member + return request + + def exists(self, folder_id: str, file_name: str, drive_id: Optional[str] = None): + """ + Checks to see if a file exists within a Google Drive folder + + :param folder_id: The id of the Google Drive folder in which the file resides + :type folder_id: str + :param file_name: The name of a file in Google Drive + :type file_name: str + :param drive_id: Optional. The id of the shared Google Drive in which the file resides. + :type drive_id: str + :return: True if the file exists, False otherwise + :rtype: bool + """ + return bool( + self.get_file_id( + folder_id=folder_id, file_name=file_name, drive_id=drive_id + ) + ) + + def get_file_id( + self, folder_id: str, file_name: str, drive_id: Optional[str] = None + ): + """ + Returns the file id of a Google Drive file + + :param folder_id: The id of the Google Drive folder in which the file resides + :type folder_id: str + :param file_name: The name of a file in Google Drive + :type file_name: str + :param drive_id: Optional. The id of the shared Google Drive in which the file resides. + :type drive_id: str + :return: Google Drive file id if the file exists, otherwise None + :rtype: str if file exists else None + """ + query = f"name = '{file_name}'" + if folder_id: + query += f" and parents in '{folder_id}'" + service = self.get_conn() + if drive_id: + files = ( + service.files() # pylint: disable=no-member + .list( + q=query, + spaces="drive", + fields="files(id, mimeType)", + orderBy="modifiedTime desc", + driveId=drive_id, + includeItemsFromAllDrives=True, + supportsAllDrives=True, + corpora="drive", + ) + .execute(num_retries=self.num_retries) + ) + else: + files = ( + service.files() # pylint: disable=no-member + .list( + q=query, + spaces="drive", + fields="files(id, mimeType)", + orderBy="modifiedTime desc", + ) + .execute(num_retries=self.num_retries) + ) + file_metadata = {} + if files["files"]: + file_metadata = { + "id": files["files"][0]["id"], + "mime_type": files["files"][0]["mimeType"], + } + return file_metadata + + def upload_file(self, local_location: str, remote_location: str) -> str: + """ + Uploads a file that is available locally to a Google Drive service. + + :param local_location: The path where the file is available. + :type local_location: str + :param remote_location: The path where the file will be send + :type remote_location: str + :return: File ID + :rtype: str + """ + service = self.get_conn() + directory_path, _, file_name = remote_location.rpartition("/") + if directory_path: + parent = self._ensure_folders_exists(directory_path) + else: + parent = "root" + + file_metadata = {"name": file_name, "parents": [parent]} + media = MediaFileUpload(local_location) + file = ( + service.files() # pylint: disable=no-member + .create(body=file_metadata, media_body=media, fields="id") + .execute(num_retries=self.num_retries) + ) + self.log.info( + "File %s uploaded to gdrive://%s.", local_location, remote_location + ) + return file.get("id") + + def download_file( + self, file_id: str, file_handle: TextIOWrapper, chunk_size: int = 104857600 + ): + """ + Download a file from Google Drive. + + :param file_id: the id of the file + :type file_id: str + :param file_handle: file handle used to write the content to + :type file_handle: io.TextIOWrapper + """ + request = self.get_media_request(file_id=file_id) + self.download_content_from_request( + file_handle=file_handle, request=request, chunk_size=chunk_size + ) diff --git a/reference/providers/google/suite/hooks/sheets.py b/reference/providers/google/suite/hooks/sheets.py new file mode 100644 index 0000000..cd281a6 --- /dev/null +++ b/reference/providers/google/suite/hooks/sheets.py @@ -0,0 +1,477 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains a Google Sheets API hook""" + +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from googleapiclient.discovery import build + + +class GSheetsHook(GoogleBaseHook): + """ + Interact with Google Sheets via Google Cloud connection + Reading and writing cells in Google Sheet: + https://developers.google.com/sheets/api/guides/values + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param api_version: API Version + :type api_version: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v4", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + delegate_to=delegate_to, + impersonation_chain=impersonation_chain, + ) + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.delegate_to = delegate_to + self._conn = None + + def get_conn(self) -> Any: + """ + Retrieves connection to Google Sheets. + + :return: Google Sheets services object. + :rtype: Any + """ + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + "sheets", self.api_version, http=http_authorized, cache_discovery=False + ) + + return self._conn + + def get_values( + self, + spreadsheet_id: str, + range_: str, + major_dimension: str = "DIMENSION_UNSPECIFIED", + value_render_option: str = "FORMATTED_VALUE", + date_time_render_option: str = "SERIAL_NUMBER", + ) -> list: + """ + Gets values from Google Sheet from a single range + https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets.values/get + + :param spreadsheet_id: The Google Sheet ID to interact with + :type spreadsheet_id: str + :param range_: The A1 notation of the values to retrieve. + :type range_: str + :param major_dimension: Indicates which dimension an operation should apply to. + DIMENSION_UNSPECIFIED, ROWS, or COLUMNS + :type major_dimension: str + :param value_render_option: Determines how values should be rendered in the output. + FORMATTED_VALUE, UNFORMATTED_VALUE, or FORMULA + :type value_render_option: str + :param date_time_render_option: Determines how dates should be rendered in the output. + SERIAL_NUMBER or FORMATTED_STRING + :type date_time_render_option: str + :return: An array of sheet values from the specified sheet. + :rtype: List + """ + service = self.get_conn() + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .get( + spreadsheetId=spreadsheet_id, + range=range_, + majorDimension=major_dimension, + valueRenderOption=value_render_option, + dateTimeRenderOption=date_time_render_option, + ) + .execute(num_retries=self.num_retries) + ) + + return response["values"] + + def batch_get_values( + self, + spreadsheet_id: str, + ranges: List, + major_dimension: str = "DIMENSION_UNSPECIFIED", + value_render_option: str = "FORMATTED_VALUE", + date_time_render_option: str = "SERIAL_NUMBER", + ) -> dict: + """ + Gets values from Google Sheet from a list of ranges + https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets.values/batchGet + + :param spreadsheet_id: The Google Sheet ID to interact with + :type spreadsheet_id: str + :param ranges: The A1 notation of the values to retrieve. + :type ranges: List + :param major_dimension: Indicates which dimension an operation should apply to. + DIMENSION_UNSPECIFIED, ROWS, or COLUMNS + :type major_dimension: str + :param value_render_option: Determines how values should be rendered in the output. + FORMATTED_VALUE, UNFORMATTED_VALUE, or FORMULA + :type value_render_option: str + :param date_time_render_option: Determines how dates should be rendered in the output. + SERIAL_NUMBER or FORMATTED_STRING + :type date_time_render_option: str + :return: Google Sheets API response. + :rtype: Dict + """ + service = self.get_conn() + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .batchGet( + spreadsheetId=spreadsheet_id, + ranges=ranges, + majorDimension=major_dimension, + valueRenderOption=value_render_option, + dateTimeRenderOption=date_time_render_option, + ) + .execute(num_retries=self.num_retries) + ) + + return response + + def update_values( + self, + spreadsheet_id: str, + range_: str, + values: List, + major_dimension: str = "ROWS", + value_input_option: str = "RAW", + include_values_in_response: bool = False, + value_render_option: str = "FORMATTED_VALUE", + date_time_render_option: str = "SERIAL_NUMBER", + ) -> dict: + """ + Updates values from Google Sheet from a single range + https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets.values/update + + :param spreadsheet_id: The Google Sheet ID to interact with. + :type spreadsheet_id: str + :param range_: The A1 notation of the values to retrieve. + :type range_: str + :param values: Data within a range of the spreadsheet. + :type values: List + :param major_dimension: Indicates which dimension an operation should apply to. + DIMENSION_UNSPECIFIED, ROWS, or COLUMNS + :type major_dimension: str + :param value_input_option: Determines how input data should be interpreted. + RAW or USER_ENTERED + :type value_input_option: str + :param include_values_in_response: Determines if the update response should + include the values of the cells that were updated. + :type include_values_in_response: bool + :param value_render_option: Determines how values should be rendered in the output. + FORMATTED_VALUE, UNFORMATTED_VALUE, or FORMULA + :type value_render_option: str + :param date_time_render_option: Determines how dates should be rendered in the output. + SERIAL_NUMBER or FORMATTED_STRING + :type date_time_render_option: str + :return: Google Sheets API response. + :rtype: Dict + """ + service = self.get_conn() + body = {"range": range_, "majorDimension": major_dimension, "values": values} + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .update( + spreadsheetId=spreadsheet_id, + range=range_, + valueInputOption=value_input_option, + includeValuesInResponse=include_values_in_response, + responseValueRenderOption=value_render_option, + responseDateTimeRenderOption=date_time_render_option, + body=body, + ) + .execute(num_retries=self.num_retries) + ) + + return response + + def batch_update_values( + self, + spreadsheet_id: str, + ranges: List, + values: List, + major_dimension: str = "ROWS", + value_input_option: str = "RAW", + include_values_in_response: bool = False, + value_render_option: str = "FORMATTED_VALUE", + date_time_render_option: str = "SERIAL_NUMBER", + ) -> dict: + """ + Updates values from Google Sheet for multiple ranges + https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets.values/batchUpdate + + :param spreadsheet_id: The Google Sheet ID to interact with + :type spreadsheet_id: str + :param ranges: The A1 notation of the values to retrieve. + :type ranges: List + :param values: Data within a range of the spreadsheet. + :type values: List + :param major_dimension: Indicates which dimension an operation should apply to. + DIMENSION_UNSPECIFIED, ROWS, or COLUMNS + :type major_dimension: str + :param value_input_option: Determines how input data should be interpreted. + RAW or USER_ENTERED + :type value_input_option: str + :param include_values_in_response: Determines if the update response should + include the values of the cells that were updated. + :type include_values_in_response: bool + :param value_render_option: Determines how values should be rendered in the output. + FORMATTED_VALUE, UNFORMATTED_VALUE, or FORMULA + :type value_render_option: str + :param date_time_render_option: Determines how dates should be rendered in the output. + SERIAL_NUMBER or FORMATTED_STRING + :type date_time_render_option: str + :return: Google Sheets API response. + :rtype: Dict + """ + if len(ranges) != len(values): + raise AirflowException( + "'Ranges' and 'Lists' must be of equal length. \n \ + 'Ranges' is of length: {} and \n \ + 'Values' is of length: {}.".format( + str(len(ranges)), str(len(values)) + ) + ) + service = self.get_conn() + data = [] + for idx, range_ in enumerate(ranges): + value_range = { + "range": range_, + "majorDimension": major_dimension, + "values": values[idx], + } + data.append(value_range) + body = { + "valueInputOption": value_input_option, + "data": data, + "includeValuesInResponse": include_values_in_response, + "responseValueRenderOption": value_render_option, + "responseDateTimeRenderOption": date_time_render_option, + } + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .batchUpdate(spreadsheetId=spreadsheet_id, body=body) + .execute(num_retries=self.num_retries) + ) + + return response + + def append_values( + self, + spreadsheet_id: str, + range_: str, + values: List, + major_dimension: str = "ROWS", + value_input_option: str = "RAW", + insert_data_option: str = "OVERWRITE", + include_values_in_response: bool = False, + value_render_option: str = "FORMATTED_VALUE", + date_time_render_option: str = "SERIAL_NUMBER", + ) -> dict: + """ + Append values from Google Sheet from a single range + https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets.values/append + + :param spreadsheet_id: The Google Sheet ID to interact with + :type spreadsheet_id: str + :param range_: The A1 notation of the values to retrieve. + :type range_: str + :param values: Data within a range of the spreadsheet. + :type values: List + :param major_dimension: Indicates which dimension an operation should apply to. + DIMENSION_UNSPECIFIED, ROWS, or COLUMNS + :type major_dimension: str + :param value_input_option: Determines how input data should be interpreted. + RAW or USER_ENTERED + :type value_input_option: str + :param insert_data_option: Determines how existing data is changed when new data is input. + OVERWRITE or INSERT_ROWS + :type insert_data_option: str + :param include_values_in_response: Determines if the update response should + include the values of the cells that were updated. + :type include_values_in_response: bool + :param value_render_option: Determines how values should be rendered in the output. + FORMATTED_VALUE, UNFORMATTED_VALUE, or FORMULA + :type value_render_option: str + :param date_time_render_option: Determines how dates should be rendered in the output. + SERIAL_NUMBER or FORMATTED_STRING + :type date_time_render_option: str + :return: Google Sheets API response. + :rtype: Dict + """ + service = self.get_conn() + body = {"range": range_, "majorDimension": major_dimension, "values": values} + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .append( + spreadsheetId=spreadsheet_id, + range=range_, + valueInputOption=value_input_option, + insertDataOption=insert_data_option, + includeValuesInResponse=include_values_in_response, + responseValueRenderOption=value_render_option, + responseDateTimeRenderOption=date_time_render_option, + body=body, + ) + .execute(num_retries=self.num_retries) + ) + + return response + + def clear(self, spreadsheet_id: str, range_: str) -> dict: + """ + Clear values from Google Sheet from a single range + https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets.values/clear + + :param spreadsheet_id: The Google Sheet ID to interact with + :type spreadsheet_id: str + :param range_: The A1 notation of the values to retrieve. + :type range_: str + :return: Google Sheets API response. + :rtype: Dict + """ + service = self.get_conn() + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .clear(spreadsheetId=spreadsheet_id, range=range_) + .execute(num_retries=self.num_retries) + ) + + return response + + def batch_clear(self, spreadsheet_id: str, ranges: list) -> dict: + """ + Clear values from Google Sheet from a list of ranges + https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets.values/batchClear + + :param spreadsheet_id: The Google Sheet ID to interact with + :type spreadsheet_id: str + :param ranges: The A1 notation of the values to retrieve. + :type ranges: List + :return: Google Sheets API response. + :rtype: Dict + """ + service = self.get_conn() + body = {"ranges": ranges} + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .batchClear(spreadsheetId=spreadsheet_id, body=body) + .execute(num_retries=self.num_retries) + ) + + return response + + def get_spreadsheet(self, spreadsheet_id: str): + """ + Retrieves spreadsheet matching the given id. + + :param spreadsheet_id: The spreadsheet id. + :type spreadsheet_id: str + :return: An spreadsheet that matches the sheet filter. + """ + response = ( + self.get_conn() # pylint: disable=no-member + .spreadsheets() + .get(spreadsheetId=spreadsheet_id) + .execute(num_retries=self.num_retries) + ) + return response + + def get_sheet_titles( + self, spreadsheet_id: str, sheet_filter: Optional[List[str]] = None + ): + """ + Retrieves the sheet titles from a spreadsheet matching the given id and sheet filter. + + :param spreadsheet_id: The spreadsheet id. + :type spreadsheet_id: str + :param sheet_filter: List of sheet title to retrieve from sheet. + :type sheet_filter: List[str] + :return: An list of sheet titles from the specified sheet that match + the sheet filter. + """ + response = self.get_spreadsheet(spreadsheet_id=spreadsheet_id) + + if sheet_filter: + titles = [ + sh["properties"]["title"] + for sh in response["sheets"] + if sh["properties"]["title"] in sheet_filter + ] + else: + titles = [sh["properties"]["title"] for sh in response["sheets"]] + return titles + + def create_spreadsheet(self, spreadsheet: Dict[str, Any]) -> Dict[str, Any]: + """ + Creates a spreadsheet, returning the newly created spreadsheet. + + :param spreadsheet: an instance of Spreadsheet + https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets#Spreadsheet + :type spreadsheet: Dict[str, Any] + :return: An spreadsheet object. + """ + self.log.info("Creating spreadsheet: %s", spreadsheet["properties"]["title"]) + # pylint: disable=no-member + response = ( + self.get_conn() + .spreadsheets() + .create(body=spreadsheet) + .execute(num_retries=self.num_retries) + ) + self.log.info("Spreadsheet: %s created", spreadsheet["properties"]["title"]) + return response diff --git a/reference/providers/google/suite/operators/__init__.py b/reference/providers/google/suite/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/google/suite/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/suite/operators/sheets.py b/reference/providers/google/suite/operators/sheets.py new file mode 100644 index 0000000..007ff43 --- /dev/null +++ b/reference/providers/google/suite/operators/sheets.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.suite.hooks.sheets import GSheetsHook +from airflow.utils.decorators import apply_defaults + + +class GoogleSheetsCreateSpreadsheetOperator(BaseOperator): + """ + Creates a new spreadsheet. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleSheetsCreateSpreadsheetOperator` + + :param spreadsheet: an instance of Spreadsheet + https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets#Spreadsheet + :type spreadsheet: Dict[str, Any] + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "spreadsheet", + "impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + spreadsheet: Dict[str, Any], + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.gcp_conn_id = gcp_conn_id + self.spreadsheet = spreadsheet + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Any) -> Dict[str, Any]: + hook = GSheetsHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + spreadsheet = hook.create_spreadsheet(spreadsheet=self.spreadsheet) + self.xcom_push(context, "spreadsheet_id", spreadsheet["spreadsheetId"]) + self.xcom_push(context, "spreadsheet_url", spreadsheet["spreadsheetUrl"]) + return spreadsheet diff --git a/reference/providers/google/suite/sensors/__init__.py b/reference/providers/google/suite/sensors/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/suite/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/suite/sensors/drive.py b/reference/providers/google/suite/sensors/drive.py new file mode 100644 index 0000000..c0fc6b6 --- /dev/null +++ b/reference/providers/google/suite/sensors/drive.py @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Drive sensors.""" + +from typing import Optional, Sequence, Union + +from airflow.providers.google.suite.hooks.drive import GoogleDriveHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class GoogleDriveFileExistenceSensor(BaseSensorOperator): + """ + Checks for the existence of a file in Google Cloud Storage. + + :param folder_id: The Google drive folder where the file is. + :type folder_id: str + :param file_name: The name of the file to check in Google Drive + :type file_name: str + :param drive_id: Optional. The id of the shared Google Drive in which the file resides. + :type drive_id: str + :param gcp_conn_id: The connection ID to use when + connecting to Google Cloud Storage. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "folder_id", + "file_name", + "drive_id", + "impersonation_chain", + ) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + *, + folder_id: str, + file_name: str, + drive_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self.folder_id = folder_id + self.file_name = file_name + self.drive_id = drive_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def poke(self, context: dict) -> bool: + self.log.info( + "Sensor is checking for the file %s in the folder %s", + self.file_name, + self.folder_id, + ) + hook = GoogleDriveHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + return hook.exists( + folder_id=self.folder_id, file_name=self.file_name, drive_id=self.drive_id + ) diff --git a/reference/providers/google/suite/transfers/__init__.py b/reference/providers/google/suite/transfers/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/google/suite/transfers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/google/suite/transfers/gcs_to_gdrive.py b/reference/providers/google/suite/transfers/gcs_to_gdrive.py new file mode 100644 index 0000000..cb12269 --- /dev/null +++ b/reference/providers/google/suite/transfers/gcs_to_gdrive.py @@ -0,0 +1,182 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Storage to Google Drive transfer operator.""" +import tempfile +from typing import Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.suite.hooks.drive import GoogleDriveHook +from airflow.utils.decorators import apply_defaults + +WILDCARD = "*" + + +class GCSToGoogleDriveOperator(BaseOperator): + """ + Copies objects from a Google Cloud Storage service to a Google Drive service, with renaming + if requested. + + Using this operator requires the following OAuth 2.0 scope: + + .. code-block:: none + + https://www.googleapis.com/auth/drive + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GCSToGoogleDriveOperator` + + :param source_bucket: The source Google Cloud Storage bucket where the object is. (templated) + :type source_bucket: str + :param source_object: The source name of the object to copy in the Google cloud + storage bucket. (templated) + You can use only one wildcard for objects (filenames) within your bucket. The wildcard can appear + inside the object name or at the end of the object name. Appending a wildcard to the bucket name + is unsupported. + :type source_object: str + :param destination_object: The destination name of the object in the destination Google Drive + service. (templated) + If a wildcard is supplied in the source_object argument, this is the prefix that will be prepended + to the final destination objects' paths. + Note that the source path's part before the wildcard will be removed; + if it needs to be retained it should be appended to destination_object. + For example, with prefix ``foo/*`` and destination_object ``blah/``, the file ``foo/baz`` will be + copied to ``blah/baz``; to retain the prefix write the destination_object as e.g. ``blah/foo``, in + which case the copied file will be named ``blah/foo/baz``. + :type destination_object: str + :param move_object: When move object is True, the object is moved instead of copied to the new location. + This is the equivalent of a mv command as opposed to a cp command. + :type move_object: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + "source_bucket", + "source_object", + "destination_object", + "impersonation_chain", + ) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + *, + source_bucket: str, + source_object: str, + destination_object: Optional[str] = None, + move_object: bool = False, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.source_bucket = source_bucket + self.source_object = source_object + self.destination_object = destination_object + self.move_object = move_object + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.gcs_hook = None # type: Optional[GCSHook] + self.gdrive_hook = None # type: Optional[GoogleDriveHook] + + def execute(self, context): + + self.gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + self.gdrive_hook = GoogleDriveHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + if WILDCARD in self.source_object: + total_wildcards = self.source_object.count(WILDCARD) + if total_wildcards > 1: + error_msg = ( + "Only one wildcard '*' is allowed in source_object parameter. " + "Found {} in {}.".format(total_wildcards, self.source_object) + ) + + raise AirflowException(error_msg) + + prefix, delimiter = self.source_object.split(WILDCARD, 1) + objects = self.gcs_hook.list( + self.source_bucket, prefix=prefix, delimiter=delimiter + ) + + for source_object in objects: + if self.destination_object is None: + destination_object = source_object + else: + destination_object = source_object.replace( + prefix, self.destination_object, 1 + ) + + self._copy_single_object( + source_object=source_object, destination_object=destination_object + ) + else: + self._copy_single_object( + source_object=self.source_object, + destination_object=self.destination_object, + ) + + def _copy_single_object(self, source_object, destination_object): + self.log.info( + "Executing copy of gs://%s/%s to gdrive://%s", + self.source_bucket, + source_object, + destination_object, + ) + + with tempfile.NamedTemporaryFile() as file: + filename = file.name + self.gcs_hook.download( + bucket_name=self.source_bucket, + object_name=source_object, + filename=filename, + ) + self.gdrive_hook.upload_file( + local_location=filename, remote_location=destination_object + ) + + if self.move_object: + self.gcs_hook.delete(self.source_bucket, source_object) diff --git a/reference/providers/google/suite/transfers/gcs_to_sheets.py b/reference/providers/google/suite/transfers/gcs_to_sheets.py new file mode 100644 index 0000000..b74b6e3 --- /dev/null +++ b/reference/providers/google/suite/transfers/gcs_to_sheets.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import csv +from tempfile import NamedTemporaryFile +from typing import Any, Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.suite.hooks.sheets import GSheetsHook +from airflow.utils.decorators import apply_defaults + + +class GCSToGoogleSheetsOperator(BaseOperator): + """ + Uploads .csv file from Google Cloud Storage to provided Google Spreadsheet. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GCSToGoogleSheets` + + :param spreadsheet_id: The Google Sheet ID to interact with. + :type spreadsheet_id: str + :param bucket_name: Name of GCS bucket.: + :type bucket_name: str + :param object_name: Path to the .csv file on the GCS bucket. + :type object_name: str + :param spreadsheet_range: The A1 notation of the values to retrieve. + :type spreadsheet_range: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = [ + "spreadsheet_id", + "bucket_name", + "object_name", + "spreadsheet_range", + "impersonation_chain", + ] + + @apply_defaults + def __init__( + self, + *, + spreadsheet_id: str, + bucket_name: str, + object_name: str, + spreadsheet_range: str = "Sheet1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.gcp_conn_id = gcp_conn_id + self.spreadsheet_id = spreadsheet_id + self.spreadsheet_range = spreadsheet_range + self.bucket_name = bucket_name + self.object_name = object_name + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Any) -> None: + sheet_hook = GSheetsHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + with NamedTemporaryFile("w+") as temp_file: + # Download data + gcs_hook.download( + bucket_name=self.bucket_name, + object_name=self.object_name, + filename=temp_file.name, + ) + + # Upload data + values = list(csv.reader(temp_file)) + sheet_hook.update_values( + spreadsheet_id=self.spreadsheet_id, + range_=self.spreadsheet_range, + values=values, + ) diff --git a/reference/providers/grpc/CHANGELOG.rst b/reference/providers/grpc/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/grpc/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/grpc/__init__.py b/reference/providers/grpc/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/grpc/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/grpc/hooks/__init__.py b/reference/providers/grpc/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/grpc/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/grpc/hooks/grpc.py b/reference/providers/grpc/hooks/grpc.py new file mode 100644 index 0000000..61ffdbb --- /dev/null +++ b/reference/providers/grpc/hooks/grpc.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""GRPC Hook""" +from typing import Any, Callable, Dict, Generator, List, Optional + +import grpc +from airflow.exceptions import AirflowConfigException +from airflow.hooks.base import BaseHook +from google import auth as google_auth +from google.auth import jwt as google_auth_jwt +from google.auth.transport import grpc as google_auth_transport_grpc +from google.auth.transport import requests as google_auth_transport_requests + + +class GrpcHook(BaseHook): + """ + General interaction with gRPC servers. + + :param grpc_conn_id: The connection ID to use when fetching connection info. + :type grpc_conn_id: str + :param interceptors: a list of gRPC interceptor objects which would be applied + to the connected gRPC channel. None by default. + :type interceptors: a list of gRPC interceptors based on or extends the four + official gRPC interceptors, eg, UnaryUnaryClientInterceptor, + UnaryStreamClientInterceptor, StreamUnaryClientInterceptor, + StreamStreamClientInterceptor. + :param custom_connection_func: The customized connection function to return gRPC channel. + :type custom_connection_func: python callable objects that accept the connection as + its only arg. Could be partial or lambda. + """ + + conn_name_attr = "grpc_conn_id" + default_conn_name = "grpc_default" + conn_type = "grpc" + hook_name = "GRPC Connection" + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import StringField + + return { + "extra__grpc__auth_type": StringField( + lazy_gettext("Grpc Auth Type"), widget=BS3TextFieldWidget() + ), + "extra__grpc__credential_pem_file": StringField( + lazy_gettext("Credential Keyfile Path"), widget=BS3TextFieldWidget() + ), + "extra__grpc__scopes": StringField( + lazy_gettext("Scopes (comma separated)"), widget=BS3TextFieldWidget() + ), + } + + def __init__( + self, + grpc_conn_id: str = default_conn_name, + interceptors: Optional[List[Callable]] = None, + custom_connection_func: Optional[Callable] = None, + ) -> None: + super().__init__() + self.grpc_conn_id = grpc_conn_id + self.conn = self.get_connection(self.grpc_conn_id) + self.extras = self.conn.extra_dejson + self.interceptors = interceptors if interceptors else [] + self.custom_connection_func = custom_connection_func + + def get_conn(self) -> grpc.Channel: + base_url = self.conn.host + + if self.conn.port: + base_url = base_url + ":" + str(self.conn.port) + + auth_type = self._get_field("auth_type") + + if auth_type == "NO_AUTH": + channel = grpc.insecure_channel(base_url) + elif auth_type in {"SSL", "TLS"}: + credential_file_name = self._get_field("credential_pem_file") + with open(credential_file_name, "rb") as credential_file: + creds = grpc.ssl_channel_credentials(credential_file.read()) + channel = grpc.secure_channel(base_url, creds) + elif auth_type == "JWT_GOOGLE": + credentials, _ = google_auth.default() + jwt_creds = google_auth_jwt.OnDemandCredentials.from_signing_credentials( + credentials + ) + channel = google_auth_transport_grpc.secure_authorized_channel( + jwt_creds, None, base_url + ) + elif auth_type == "OATH_GOOGLE": + scopes = self._get_field("scopes").split(",") + credentials, _ = google_auth.default(scopes=scopes) + request = google_auth_transport_requests.Request() + channel = google_auth_transport_grpc.secure_authorized_channel( + credentials, request, base_url + ) + elif auth_type == "CUSTOM": + if not self.custom_connection_func: + raise AirflowConfigException( + "Customized connection function not set, not able to establish a channel" + ) + channel = self.custom_connection_func(self.conn) + else: + raise AirflowConfigException( + "auth_type not supported or not provided, channel cannot be established,\ + given value: %s" + % str(auth_type) + ) + + if self.interceptors: + for interceptor in self.interceptors: + channel = grpc.intercept_channel(channel, interceptor) + + return channel + + def run( + self, + stub_class: Callable, + call_func: str, + streaming: bool = False, + data: Optional[dict] = None, + ) -> Generator: + """Call gRPC function and yield response to caller""" + if data is None: + data = {} + with self.get_conn() as channel: + stub = stub_class(channel) + try: + rpc_func = getattr(stub, call_func) + response = rpc_func(**data) + if not streaming: + yield response + else: + yield from response + except grpc.RpcError as ex: + self.log.exception( + "Error occurred when calling the grpc service: %s, method: %s \ + status code: %s, error details: %s", + stub.__class__.__name__, + call_func, + ex.code(), # pylint: disable=no-member + ex.details(), # pylint: disable=no-member + ) + raise ex + + def _get_field(self, field_name: str) -> str: + """ + Fetches a field from extras, and returns it. This is some Airflow + magic. The grpc hook type adds custom UI elements + to the hook page, which allow admins to specify scopes, credential pem files, etc. + They get formatted as shown below. + """ + full_field_name = f"extra__grpc__{field_name}" + return self.extras[full_field_name] diff --git a/reference/providers/grpc/operators/__init__.py b/reference/providers/grpc/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/grpc/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/grpc/operators/grpc.py b/reference/providers/grpc/operators/grpc.py new file mode 100644 index 0000000..afbf51a --- /dev/null +++ b/reference/providers/grpc/operators/grpc.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Callable, Dict, List, Optional + +from airflow.models import BaseOperator +from airflow.providers.grpc.hooks.grpc import GrpcHook +from airflow.utils.decorators import apply_defaults + + +class GrpcOperator(BaseOperator): + """ + Calls a gRPC endpoint to execute an action + + :param stub_class: The stub client to use for this gRPC call + :type stub_class: gRPC stub class generated from proto file + :param call_func: The client function name to call the gRPC endpoint + :type call_func: gRPC client function name for the endpoint generated from proto file, str + :param grpc_conn_id: The connection to run the operator against + :type grpc_conn_id: str + :param data: The data to pass to the rpc call + :type data: A dict with key value pairs as kwargs of the call_func + :param interceptors: A list of gRPC interceptor objects to be used on the channel + :type interceptors: A list of gRPC interceptor objects, has to be initialized + :param custom_connection_func: The customized connection function to return channel object + :type custom_connection_func: A python function that returns channel object, take in + a connection object, can be a partial function + :param streaming: A flag to indicate if the call is a streaming call + :type streaming: boolean + :param response_callback: The callback function to process the response from gRPC call + :type response_callback: A python function that process the response from gRPC call, + takes in response object and context object, context object can be used to perform + push xcom or other after task actions + :param log_response: A flag to indicate if we need to log the response + :type log_response: boolean + """ + + template_fields = ("stub_class", "call_func", "data") + + @apply_defaults + def __init__( + self, + *, + stub_class: Callable, + call_func: str, + grpc_conn_id: str = "grpc_default", + data: Optional[dict] = None, + interceptors: Optional[List[Callable]] = None, + custom_connection_func: Optional[Callable] = None, + streaming: bool = False, + response_callback: Optional[Callable] = None, + log_response: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.stub_class = stub_class + self.call_func = call_func + self.grpc_conn_id = grpc_conn_id + self.data = data or {} + self.interceptors = interceptors + self.custom_connection_func = custom_connection_func + self.streaming = streaming + self.log_response = log_response + self.response_callback = response_callback + + def _get_grpc_hook(self) -> GrpcHook: + return GrpcHook( + self.grpc_conn_id, + interceptors=self.interceptors, + custom_connection_func=self.custom_connection_func, + ) + + def execute(self, context: Dict) -> None: + hook = self._get_grpc_hook() + self.log.info("Calling gRPC service") + + # grpc hook always yield + responses = hook.run( + self.stub_class, self.call_func, streaming=self.streaming, data=self.data + ) + + for response in responses: + self._handle_response(response, context) + + def _handle_response(self, response: Any, context: Dict) -> None: + if self.log_response: + self.log.info(repr(response)) + if self.response_callback: + self.response_callback(response, context) diff --git a/reference/providers/grpc/provider.yaml b/reference/providers/grpc/provider.yaml new file mode 100644 index 0000000..9edf4eb --- /dev/null +++ b/reference/providers/grpc/provider.yaml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-grpc +name: gRPC +description: | + `gRPC `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: gRPC + external-doc-url: https://grpc.io/ + tags: [protocol] + +operators: + - integration-name: gRPC + python-modules: + - airflow.providers.grpc.operators.grpc + +hooks: + - integration-name: gRPC + python-modules: + - airflow.providers.grpc.hooks.grpc + +hook-class-names: + - airflow.providers.grpc.hooks.grpc.GrpcHook diff --git a/reference/providers/hashicorp/CHANGELOG.rst b/reference/providers/hashicorp/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/hashicorp/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/hashicorp/__init__.py b/reference/providers/hashicorp/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/hashicorp/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/hashicorp/_internal_client/__init__.py b/reference/providers/hashicorp/_internal_client/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/hashicorp/_internal_client/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/hashicorp/_internal_client/vault_client.py b/reference/providers/hashicorp/_internal_client/vault_client.py new file mode 100644 index 0000000..6422a6b --- /dev/null +++ b/reference/providers/hashicorp/_internal_client/vault_client.py @@ -0,0 +1,527 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional + +import hvac + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.utils.log.logging_mixin import LoggingMixin +from hvac.exceptions import InvalidPath, VaultError +from requests import Response + +DEFAULT_KUBERNETES_JWT_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" +DEFAULT_KV_ENGINE_VERSION = 2 + + +VALID_KV_VERSIONS: List[int] = [1, 2] +VALID_AUTH_TYPES: List[str] = [ + "approle", + "aws_iam", + "azure", + "github", + "gcp", + "kubernetes", + "ldap", + "radius", + "token", + "userpass", +] + + +class _VaultClient(LoggingMixin): # pylint: disable=too-many-instance-attributes + """ + Retrieves Authenticated client from Hashicorp Vault. This is purely internal class promoting + authentication code reuse between the Hook and the SecretBackend, it should not be used directly in + Airflow DAGs. Use VaultBackend for backend integration and Hook in case you want to communicate + with VaultHook using standard Airflow Connection definition. + + :param url: Base URL for the Vault instance being addressed. + :type url: str + :param auth_type: Authentication Type for Vault. Default is ``token``. Available values are in + ('approle', 'aws_iam', 'azure', 'github', 'gcp', 'kubernetes', 'ldap', 'radius', 'token', 'userpass') + :type auth_type: str + :param auth_mount_point: It can be used to define mount_point for authentication chosen + Default depends on the authentication method used. + :type auth_mount_point: str + :param mount_point: The "path" the secret engine was mounted on. Default is "secret". Note that + this mount_point is not used for authentication if authentication is done via a + different engine. For authentication mount_points see, auth_mount_point. + :type mount_point: str + :param kv_engine_version: Selects the version of the engine to run (``1`` or ``2``, default: ``2``). + :type kv_engine_version: int + :param token: Authentication token to include in requests sent to Vault + (for ``token`` and ``github`` auth_type). + :type token: str + :param token_path: path to file containing authentication token to include in requests sent to Vault + (for ``token`` and ``github`` auth_type). + :type token_path: str + :param username: Username for Authentication (for ``ldap`` and ``userpass`` auth_types). + :type username: str + :param password: Password for Authentication (for ``ldap`` and ``userpass`` auth_types). + :type password: str + :param key_id: Key ID for Authentication (for ``aws_iam`` and ''azure`` auth_type). + :type key_id: str + :param secret_id: Secret ID for Authentication (for ``approle``, ``aws_iam`` and ``azure`` auth_types). + :type secret_id: str + :param role_id: Role ID for Authentication (for ``approle``, ``aws_iam`` auth_types). + :type role_id: str + :param kubernetes_role: Role for Authentication (for ``kubernetes`` auth_type). + :type kubernetes_role: str + :param kubernetes_jwt_path: Path for kubernetes jwt token (for ``kubernetes`` auth_type, default: + ``/var/run/secrets/kubernetes.io/serviceaccount/token``). + :type kubernetes_jwt_path: str + :param gcp_key_path: Path to Google Cloud Service Account key file (JSON) (for ``gcp`` auth_type). + Mutually exclusive with gcp_keyfile_dict + :type gcp_key_path: str + :param gcp_keyfile_dict: Dictionary of keyfile parameters. (for ``gcp`` auth_type). + Mutually exclusive with gcp_key_path + :type gcp_keyfile_dict: dict + :param gcp_scopes: Comma-separated string containing OAuth2 scopes (for ``gcp`` auth_type). + :type gcp_scopes: str + :param azure_tenant_id: The tenant id for the Azure Active Directory (for ``azure`` auth_type). + :type azure_tenant_id: str + :param azure_re# The configured URL for the application registered in Azure Active Directory + (for ``azure`` auth_type). + :type azure_re# str + :param radius_host: Host for radius (for ``radius`` auth_type). + :type radius_host: str + :param radius_secret: Secret for radius (for ``radius`` auth_type). + :type radius_secret: str + :param radius_port: Port for radius (for ``radius`` auth_type). + :type radius_port: int + """ + + def __init__( # pylint: disable=too-many-arguments + self, + url: Optional[str] = None, + auth_type: str = "token", + auth_mount_point: Optional[str] = None, + mount_point: str = "secret", + kv_engine_version: Optional[int] = None, + token: Optional[str] = None, + token_path: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + key_id: Optional[str] = None, + secret_id: Optional[str] = None, + role_id: Optional[str] = None, + kubernetes_role: Optional[str] = None, + kubernetes_jwt_path: Optional[ + str + ] = "/var/run/secrets/kubernetes.io/serviceaccount/token", + gcp_key_path: Optional[str] = None, + gcp_keyfile_dict: Optional[dict] = None, + gcp_scopes: Optional[str] = None, + azure_tenant_id: Optional[str] = None, + azure_re# Optional[str] = None, + radius_host: Optional[str] = None, + radius_secret: Optional[str] = None, + radius_port: Optional[int] = None, + **kwargs, + ): + super().__init__() + if kv_engine_version and kv_engine_version not in VALID_KV_VERSIONS: + raise VaultError( + f"The version is not supported: {kv_engine_version}. " + f"It should be one of {VALID_KV_VERSIONS}" + ) + if auth_type not in VALID_AUTH_TYPES: + raise VaultError( + f"The auth_type is not supported: {auth_type}. " + f"It should be one of {VALID_AUTH_TYPES}" + ) + if auth_type == "token" and not token and not token_path: + raise VaultError( + "The 'token' authentication type requires 'token' or 'token_path'" + ) + if auth_type == "github" and not token and not token_path: + raise VaultError( + "The 'github' authentication type requires 'token' or 'token_path'" + ) + if auth_type == "approle" and not role_id: + raise VaultError("The 'approle' authentication type requires 'role_id'") + if auth_type == "kubernetes": + if not kubernetes_role: + raise VaultError( + "The 'kubernetes' authentication type requires 'kubernetes_role'" + ) + if not kubernetes_jwt_path: + raise VaultError( + "The 'kubernetes' authentication type requires 'kubernetes_jwt_path'" + ) + if auth_type == "azure": + if not azure_re# + raise VaultError( + "The 'azure' authentication type requires 'azure_resource'" + ) + if not azure_tenant_id: + raise VaultError( + "The 'azure' authentication type requires 'azure_tenant_id'" + ) + if auth_type == "radius": + if not radius_host: + raise VaultError( + "The 'radius' authentication type requires 'radius_host'" + ) + if not radius_secret: + raise VaultError( + "The 'radius' authentication type requires 'radius_secret'" + ) + + self.kv_engine_version = kv_engine_version if kv_engine_version else 2 + self.url = url + self.auth_type = auth_type + self.kwargs = kwargs + self.token = token + self.token_path = token_path + self.auth_mount_point = auth_mount_point + self.mount_point = mount_point + self.username = username + self.password = password + self.key_id = key_id + self.secret_id = secret_id + self.role_id = role_id + self.kubernetes_role = kubernetes_role + self.kubernetes_jwt_path = kubernetes_jwt_path + self.gcp_key_path = gcp_key_path + self.gcp_keyfile_dict = gcp_keyfile_dict + self.gcp_scopes = gcp_scopes + self.azure_tenant_id = azure_tenant_id + self.azure_resource = azure_resource + self.radius_host = radius_host + self.radius_secret = radius_secret + self.radius_port = radius_port + + @cached_property + def client(self) -> hvac.Client: + """ + Return an authenticated Hashicorp Vault client. + + :rtype: hvac.Client + :return: Vault Client + + """ + _client = hvac.Client(url=self.url, **self.kwargs) + if self.auth_type == "approle": + self._auth_approle(_client) + elif self.auth_type == "aws_iam": + self._auth_aws_iam(_client) + elif self.auth_type == "azure": + self._auth_azure(_client) + elif self.auth_type == "gcp": + self._auth_gcp(_client) + elif self.auth_type == "github": + self._auth_github(_client) + elif self.auth_type == "kubernetes": + self._auth_kubernetes(_client) + elif self.auth_type == "ldap": + self._auth_ldap(_client) + elif self.auth_type == "radius": + self._auth_radius(_client) + elif self.auth_type == "token": + self._set_token(_client) + elif self.auth_type == "userpass": + self._auth_userpass(_client) + else: + raise VaultError(f"Authentication type '{self.auth_type}' not supported") + + if _client.is_authenticated(): + return _client + else: + raise VaultError("Vault Authentication Error!") + + def _auth_userpass(self, _client: hvac.Client) -> None: + if self.auth_mount_point: + _client.auth_userpass( + username=self.username, + password=self.password, + mount_point=self.auth_mount_point, + ) + else: + _client.auth_userpass(username=self.username, password=self.password) + + def _auth_radius(self, _client: hvac.Client) -> None: + if self.auth_mount_point: + _client.auth.radius.configure( + host=self.radius_host, + secret=self.radius_secret, + port=self.radius_port, + mount_point=self.auth_mount_point, + ) + else: + _client.auth.radius.configure( + host=self.radius_host, secret=self.radius_secret, port=self.radius_port + ) + + def _auth_ldap(self, _client: hvac.Client) -> None: + if self.auth_mount_point: + _client.auth.ldap.login( + username=self.username, + password=self.password, + mount_point=self.auth_mount_point, + ) + else: + _client.auth.ldap.login(username=self.username, password=self.password) + + def _auth_kubernetes(self, _client: hvac.Client) -> None: + if not self.kubernetes_jwt_path: + raise VaultError( + "The kubernetes_jwt_path should be set here. This should not happen." + ) + with open(self.kubernetes_jwt_path) as f: + jwt = f.read() + if self.auth_mount_point: + _client.auth_kubernetes( + role=self.kubernetes_role, + jwt=jwt, + mount_point=self.auth_mount_point, + ) + else: + _client.auth_kubernetes(role=self.kubernetes_role, jwt=jwt) + + def _auth_github(self, _client: hvac.Client) -> None: + if self.auth_mount_point: + _client.auth.github.login( + token=self.token, mount_point=self.auth_mount_point + ) + else: + _client.auth.github.login(token=self.token) + + def _auth_gcp(self, _client: hvac.Client) -> None: + from airflow.providers.google.cloud.utils.credentials_provider import ( # noqa + _get_scopes, + get_credentials_and_project_id, + ) + + scopes = _get_scopes(self.gcp_scopes) + credentials, _ = get_credentials_and_project_id( + key_path=self.gcp_key_path, + keyfile_dict=self.gcp_keyfile_dict, + scopes=scopes, + ) + if self.auth_mount_point: + _client.auth.gcp.configure( + credentials=credentials, mount_point=self.auth_mount_point + ) + else: + _client.auth.gcp.configure(credentials=credentials) + + def _auth_azure(self, _client: hvac.Client) -> None: + if self.auth_mount_point: + _client.auth.azure.configure( + tenant_id=self.azure_tenant_id, + resource=self.azure_resource, + client_id=self.key_id, + client_secret=self.secret_id, + mount_point=self.auth_mount_point, + ) + else: + _client.auth.azure.configure( + tenant_id=self.azure_tenant_id, + resource=self.azure_resource, + client_id=self.key_id, + client_secret=self.secret_id, + ) + + def _auth_aws_iam(self, _client: hvac.Client) -> None: + if self.auth_mount_point: + _client.auth_aws_iam( + access_key=self.key_id, + secret_key=self.secret_id, + role=self.role_id, + mount_point=self.auth_mount_point, + ) + else: + _client.auth_aws_iam( + access_key=self.key_id, secret_key=self.secret_id, role=self.role_id + ) + + def _auth_approle(self, _client: hvac.Client) -> None: + if self.auth_mount_point: + _client.auth_approle( + role_id=self.role_id, + secret_id=self.secret_id, + mount_point=self.auth_mount_point, + ) + else: + _client.auth_approle(role_id=self.role_id, secret_id=self.secret_id) + + def _set_token(self, _client: hvac.Client) -> None: + if self.token_path: + with open(self.token_path) as f: + _client.token = f.read() + else: + _client.token = self.token + + def get_secret( + self, secret_path: str, secret_version: Optional[int] = None + ) -> Optional[dict]: + """ + Get secret value from the KV engine. + + :param secret_path: The path of the secret. + :type secret_path: str + :param secret_version: Specifies the version of Secret to return. If not set, the latest + version is returned. (Can only be used in case of version 2 of KV). + :type secret_version: int + + See https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v1.html + and https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v2.html for details. + + :return: secret stored in the vault as a dictionary + """ + try: + if self.kv_engine_version == 1: + if secret_version: + raise VaultError( + "Secret version can only be used with version 2 of the KV engine" + ) + response = self.client.secrets.kv.v1.read_secret( + path=secret_path, mount_point=self.mount_point + ) + else: + response = self.client.secrets.kv.v2.read_secret_version( + path=secret_path, + mount_point=self.mount_point, + version=secret_version, + ) + except InvalidPath: + self.log.debug( + "Secret not found %s with mount point %s", secret_path, self.mount_point + ) + return None + + return_data = ( + response["data"] + if self.kv_engine_version == 1 + else response["data"]["data"] + ) + return return_data + + def get_secret_metadata(self, secret_path: str) -> Optional[dict]: + """ + Reads secret metadata (including versions) from the engine. It is only valid for KV version 2. + + :param secret_path: The path of the secret. + :type secret_path: str + :rtype: dict + :return: secret metadata. This is a Dict containing metadata for the secret. + + See https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v2.html for details. + + """ + if self.kv_engine_version == 1: + raise VaultError( + "Metadata might only be used with version 2 of the KV engine." + ) + try: + return self.client.secrets.kv.v2.read_secret_metadata( + path=secret_path, mount_point=self.mount_point + ) + except InvalidPath: + self.log.debug( + "Secret not found %s with mount point %s", secret_path, self.mount_point + ) + return None + + def get_secret_including_metadata( + self, secret_path: str, secret_version: Optional[int] = None + ) -> Optional[dict]: + """ + Reads secret including metadata. It is only valid for KV version 2. + + See https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v2.html for details. + + :param secret_path: The path of the secret. + :type secret_path: str + :param secret_version: Specifies the version of Secret to return. If not set, the latest + version is returned. (Can only be used in case of version 2 of KV). + :type secret_version: int + :rtype: dict + :return: The key info. This is a Dict with "data" mapping keeping secret + and "metadata" mapping keeping metadata of the secret. + """ + if self.kv_engine_version == 1: + raise VaultError( + "Metadata might only be used with version 2 of the KV engine." + ) + try: + return self.client.secrets.kv.v2.read_secret_version( + path=secret_path, mount_point=self.mount_point, version=secret_version + ) + except InvalidPath: + self.log.debug( + "Secret not found %s with mount point %s and version %s", + secret_path, + self.mount_point, + secret_version, + ) + return None + + def create_or_update_secret( + self, + secret_path: str, + secret: dict, + method: Optional[str] = None, + cas: Optional[int] = None, + ) -> Response: + """ + Creates or updates secret. + + :param secret_path: The path of the secret. + :type secret_path: str + :param secret: Secret to create or update for the path specified + :type secret: dict + :param method: Optional parameter to explicitly request a POST (create) or PUT (update) request to + the selected kv secret engine. If no argument is provided for this parameter, hvac attempts to + intelligently determine which method is appropriate. Only valid for KV engine version 1 + :type method: str + :param cas: Set the "cas" value to use a Check-And-Set operation. If not set the write will be + allowed. If set to 0 a write will only be allowed if the key doesn't exist. + If the index is non-zero the write will only be allowed if the key's current version + matches the version specified in the cas parameter. Only valid for KV engine version 2. + :type cas: int + :rtype: requests.Response + :return: The response of the create_or_update_secret request. + + See https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v1.html + and https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v2.html for details. + + """ + if self.kv_engine_version == 2 and method: + raise VaultError("The method parameter is only valid for version 1") + if self.kv_engine_version == 1 and cas: + raise VaultError("The cas parameter is only valid for version 2") + if self.kv_engine_version == 1: + response = self.client.secrets.kv.v1.create_or_update_secret( + secret_path=secret_path, + secret=secret, + mount_point=self.mount_point, + method=method, + ) + else: + response = self.client.secrets.kv.v2.create_or_update_secret( + secret_path=secret_path, + secret=secret, + mount_point=self.mount_point, + cas=cas, + ) + return response diff --git a/reference/providers/hashicorp/hooks/__init__.py b/reference/providers/hashicorp/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/hashicorp/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/hashicorp/hooks/vault.py b/reference/providers/hashicorp/hooks/vault.py new file mode 100644 index 0000000..2a39868 --- /dev/null +++ b/reference/providers/hashicorp/hooks/vault.py @@ -0,0 +1,374 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Hook for HashiCorp Vault""" +import json +from typing import Optional, Tuple + +import hvac +from airflow.hooks.base import BaseHook +from airflow.providers.hashicorp._internal_client.vault_client import ( # noqa + DEFAULT_KUBERNETES_JWT_PATH, + DEFAULT_KV_ENGINE_VERSION, + _VaultClient, +) +from hvac.exceptions import VaultError +from requests import Response + + +class VaultHook(BaseHook): + """ + Hook to Interact with HashiCorp Vault KeyValue Secret engine. + + HashiCorp hvac documentation: + * https://hvac.readthedocs.io/en/stable/ + + You connect to the host specified as host in the connection. The login/password from the connection + are used as credentials usually and you can specify different authentication parameters + via init params or via corresponding extras in the connection. + + The mount point should be placed as a path in the URL - similarly to Vault's URL schema: + This indicates the "path" the secret engine is mounted on. Default id not specified is "secret". + Note that this ``mount_point`` is not used for authentication if authentication is done via a + different engines. Each engine uses it's own engine-specific authentication mount_point. + + The extras in the connection are named the same as the parameters ('kv_engine_version', 'auth_type', ...). + + You can also use gcp_keyfile_dict extra to pass json-formatted dict in case of 'gcp' authentication. + + The URL schemas supported are "vault", "http" (using http to connect to the vault) or + "vaults" and "https" (using https to connect to the vault). + + Example URL: + + .. code-block:: + + vault://user:password@host:port/mount_point?kv_engine_version=1&auth_type=github + + + Login/Password are used as credentials: + + * approle: password -> secret_id + * github: password -> token + * token: password -> token + * aws_iam: login -> key_id, password -> secret_id + * azure: login -> client_id, password -> client_secret + * ldap: login -> username, password -> password + * userpass: login -> username, password -> password + * radius: password -> radius_secret + + :param vault_conn_id: The id of the connection to use + :type vault_conn_id: str + :param auth_type: Authentication Type for the Vault. Default is ``token``. Available values are: + ('approle', 'github', 'gcp', 'kubernetes', 'ldap', 'token', 'userpass') + :type auth_type: str + :param auth_mount_point: It can be used to define mount_point for authentication chosen + Default depends on the authentication method used. + :type auth_mount_point: str + :param kv_engine_version: Select the version of the engine to run (``1`` or ``2``). Defaults to + version defined in connection or ``2`` if not defined in connection. + :type kv_engine_version: int + :param role_id: Role ID for Authentication (for ``approle``, ``aws_iam`` auth_types) + :type role_id: str + :param kubernetes_role: Role for Authentication (for ``kubernetes`` auth_type) + :type kubernetes_role: str + :param kubernetes_jwt_path: Path for kubernetes jwt token (for ``kubernetes`` auth_type, default: + ``/var/run/secrets/kubernetes.io/serviceaccount/token``) + :type kubernetes_jwt_path: str + :param token_path: path to file containing authentication token to include in requests sent to Vault + (for ``token`` and ``github`` auth_type). + :type token_path: str + :param gcp_key_path: Path to Google Cloud Service Account key file (JSON) (for ``gcp`` auth_type) + Mutually exclusive with gcp_keyfile_dict + :type gcp_key_path: str + :param gcp_scopes: Comma-separated string containing OAuth2 scopes (for ``gcp`` auth_type) + :type gcp_scopes: str + :param azure_tenant_id: The tenant id for the Azure Active Directory (for ``azure`` auth_type) + :type azure_tenant_id: str + :param azure_re# The configured URL for the application registered in Azure Active Directory + (for ``azure`` auth_type) + :type azure_re# str + :param radius_host: Host for radius (for ``radius`` auth_type) + :type radius_host: str + :param radius_port: Port for radius (for ``radius`` auth_type) + :type radius_port: int + + """ + + conn_name_attr = "vault_conn_id" + default_conn_name = "imap_default" + conn_type = "vault" + hook_name = "Hashicorp Vault" + + def __init__( # pylint: disable=too-many-arguments + self, + vault_conn_id: str = default_conn_name, + auth_type: Optional[str] = None, + auth_mount_point: Optional[str] = None, + kv_engine_version: Optional[int] = None, + role_id: Optional[str] = None, + kubernetes_role: Optional[str] = None, + kubernetes_jwt_path: Optional[str] = None, + token_path: Optional[str] = None, + gcp_key_path: Optional[str] = None, + gcp_scopes: Optional[str] = None, + azure_tenant_id: Optional[str] = None, + azure_re# Optional[str] = None, + radius_host: Optional[str] = None, + radius_port: Optional[int] = None, + ): + super().__init__() + self.connection = self.get_connection(vault_conn_id) + + if not auth_type: + auth_type = self.connection.extra_dejson.get("auth_type") or "token" + + if not auth_mount_point: + auth_mount_point = self.connection.extra_dejson.get("auth_mount_point") + + if not kv_engine_version: + conn_version = self.connection.extra_dejson.get("kv_engine_version") + try: + kv_engine_version = ( + int(conn_version) if conn_version else DEFAULT_KV_ENGINE_VERSION + ) + except ValueError: + raise VaultError(f"The version is not an int: {conn_version}. ") + + if auth_type in ["approle", "aws_iam"]: + if not role_id: + role_id = self.connection.extra_dejson.get("role_id") + + azure_resource, azure_tenant_id = ( + self._get_azure_parameters_from_connection(azure_resource, azure_tenant_id) + if auth_type == "azure" + else (None, None) + ) + gcp_key_path, gcp_keyfile_dict, gcp_scopes = ( + self._get_gcp_parameters_from_connection(gcp_key_path, gcp_scopes) + if auth_type == "gcp" + else (None, None, None) + ) + kubernetes_jwt_path, kubernetes_role = ( + self._get_kubernetes_parameters_from_connection( + kubernetes_jwt_path, kubernetes_role + ) + if auth_type == "kubernetes" + else (None, None) + ) + radius_host, radius_port = ( + self._get_radius_parameters_from_connection(radius_host, radius_port) + if auth_type == "radius" + else (None, None) + ) + + if self.connection.conn_type == "vault": + conn_protocol = "http" + elif self.connection.conn_type == "vaults": + conn_protocol = "https" + elif self.connection.conn_type == "http": + conn_protocol = "http" + elif self.connection.conn_type == "https": + conn_protocol = "https" + else: + raise VaultError( + "The url schema must be one of ['http', 'https', 'vault', 'vaults' ]" + ) + + url = f"{conn_protocol}://{self.connection.host}" + if self.connection.port: + url += f":{self.connection.port}" + + # Schema is really path in the Connection definition. This is pretty confusing because of URL schema + mount_point = self.connection.schema if self.connection.schema else "secret" + + self.vault_client = _VaultClient( + url=url, + auth_type=auth_type, + auth_mount_point=auth_mount_point, + mount_point=mount_point, + kv_engine_version=kv_engine_version, + token=self.connection.password, + token_path=token_path, + username=self.connection.login, + password=self.connection.password, + key_id=self.connection.login, + secret_id=self.connection.password, + role_id=role_id, + kubernetes_role=kubernetes_role, + kubernetes_jwt_path=kubernetes_jwt_path, + gcp_key_path=gcp_key_path, + gcp_keyfile_dict=gcp_keyfile_dict, + gcp_scopes=gcp_scopes, + azure_tenant_id=azure_tenant_id, + azure_resource=azure_resource, + radius_host=radius_host, + radius_secret=self.connection.password, + radius_port=radius_port, + ) + + def _get_kubernetes_parameters_from_connection( + self, kubernetes_jwt_path: Optional[str], kubernetes_role: Optional[str] + ) -> Tuple[str, Optional[str]]: + if not kubernetes_jwt_path: + kubernetes_jwt_path = self.connection.extra_dejson.get( + "kubernetes_jwt_path" + ) + if not kubernetes_jwt_path: + kubernetes_jwt_path = DEFAULT_KUBERNETES_JWT_PATH + if not kubernetes_role: + kubernetes_role = self.connection.extra_dejson.get("kubernetes_role") + return kubernetes_jwt_path, kubernetes_role + + def _get_gcp_parameters_from_connection( + self, + gcp_key_path: Optional[str], + gcp_scopes: Optional[str], + ) -> Tuple[Optional[str], Optional[dict], Optional[str]]: + if not gcp_scopes: + gcp_scopes = self.connection.extra_dejson.get("gcp_scopes") + if not gcp_key_path: + gcp_key_path = self.connection.extra_dejson.get("gcp_key_path") + string_keyfile_dict = self.connection.extra_dejson.get("gcp_keyfile_dict") + gcp_keyfile_dict = ( + json.loads(string_keyfile_dict) if string_keyfile_dict else None + ) + return gcp_key_path, gcp_keyfile_dict, gcp_scopes + + def _get_azure_parameters_from_connection( + self, azure_re# Optional[str], azure_tenant_id: Optional[str] + ) -> Tuple[Optional[str], Optional[str]]: + if not azure_tenant_id: + azure_tenant_id = self.connection.extra_dejson.get("azure_tenant_id") + if not azure_re# + azure_resource = self.connection.extra_dejson.get("azure_resource") + return azure_resource, azure_tenant_id + + def _get_radius_parameters_from_connection( + self, radius_host: Optional[str], radius_port: Optional[int] + ) -> Tuple[Optional[str], Optional[int]]: + if not radius_port: + radius_port_str = self.connection.extra_dejson.get("radius_port") + if radius_port_str: + try: + radius_port = int(radius_port_str) + except ValueError: + raise VaultError(f"Radius port was wrong: {radius_port_str}") + if not radius_host: + radius_host = self.connection.extra_dejson.get("radius_host") + return radius_host, radius_port + + def get_conn(self) -> hvac.Client: + """ + Retrieves connection to Vault. + + :rtype: hvac.Client + :return: connection used. + """ + return self.vault_client.client + + def get_secret( + self, secret_path: str, secret_version: Optional[int] = None + ) -> Optional[dict]: + """ + Get secret value from the engine. + + :param secret_path: Path of the secret + :type secret_path: str + :param secret_version: Optional version of key to read - can only be used in case of version 2 of KV + :type secret_version: int + + See https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v1.html + and https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v2.html for details. + + :param secret_path: Path of the secret + :type secret_path: str + :rtype: dict + :return: secret stored in the vault as a dictionary + """ + return self.vault_client.get_secret( + secret_path=secret_path, secret_version=secret_version + ) + + def get_secret_metadata(self, secret_path: str) -> Optional[dict]: + """ + Reads secret metadata (including versions) from the engine. It is only valid for KV version 2. + + :param secret_path: Path to read from + :type secret_path: str + :rtype: dict + :return: secret metadata. This is a Dict containing metadata for the secret. + + See https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v2.html for details. + + """ + return self.vault_client.get_secret_metadata(secret_path=secret_path) + + def get_secret_including_metadata( + self, secret_path: str, secret_version: Optional[int] = None + ) -> Optional[dict]: + """ + Reads secret including metadata. It is only valid for KV version 2. + + See https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v2.html for details. + + :param secret_path: Path of the secret + :type secret_path: str + :param secret_version: Optional version of key to read - can only be used in case of version 2 of KV + :type secret_version: int + :rtype: dict + :return: key info. This is a Dict with "data" mapping keeping secret + and "metadata" mapping keeping metadata of the secret. + + """ + return self.vault_client.get_secret_including_metadata( + secret_path=secret_path, secret_version=secret_version + ) + + def create_or_update_secret( + self, + secret_path: str, + secret: dict, + method: Optional[str] = None, + cas: Optional[int] = None, + ) -> Response: + """ + Creates or updates secret. + + :param secret_path: Path to read from + :type secret_path: str + :param secret: Secret to create or update for the path specified + :type secret: dict + :param method: Optional parameter to explicitly request a POST (create) or PUT (update) request to + the selected kv secret engine. If no argument is provided for this parameter, hvac attempts to + intelligently determine which method is appropriate. Only valid for KV engine version 1 + :type method: str + :param cas: Set the "cas" value to use a Check-And-Set operation. If not set the write will be + allowed. If set to 0 a write will only be allowed if the key doesn't exist. + If the index is non-zero the write will only be allowed if the key's current version + matches the version specified in the cas parameter. Only valid for KV engine version 2. + :type cas: int + :rtype: requests.Response + :return: The response of the create_or_update_secret request. + + See https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v1.html + and https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v2.html for details. + + """ + return self.vault_client.create_or_update_secret( + secret_path=secret_path, secret=secret, method=method, cas=cas + ) diff --git a/reference/providers/hashicorp/provider.yaml b/reference/providers/hashicorp/provider.yaml new file mode 100644 index 0000000..04b91f5 --- /dev/null +++ b/reference/providers/hashicorp/provider.yaml @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-hashicorp +name: Hashicorp +description: | + Hashicorp including `Hashicorp Vault `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Hashicorp Vault + external-doc-url: https://www.vaultproject.io/ + logo: /integration-logos/hashicorp/Hashicorp-Vault.png + tags: [software] + +hooks: + - integration-name: Hashicorp Vault + python-modules: + - airflow.providers.hashicorp.hooks.vault + +hook-class-names: + - airflow.providers.hashicorp.hooks.vault.VaultHook diff --git a/reference/providers/hashicorp/secrets/__init__.py b/reference/providers/hashicorp/secrets/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/hashicorp/secrets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/hashicorp/secrets/vault.py b/reference/providers/hashicorp/secrets/vault.py new file mode 100644 index 0000000..5098a52 --- /dev/null +++ b/reference/providers/hashicorp/secrets/vault.py @@ -0,0 +1,230 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Objects relating to sourcing connections & variables from Hashicorp Vault""" +from typing import Optional + +from airflow.providers.hashicorp._internal_client.vault_client import ( # noqa + _VaultClient, +) +from airflow.secrets import BaseSecretsBackend +from airflow.utils.log.logging_mixin import LoggingMixin + + +# pylint: disable=too-many-instance-attributes,too-many-locals +class VaultBackend(BaseSecretsBackend, LoggingMixin): + """ + Retrieves Connections and Variables from Hashicorp Vault. + + Configurable via ``airflow.cfg`` as follows: + + .. code-block:: ini + + [secrets] + backend = airflow.providers.hashicorp.secrets.vault.VaultBackend + backend_kwargs = { + "connections_path": "connections", + "url": "http://127.0.0.1:8200", + "mount_point": "airflow" + } + + For example, if your keys are under ``connections`` path in ``airflow`` mount_point, this + would be accessible if you provide ``{"connections_path": "connections"}`` and request + conn_id ``smtp_default``. + + :param connections_path: Specifies the path of the secret to read to get Connections. + (default: 'connections'). If set to None (null), requests for connections will not be sent to Vault. + :type connections_path: str + :param variables_path: Specifies the path of the secret to read to get Variable. + (default: 'variables'). If set to None (null), requests for variables will not be sent to Vault. + :type variables_path: str + :param config_path: Specifies the path of the secret to read Airflow Configurations + (default: 'config'). If set to None (null), requests for configurations will not be sent to Vault. + :type config_path: str + :param url: Base URL for the Vault instance being addressed. + :type url: str + :param auth_type: Authentication Type for Vault. Default is ``token``. Available values are: + ('approle', 'aws_iam', 'azure', 'github', 'gcp', 'kubernetes', 'ldap', 'radius', 'token', 'userpass') + :type auth_type: str + :param auth_mount_point: It can be used to define mount_point for authentication chosen + Default depends on the authentication method used. + :type auth_mount_point: str + :param mount_point: The "path" the secret engine was mounted on. Default is "secret". Note that + this mount_point is not used for authentication if authentication is done via a + different engine. For authentication mount_points see, auth_mount_point. + :type mount_point: str + :param kv_engine_version: Select the version of the engine to run (``1`` or ``2``, default: ``2``). + :type kv_engine_version: int + :param token: Authentication token to include in requests sent to Vault. + (for ``token`` and ``github`` auth_type) + :type token: str + :param token_path: path to file containing authentication token to include in requests sent to Vault + (for ``token`` and ``github`` auth_type). + :type token_path: str + :param username: Username for Authentication (for ``ldap`` and ``userpass`` auth_type). + :type username: str + :param password: Password for Authentication (for ``ldap`` and ``userpass`` auth_type). + :type password: str + :param key_id: Key ID for Authentication (for ``aws_iam`` and ''azure`` auth_type). + :type key_id: str + :param secret_id: Secret ID for Authentication (for ``approle``, ``aws_iam`` and ``azure`` auth_types). + :type secret_id: str + :param role_id: Role ID for Authentication (for ``approle``, ``aws_iam`` auth_types). + :type role_id: str + :param kubernetes_role: Role for Authentication (for ``kubernetes`` auth_type). + :type kubernetes_role: str + :param kubernetes_jwt_path: Path for kubernetes jwt token (for ``kubernetes`` auth_type, default: + ``/var/run/secrets/kubernetes.io/serviceaccount/token``). + :type kubernetes_jwt_path: str + :param gcp_key_path: Path to Google Cloud Service Account key file (JSON) (for ``gcp`` auth_type). + Mutually exclusive with gcp_keyfile_dict. + :type gcp_key_path: str + :param gcp_keyfile_dict: Dictionary of keyfile parameters. (for ``gcp`` auth_type). + Mutually exclusive with gcp_key_path. + :type gcp_keyfile_dict: dict + :param gcp_scopes: Comma-separated string containing OAuth2 scopes (for ``gcp`` auth_type). + :type gcp_scopes: str + :param azure_tenant_id: The tenant id for the Azure Active Directory (for ``azure`` auth_type). + :type azure_tenant_id: str + :param azure_re# The configured URL for the application registered in Azure Active Directory + (for ``azure`` auth_type). + :type azure_re# str + :param radius_host: Host for radius (for ``radius`` auth_type). + :type radius_host: str + :param radius_secret: Secret for radius (for ``radius`` auth_type). + :type radius_secret: str + :param radius_port: Port for radius (for ``radius`` auth_type). + :type radius_port: str + """ + + def __init__( # pylint: disable=too-many-arguments + self, + connections_path: str = "connections", + variables_path: str = "variables", + config_path: str = "config", + url: Optional[str] = None, + auth_type: str = "token", + auth_mount_point: Optional[str] = None, + mount_point: str = "secret", + kv_engine_version: int = 2, + token: Optional[str] = None, + token_path: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + key_id: Optional[str] = None, + secret_id: Optional[str] = None, + role_id: Optional[str] = None, + kubernetes_role: Optional[str] = None, + kubernetes_jwt_path: str = "/var/run/secrets/kubernetes.io/serviceaccount/token", + gcp_key_path: Optional[str] = None, + gcp_keyfile_dict: Optional[dict] = None, + gcp_scopes: Optional[str] = None, + azure_tenant_id: Optional[str] = None, + azure_re# Optional[str] = None, + radius_host: Optional[str] = None, + radius_secret: Optional[str] = None, + radius_port: Optional[int] = None, + **kwargs, + ): + super().__init__() + if connections_path is not None: + self.connections_path = connections_path.rstrip("/") + else: + self.connections_path = connections_path + if variables_path is not None: + self.variables_path = variables_path.rstrip("/") + else: + self.variables_path = variables_path + if config_path is not None: + self.config_path = config_path.rstrip("/") + else: + self.config_path = config_path + self.mount_point = mount_point + self.kv_engine_version = kv_engine_version + self.vault_client = _VaultClient( + url=url, + auth_type=auth_type, + auth_mount_point=auth_mount_point, + mount_point=mount_point, + kv_engine_version=kv_engine_version, + token=token, + token_path=token_path, + username=username, + password=password, + key_id=key_id, + secret_id=secret_id, + role_id=role_id, + kubernetes_role=kubernetes_role, + kubernetes_jwt_path=kubernetes_jwt_path, + gcp_key_path=gcp_key_path, + gcp_keyfile_dict=gcp_keyfile_dict, + gcp_scopes=gcp_scopes, + azure_tenant_id=azure_tenant_id, + azure_resource=azure_resource, + radius_host=radius_host, + radius_secret=radius_secret, + radius_port=radius_port, + **kwargs, + ) + + def get_conn_uri(self, conn_id: str) -> Optional[str]: + """ + Get secret value from Vault. Store the secret in the form of URI + + :param conn_id: The connection id + :type conn_id: str + :rtype: str + :return: The connection uri retrieved from the secret + """ + if self.connections_path is None: + return None + else: + secret_path = self.build_path(self.connections_path, conn_id) + response = self.vault_client.get_secret(secret_path=secret_path) + return response.get("conn_uri") if response else None + + def get_variable(self, key: str) -> Optional[str]: + """ + Get Airflow Variable + + :param key: Variable Key + :type key: str + :rtype: str + :return: Variable Value retrieved from the vault + """ + if self.variables_path is None: + return None + else: + secret_path = self.build_path(self.variables_path, key) + response = self.vault_client.get_secret(secret_path=secret_path) + return response.get("value") if response else None + + def get_config(self, key: str) -> Optional[str]: + """ + Get Airflow Configuration + + :param key: Configuration Option Key + :type key: str + :rtype: str + :return: Configuration Option Value retrieved from the vault + """ + if self.config_path is None: + return None + else: + secret_path = self.build_path(self.config_path, key) + response = self.vault_client.get_secret(secret_path=secret_path) + return response.get("value") if response else None diff --git a/reference/providers/http/CHANGELOG.rst b/reference/providers/http/CHANGELOG.rst new file mode 100644 index 0000000..b4a0313 --- /dev/null +++ b/reference/providers/http/CHANGELOG.rst @@ -0,0 +1,44 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.1.1 +..... + +Bug fixes +~~~~~~~~~ + +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` + + +1.1.0 +..... + +Updated documentation and readme files. + +Features +~~~~~~~~ + +* ``Add a new argument for HttpSensor to accept a list of http status code`` + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/http/__init__.py b/reference/providers/http/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/http/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/http/example_dags/__init__.py b/reference/providers/http/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/http/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/http/example_dags/example_http.py b/reference/providers/http/example_dags/example_http.py new file mode 100644 index 0000000..79c4d4a --- /dev/null +++ b/reference/providers/http/example_dags/example_http.py @@ -0,0 +1,119 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example HTTP operator and sensor""" + +import json +from datetime import timedelta + +from airflow import DAG +from airflow.providers.http.operators.http import SimpleHttpOperator +from airflow.providers.http.sensors.http import HttpSensor +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), +} + +dag = DAG( + "example_http_operator", + default_args=default_args, + tags=["example"], + start_date=days_ago(2), +) + +dag.doc_md = __doc__ + +# task_post_op, task_get_op and task_put_op are examples of tasks created by instantiating operators +# [START howto_operator_http_task_post_op] +task_post_op = SimpleHttpOperator( + task_id="post_op", + endpoint="post", + data=json.dumps({"priority": 5}), + headers={"Content-Type": "application/json"}, + response_check=lambda response: response.json()["json"]["priority"] == 5, + dag=dag, +) +# [END howto_operator_http_task_post_op] +# [START howto_operator_http_task_post_op_formenc] +task_post_op_formenc = SimpleHttpOperator( + task_id="post_op_formenc", + endpoint="post", + data="name=Joe", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + dag=dag, +) +# [END howto_operator_http_task_post_op_formenc] +# [START howto_operator_http_task_get_op] +task_get_op = SimpleHttpOperator( + task_id="get_op", + method="GET", + endpoint="get", + data={"param1": "value1", "param2": "value2"}, + headers={}, + dag=dag, +) +# [END howto_operator_http_task_get_op] +# [START howto_operator_http_task_get_op_response_filter] +task_get_op_response_filter = SimpleHttpOperator( + task_id="get_op_response_filter", + method="GET", + endpoint="get", + response_filter=lambda response: response.json()["nested"]["property"], + dag=dag, +) +# [END howto_operator_http_task_get_op_response_filter] +# [START howto_operator_http_task_put_op] +task_put_op = SimpleHttpOperator( + task_id="put_op", + method="PUT", + endpoint="put", + data=json.dumps({"priority": 5}), + headers={"Content-Type": "application/json"}, + dag=dag, +) +# [END howto_operator_http_task_put_op] +# [START howto_operator_http_task_del_op] +task_del_op = SimpleHttpOperator( + task_id="del_op", + method="DELETE", + endpoint="delete", + data="some=data", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + dag=dag, +) +# [END howto_operator_http_task_del_op] +# [START howto_operator_http_http_sensor_check] +task_http_sensor_check = HttpSensor( + task_id="http_sensor_check", + http_conn_id="http_default", + endpoint="", + request_params={}, + response_check=lambda response: "httpbin" in response.text, + poke_interval=5, + dag=dag, +) +# [END howto_operator_http_http_sensor_check] +task_http_sensor_check >> task_post_op >> task_get_op >> task_get_op_response_filter +task_get_op_response_filter >> task_put_op >> task_del_op >> task_post_op_formenc diff --git a/reference/providers/http/hooks/__init__.py b/reference/providers/http/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/http/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/http/hooks/http.py b/reference/providers/http/hooks/http.py new file mode 100644 index 0000000..78a1fd1 --- /dev/null +++ b/reference/providers/http/hooks/http.py @@ -0,0 +1,237 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Callable, Dict, Optional, Union + +import requests +import tenacity +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from requests.auth import HTTPBasicAuth + + +class HttpHook(BaseHook): + """ + Interact with HTTP servers. + + :param method: the API method to be called + :type method: str + :param http_conn_id: connection that has the base API url i.e https://www.google.com/ + and optional authentication credentials. Default headers can also be specified in + the Extra field in json format. + :type http_conn_id: str + :param auth_type: The auth type for the service + :type auth_type: AuthBase of python requests lib + """ + + conn_name_attr = "http_conn_id" + default_conn_name = "http_default" + conn_type = "http" + hook_name = "HTTP" + + def __init__( + self, + method: str = "POST", + http_conn_id: str = default_conn_name, + auth_type: Any = HTTPBasicAuth, + ) -> None: + super().__init__() + self.http_conn_id = http_conn_id + self.method = method.upper() + self.base_url: str = "" + self._retry_obj: Callable[..., Any] + self.auth_type: Any = auth_type + + # headers may be passed through directly or in the "extra" field in the connection + # definition + def get_conn(self, headers: Optional[Dict[Any, Any]] = None) -> requests.Session: + """ + Returns http session for use with requests + + :param headers: additional headers to be passed through as a dictionary + :type headers: dict + """ + session = requests.Session() + + if self.http_conn_id: + conn = self.get_connection(self.http_conn_id) + + if conn.host and "://" in conn.host: + self.base_url = conn.host + else: + # schema defaults to HTTP + schema = conn.schema if conn.schema else "http" + host = conn.host if conn.host else "" + self.base_url = schema + "://" + host + + if conn.port: + self.base_url = self.base_url + ":" + str(conn.port) + if conn.login: + session.auth = self.auth_type(conn.login, conn.password) + if conn.extra: + try: + session.headers.update(conn.extra_dejson) + except TypeError: + self.log.warning( + "Connection to %s has invalid extra field.", conn.host + ) + if headers: + session.headers.update(headers) + + return session + + def run( + self, + endpoint: Optional[str], + data: Optional[Union[Dict[str, Any], str]] = None, + headers: Optional[Dict[str, Any]] = None, + extra_options: Optional[Dict[str, Any]] = None, + **request_kwargs: Any, + ) -> Any: + r""" + Performs the request + + :param endpoint: the endpoint to be called i.e. resource/v1/query? + :type endpoint: str + :param data: payload to be uploaded or request parameters + :type data: dict + :param headers: additional headers to be passed through as a dictionary + :type headers: dict + :param extra_options: additional options to be used when executing the request + i.e. {'check_response': False} to avoid checking raising exceptions on non + 2XX or 3XX status codes + :type extra_options: dict + :param request_kwargs: Additional kwargs to pass when creating a request. + For example, ``run(json=obj)`` is passed as ``requests.Request(json=obj)`` + """ + extra_options = extra_options or {} + + session = self.get_conn(headers) + + if ( + self.base_url + and not self.base_url.endswith("/") + and endpoint + and not endpoint.startswith("/") + ): + url = self.base_url + "/" + endpoint + else: + url = (self.base_url or "") + (endpoint or "") + + if self.method == "GET": + # GET uses params + req = requests.Request( + self.method, url, params=data, headers=headers, **request_kwargs + ) + elif self.method == "HEAD": + # HEAD doesn't use params + req = requests.Request(self.method, url, headers=headers, **request_kwargs) + else: + # Others use data + req = requests.Request( + self.method, url, data=data, headers=headers, **request_kwargs + ) + + prepped_request = session.prepare_request(req) + self.log.info("Sending '%s' to url: %s", self.method, url) + return self.run_and_check(session, prepped_request, extra_options) + + def check_response(self, response: requests.Response) -> None: + """ + Checks the status code and raise an AirflowException exception on non 2XX or 3XX + status codes + + :param response: A requests response object + :type response: requests.response + """ + try: + response.raise_for_status() + except requests.exceptions.HTTPError: + self.log.error("HTTP error: %s", response.reason) + self.log.error(response.text) + raise AirflowException(str(response.status_code) + ":" + response.reason) + + def run_and_check( + self, + session: requests.Session, + prepped_request: requests.PreparedRequest, + extra_options: Dict[Any, Any], + ) -> Any: + """ + Grabs extra options like timeout and actually runs the request, + checking for the result + + :param session: the session to be used to execute the request + :type session: requests.Session + :param prepped_request: the prepared request generated in run() + :type prepped_request: session.prepare_request + :param extra_options: additional options to be used when executing the request + i.e. {'check_response': False} to avoid checking raising exceptions on non 2XX + or 3XX status codes + :type extra_options: dict + """ + extra_options = extra_options or {} + + try: + response = session.send( + prepped_request, + stream=extra_options.get("stream", False), + verify=extra_options.get("verify", True), + proxies=extra_options.get("proxies", {}), + cert=extra_options.get("cert"), + timeout=extra_options.get("timeout"), + allow_redirects=extra_options.get("allow_redirects", True), + ) + + if extra_options.get("check_response", True): + self.check_response(response) + return response + + except requests.exceptions.ConnectionError as ex: + self.log.warning("%s Tenacity will retry to execute the operation", ex) + raise ex + + def run_with_advanced_retry( + self, _retry_args: Dict[Any, Any], *args: Any, **kwargs: Any + ) -> Any: + """ + Runs Hook.run() with a Tenacity decorator attached to it. This is useful for + connectors which might be disturbed by intermittent issues and should not + instantly fail. + + :param _retry_args: Arguments which define the retry behaviour. + See Tenacity documentation at https://github.com/jd/tenacity + :type _retry_args: dict + + + .. code-block:: python + + hook = HttpHook(http_conn_id='my_conn',method='GET') + retry_args = dict( + wait=tenacity.wait_exponential(), + stop=tenacity.stop_after_attempt(10), + retry=requests.exceptions.ConnectionError + ) + hook.run_with_advanced_retry( + endpoint='v1/test', + _retry_args=retry_args + ) + + """ + self._retry_obj = tenacity.Retrying(**_retry_args) + + return self._retry_obj(self.run, *args, **kwargs) diff --git a/reference/providers/http/operators/__init__.py b/reference/providers/http/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/http/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/http/operators/http.py b/reference/providers/http/operators/http.py new file mode 100644 index 0000000..b5a1158 --- /dev/null +++ b/reference/providers/http/operators/http.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Callable, Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.http.hooks.http import HttpHook +from airflow.utils.decorators import apply_defaults + + +class SimpleHttpOperator(BaseOperator): + """ + Calls an endpoint on an HTTP system to execute an action + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SimpleHttpOperator` + + :param http_conn_id: The connection to run the operator against + :type http_conn_id: str + :param endpoint: The relative part of the full url. (templated) + :type endpoint: str + :param method: The HTTP method to use, default = "POST" + :type method: str + :param data: The data to pass. POST-data in POST/PUT and params + in the URL for a GET request. (templated) + :type data: For POST/PUT, depends on the content-type parameter, + for GET a dictionary of key/value string pairs + :param headers: The HTTP headers to be added to the GET request + :type headers: a dictionary of string key/value pairs + :param response_check: A check against the 'requests' response object. + The callable takes the response object as the first positional argument + and optionally any number of keyword arguments available in the context dictionary. + It should return True for 'pass' and False otherwise. + :type response_check: A lambda or defined function. + :param response_filter: A function allowing you to manipulate the response + text. e.g response_filter=lambda response: json.loads(response.text). + The callable takes the response object as the first positional argument + and optionally any number of keyword arguments available in the context dictionary. + :type response_filter: A lambda or defined function. + :param extra_options: Extra options for the 'requests' library, see the + 'requests' documentation (options to modify timeout, ssl, etc.) + :type extra_options: A dictionary of options, where key is string and value + depends on the option that's being modified. + :param log_response: Log the response (default: False) + :type log_response: bool + """ + + template_fields = [ + "endpoint", + "data", + "headers", + ] + template_fields_renderers = {"headers": "json", "data": "py"} + template_ext = () + ui_color = "#f4a460" + + @apply_defaults + def __init__( + self, + *, + endpoint: Optional[str] = None, + method: str = "POST", + data: Any = None, + headers: Optional[Dict[str, str]] = None, + response_check: Optional[Callable[..., bool]] = None, + response_filter: Optional[Callable[..., Any]] = None, + extra_options: Optional[Dict[str, Any]] = None, + http_conn_id: str = "http_default", + log_response: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.http_conn_id = http_conn_id + self.method = method + self.endpoint = endpoint + self.headers = headers or {} + self.data = data or {} + self.response_check = response_check + self.response_filter = response_filter + self.extra_options = extra_options or {} + self.log_response = log_response + if kwargs.get("xcom_push") is not None: + raise AirflowException( + "'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead" + ) + + def execute(self, context: Dict[str, Any]) -> Any: + from airflow.utils.operator_helpers import make_kwargs_callable + + http = HttpHook(self.method, http_conn_id=self.http_conn_id) + + self.log.info("Calling HTTP method") + + response = http.run(self.endpoint, self.data, self.headers, self.extra_options) + if self.log_response: + self.log.info(response.text) + if self.response_check: + kwargs_callable = make_kwargs_callable(self.response_check) + if not kwargs_callable(response, **context): + raise AirflowException("Response check returned False.") + if self.response_filter: + kwargs_callable = make_kwargs_callable(self.response_filter) + return kwargs_callable(response, **context) + return response.text diff --git a/reference/providers/http/provider.yaml b/reference/providers/http/provider.yaml new file mode 100644 index 0000000..79a095b --- /dev/null +++ b/reference/providers/http/provider.yaml @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-http +name: Hypertext Transfer Protocol (HTTP) +description: | + `Hypertext Transfer Protocol (HTTP) `__ + +versions: + - 1.1.1 + - 1.1.0 + - 1.0.0 + +integrations: + - integration-name: Hypertext Transfer Protocol (HTTP) + external-doc-url: https://www.w3.org/Protocols/ + how-to-guide: + - /docs/apache-airflow-providers-http/operators.rst + logo: /integration-logos/http/HTTP.png + tags: [protocol] + +operators: + - integration-name: Hypertext Transfer Protocol (HTTP) + python-modules: + - airflow.providers.http.operators.http +sensors: + - integration-name: Hypertext Transfer Protocol (HTTP) + python-modules: + - airflow.providers.http.sensors.http + +hooks: + - integration-name: Hypertext Transfer Protocol (HTTP) + python-modules: + - airflow.providers.http.hooks.http + +hook-class-names: + - airflow.providers.http.hooks.http.HttpHook diff --git a/reference/providers/http/sensors/__init__.py b/reference/providers/http/sensors/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/http/sensors/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/http/sensors/http.py b/reference/providers/http/sensors/http.py new file mode 100644 index 0000000..0e8b6fd --- /dev/null +++ b/reference/providers/http/sensors/http.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Callable, Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.http.hooks.http import HttpHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class HttpSensor(BaseSensorOperator): + """ + Executes a HTTP GET statement and returns False on failure caused by + 404 Not Found or `response_check` returning False. + + HTTP Error codes other than 404 (like 403) or Connection Refused Error + would raise an exception and fail the sensor itself directly (no more poking). + To avoid failing the task for other codes than 404, the argument ``extra_option`` + can be passed with the value ``{'check_response': False}``. It will make the ``response_check`` + be execute for any http status code. + + The response check can access the template context to the operator: + + def response_check(response, task_instance): + # The task_instance is injected, so you can pull data form xcom + # Other context variables such as dag, ds, execution_date are also available. + xcom_data = task_instance.xcom_pull(task_ids='pushing_task') + # In practice you would do something more sensible with this data.. + print(xcom_data) + return True + + HttpSensor(task_id='my_http_sensor', ..., response_check=response_check) + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:HttpSensor` + + :param http_conn_id: The connection to run the sensor against + :type http_conn_id: str + :param method: The HTTP request method to use + :type method: str + :param endpoint: The relative part of the full url + :type endpoint: str + :param request_params: The parameters to be added to the GET url + :type request_params: a dictionary of string key/value pairs + :param headers: The HTTP headers to be added to the GET request + :type headers: a dictionary of string key/value pairs + :param response_check: A check against the 'requests' response object. + The callable takes the response object as the first positional argument + and optionally any number of keyword arguments available in the context dictionary. + It should return True for 'pass' and False otherwise. + :type response_check: A lambda or defined function. + :param extra_options: Extra options for the 'requests' library, see the + 'requests' documentation (options to modify timeout, ssl, etc.) + :type extra_options: A dictionary of options, where key is string and value + depends on the option that's being modified. + """ + + template_fields = ("endpoint", "request_params", "headers") + + @apply_defaults + def __init__( + self, + *, + endpoint: str, + http_conn_id: str = "http_default", + method: str = "GET", + request_params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + response_check: Optional[Callable[..., bool]] = None, + extra_options: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.endpoint = endpoint + self.http_conn_id = http_conn_id + self.request_params = request_params or {} + self.headers = headers or {} + self.extra_options = extra_options or {} + self.response_check = response_check + + self.hook = HttpHook(method=method, http_conn_id=http_conn_id) + + def poke(self, context: Dict[Any, Any]) -> bool: + from airflow.utils.operator_helpers import make_kwargs_callable + + self.log.info("Poking: %s", self.endpoint) + try: + response = self.hook.run( + self.endpoint, + data=self.request_params, + headers=self.headers, + extra_options=self.extra_options, + ) + if self.response_check: + kwargs_callable = make_kwargs_callable(self.response_check) + return kwargs_callable(response, **context) + + except AirflowException as exc: + if str(exc).startswith("404"): + return False + + raise exc + + return True diff --git a/reference/providers/imap/CHANGELOG.rst b/reference/providers/imap/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/imap/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/imap/__init__.py b/reference/providers/imap/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/imap/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/imap/hooks/__init__.py b/reference/providers/imap/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/imap/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/imap/hooks/imap.py b/reference/providers/imap/hooks/imap.py new file mode 100644 index 0000000..b784737 --- /dev/null +++ b/reference/providers/imap/hooks/imap.py @@ -0,0 +1,406 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module provides everything to be able to search in mails for a specific attachment +and also to download it. +It uses the imaplib library that is already integrated in python 2 and 3. +""" +import email +import imaplib +import os +import re +from typing import Any, Iterable, List, Optional, Tuple + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.utils.log.logging_mixin import LoggingMixin + + +class ImapHook(BaseHook): + """ + This hook connects to a mail server by using the imap protocol. + + .. note:: Please call this Hook as context manager via `with` + to automatically open and close the connection to the mail server. + + :param imap_conn_id: The connection id that contains the information used to authenticate the client. + :type imap_conn_id: str + """ + + conn_name_attr = "imap_conn_id" + default_conn_name = "imap_default" + conn_type = "imap" + hook_name = "IMAP" + + def __init__(self, imap_conn_id: str = default_conn_name) -> None: + super().__init__() + self.imap_conn_id = imap_conn_id + self.mail_client: Optional[imaplib.IMAP4_SSL] = None + + def __enter__(self) -> "ImapHook": + return self.get_conn() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.mail_client.logout() + + def get_conn(self) -> "ImapHook": + """ + Login to the mail server. + + .. note:: Please call this Hook as context manager via `with` + to automatically open and close the connection to the mail server. + + :return: an authorized ImapHook object. + :rtype: ImapHook + """ + if not self.mail_client: + conn = self.get_connection(self.imap_conn_id) + self.mail_client = imaplib.IMAP4_SSL(conn.host) + self.mail_client.login(conn.login, conn.password) + + return self + + def has_mail_attachment( + self, + name: str, + *, + check_regex: bool = False, + mail_folder: str = "INBOX", + mail_filter: str = "All", + ) -> bool: + """ + Checks the mail folder for mails containing attachments with the given name. + + :param name: The name of the attachment that will be searched for. + :type name: str + :param check_regex: Checks the name for a regular expression. + :type check_regex: bool + :param mail_folder: The mail folder where to look at. + :type mail_folder: str + :param mail_filter: If set other than 'All' only specific mails will be checked. + See :py:meth:`imaplib.IMAP4.search` for details. + :type mail_filter: str + :returns: True if there is an attachment with the given name and False if not. + :rtype: bool + """ + mail_attachments = self._retrieve_mails_attachments_by_name( + name, check_regex, True, mail_folder, mail_filter + ) + return len(mail_attachments) > 0 + + def retrieve_mail_attachments( + self, + name: str, + *, + check_regex: bool = False, + latest_only: bool = False, + mail_folder: str = "INBOX", + mail_filter: str = "All", + not_found_mode: str = "raise", + ) -> List[Tuple]: + """ + Retrieves mail's attachments in the mail folder by its name. + + :param name: The name of the attachment that will be downloaded. + :type name: str + :param check_regex: Checks the name for a regular expression. + :type check_regex: bool + :param latest_only: If set to True it will only retrieve the first matched attachment. + :type latest_only: bool + :param mail_folder: The mail folder where to look at. + :type mail_folder: str + :param mail_filter: If set other than 'All' only specific mails will be checked. + See :py:meth:`imaplib.IMAP4.search` for details. + :type mail_filter: str + :param not_found_mode: Specify what should happen if no attachment has been found. + Supported values are 'raise', 'warn' and 'ignore'. + If it is set to 'raise' it will raise an exception, + if set to 'warn' it will only print a warning and + if set to 'ignore' it won't notify you at all. + :type not_found_mode: str + :returns: a list of tuple each containing the attachment filename and its payload. + :rtype: a list of tuple + """ + mail_attachments = self._retrieve_mails_attachments_by_name( + name, check_regex, latest_only, mail_folder, mail_filter + ) + + if not mail_attachments: + self._handle_not_found_mode(not_found_mode) + + return mail_attachments + + def download_mail_attachments( + self, + name: str, + local_output_directory: str, + *, + check_regex: bool = False, + latest_only: bool = False, + mail_folder: str = "INBOX", + mail_filter: str = "All", + not_found_mode: str = "raise", + ) -> None: + """ + Downloads mail's attachments in the mail folder by its name to the local directory. + + :param name: The name of the attachment that will be downloaded. + :type name: str + :param local_output_directory: The output directory on the local machine + where the files will be downloaded to. + :type local_output_directory: str + :param check_regex: Checks the name for a regular expression. + :type check_regex: bool + :param latest_only: If set to True it will only download the first matched attachment. + :type latest_only: bool + :param mail_folder: The mail folder where to look at. + :type mail_folder: str + :param mail_filter: If set other than 'All' only specific mails will be checked. + See :py:meth:`imaplib.IMAP4.search` for details. + :type mail_filter: str + :param not_found_mode: Specify what should happen if no attachment has been found. + Supported values are 'raise', 'warn' and 'ignore'. + If it is set to 'raise' it will raise an exception, + if set to 'warn' it will only print a warning and + if set to 'ignore' it won't notify you at all. + :type not_found_mode: str + """ + mail_attachments = self._retrieve_mails_attachments_by_name( + name, check_regex, latest_only, mail_folder, mail_filter + ) + + if not mail_attachments: + self._handle_not_found_mode(not_found_mode) + + self._create_files(mail_attachments, local_output_directory) + + def _handle_not_found_mode(self, not_found_mode: str) -> None: + if not_found_mode == "raise": + raise AirflowException("No mail attachments found!") + if not_found_mode == "warn": + self.log.warning("No mail attachments found!") + elif not_found_mode == "ignore": + pass # Do not notify if the attachment has not been found. + else: + self.log.error('Invalid "not_found_mode" %s', not_found_mode) + + def _retrieve_mails_attachments_by_name( + self, + name: str, + check_regex: bool, + latest_only: bool, + mail_folder: str, + mail_filter: str, + ) -> List: + if not self.mail_client: + raise Exception("The 'mail_client' should be initialized before!") + + all_matching_attachments = [] + + self.mail_client.select(mail_folder) + + for mail_id in self._list_mail_ids_desc(mail_filter): + response_mail_body = self._fetch_mail_body(mail_id) + matching_attachments = self._check_mail_body( + response_mail_body, name, check_regex, latest_only + ) + + if matching_attachments: + all_matching_attachments.extend(matching_attachments) + if latest_only: + break + + self.mail_client.close() + + return all_matching_attachments + + def _list_mail_ids_desc(self, mail_filter: str) -> Iterable[str]: + if not self.mail_client: + raise Exception("The 'mail_client' should be initialized before!") + _, data = self.mail_client.search(None, mail_filter) + mail_ids = data[0].split() + return reversed(mail_ids) + + def _fetch_mail_body(self, mail_id: str) -> str: + if not self.mail_client: + raise Exception("The 'mail_client' should be initialized before!") + _, data = self.mail_client.fetch(mail_id, "(RFC822)") + mail_body = data[0][1] # type: ignore # The mail body is always in this specific location + mail_body_str = mail_body.decode("utf-8") # type: ignore + return mail_body_str + + def _check_mail_body( + self, response_mail_body: str, name: str, check_regex: bool, latest_only: bool + ) -> List[Tuple[Any, Any]]: + mail = Mail(response_mail_body) + if mail.has_attachments(): + return mail.get_attachments_by_name( + name, check_regex, find_first=latest_only + ) + return [] + + def _create_files( + self, mail_attachments: List, local_output_directory: str + ) -> None: + for name, payload in mail_attachments: + if self._is_symlink(name): + self.log.error("Can not create file because it is a symlink!") + elif self._is_escaping_current_directory(name): + self.log.error( + "Can not create file because it is escaping the current directory!" + ) + else: + self._create_file(name, payload, local_output_directory) + + def _is_symlink(self, name: str) -> bool: + # IMPORTANT NOTE: os.path.islink is not working for windows symlinks + # See: https://stackoverflow.com/a/11068434 + return os.path.islink(name) + + def _is_escaping_current_directory(self, name: str) -> bool: + return "../" in name + + def _correct_path(self, name: str, local_output_directory: str) -> str: + return ( + local_output_directory + name + if local_output_directory.endswith("/") + else local_output_directory + "/" + name + ) + + def _create_file( + self, name: str, payload: Any, local_output_directory: str + ) -> None: + file_path = self._correct_path(name, local_output_directory) + + with open(file_path, "wb") as file: + file.write(payload) + + +class Mail(LoggingMixin): + """ + This class simplifies working with mails returned by the imaplib client. + + :param mail_body: The mail body of a mail received from imaplib client. + :type mail_body: str + """ + + def __init__(self, mail_body: str) -> None: + super().__init__() + self.mail = email.message_from_string(mail_body) + + def has_attachments(self) -> bool: + """ + Checks the mail for a attachments. + + :returns: True if it has attachments and False if not. + :rtype: bool + """ + return self.mail.get_content_maintype() == "multipart" + + def get_attachments_by_name( + self, name: str, check_regex: bool, find_first: bool = False + ) -> List[Tuple[Any, Any]]: + """ + Gets all attachments by name for the mail. + + :param name: The name of the attachment to look for. + :type name: str + :param check_regex: Checks the name for a regular expression. + :type check_regex: bool + :param find_first: If set to True it will only find the first match and then quit. + :type find_first: bool + :returns: a list of tuples each containing name and payload + where the attachments name matches the given name. + :rtype: list(tuple) + """ + attachments = [] + + for attachment in self._iterate_attachments(): + found_attachment = ( + attachment.has_matching_name(name) + if check_regex + else attachment.has_equal_name(name) + ) + if found_attachment: + file_name, file_payload = attachment.get_file() + self.log.info("Found attachment: %s", file_name) + attachments.append((file_name, file_payload)) + if find_first: + break + + return attachments + + def _iterate_attachments(self) -> Iterable["MailPart"]: + for part in self.mail.walk(): + mail_part = MailPart(part) + if mail_part.is_attachment(): + yield mail_part + + +class MailPart: + """ + This class is a wrapper for a Mail object's part and gives it more features. + + :param part: The mail part in a Mail object. + :type part: any + """ + + def __init__(self, part: Any) -> None: + self.part = part + + def is_attachment(self) -> bool: + """ + Checks if the part is a valid mail attachment. + + :returns: True if it is an attachment and False if not. + :rtype: bool + """ + return self.part.get_content_maintype() != "multipart" and self.part.get( + "Content-Disposition" + ) + + def has_matching_name(self, name: str) -> Optional[Tuple[Any, Any]]: + """ + Checks if the given name matches the part's name. + + :param name: The name to look for. + :type name: str + :returns: True if it matches the name (including regular expression). + :rtype: tuple + """ + return re.match(name, self.part.get_filename()) # type: ignore + + def has_equal_name(self, name: str) -> bool: + """ + Checks if the given name is equal to the part's name. + + :param name: The name to look for. + :type name: str + :returns: True if it is equal to the given name. + :rtype: bool + """ + return self.part.get_filename() == name + + def get_file(self) -> Tuple: + """ + Gets the file including name and payload. + + :returns: the part's name and payload. + :rtype: tuple + """ + return self.part.get_filename(), self.part.get_payload(decode=True) diff --git a/reference/providers/imap/provider.yaml b/reference/providers/imap/provider.yaml new file mode 100644 index 0000000..21cd1b8 --- /dev/null +++ b/reference/providers/imap/provider.yaml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-imap +name: Internet Message Access Protocol (IMAP) +description: | + `Internet Message Access Protocol (IMAP) `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Internet Message Access Protocol (IMAP) + external-doc-url: https://tools.ietf.org/html/rfc3501 + logo: /integration-logos/imap/IMAP.png + tags: [protocol] + +sensors: + - integration-name: Internet Message Access Protocol (IMAP) + python-modules: + - airflow.providers.imap.sensors.imap_attachment + +hooks: + - integration-name: Internet Message Access Protocol (IMAP) + python-modules: + - airflow.providers.imap.hooks.imap + +hook-class-names: + - airflow.providers.imap.hooks.imap.ImapHook diff --git a/reference/providers/imap/sensors/__init__.py b/reference/providers/imap/sensors/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/imap/sensors/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/imap/sensors/imap_attachment.py b/reference/providers/imap/sensors/imap_attachment.py new file mode 100644 index 0000000..8830f34 --- /dev/null +++ b/reference/providers/imap/sensors/imap_attachment.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module allows you to poke for attachments on a mail server.""" +from airflow.providers.imap.hooks.imap import ImapHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class ImapAttachmentSensor(BaseSensorOperator): + """ + Waits for a specific attachment on a mail server. + + :param attachment_name: The name of the attachment that will be checked. + :type attachment_name: str + :param check_regex: If set to True the attachment's name will be parsed as regular expression. + Through this you can get a broader set of attachments + that it will look for than just only the equality of the attachment name. + :type check_regex: bool + :param mail_folder: The mail folder in where to search for the attachment. + :type mail_folder: str + :param mail_filter: If set other than 'All' only specific mails will be checked. + See :py:meth:`imaplib.IMAP4.search` for details. + :type mail_filter: str + :param conn_id: The connection to run the sensor against. + :type conn_id: str + """ + + template_fields = ("attachment_name", "mail_filter") + + @apply_defaults + def __init__( + self, + *, + attachment_name, + check_regex=False, + mail_folder="INBOX", + mail_filter="All", + conn_id="imap_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.attachment_name = attachment_name + self.check_regex = check_regex + self.mail_folder = mail_folder + self.mail_filter = mail_filter + self.conn_id = conn_id + + def poke(self, context: dict) -> bool: + """ + Pokes for a mail attachment on the mail server. + + :param context: The context that is being provided when poking. + :type context: dict + :return: True if attachment with the given name is present and False if not. + :rtype: bool + """ + self.log.info("Poking for %s", self.attachment_name) + + with ImapHook(imap_conn_id=self.conn_id) as imap_hook: + return imap_hook.has_mail_attachment( + name=self.attachment_name, + check_regex=self.check_regex, + mail_folder=self.mail_folder, + mail_filter=self.mail_filter, + ) diff --git a/reference/providers/jdbc/CHANGELOG.rst b/reference/providers/jdbc/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/jdbc/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/jdbc/__init__.py b/reference/providers/jdbc/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/jdbc/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/jdbc/example_dags/__init__.py b/reference/providers/jdbc/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/jdbc/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/jdbc/example_dags/example_jdbc_queries.py b/reference/providers/jdbc/example_dags/example_jdbc_queries.py new file mode 100644 index 0000000..2f284ab --- /dev/null +++ b/reference/providers/jdbc/example_dags/example_jdbc_queries.py @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example DAG demonstrating the usage of the JdbcOperator.""" + +from datetime import timedelta + +from airflow import DAG +from airflow.operators.dummy import DummyOperator +from airflow.providers.jdbc.operators.jdbc import JdbcOperator +from airflow.utils.dates import days_ago + +args = { + "owner": "airflow", +} + +with DAG( + dag_id="example_jdbc_operator", + default_args=args, + schedule_interval="0 0 * * *", + start_date=days_ago(2), + dagrun_timeout=timedelta(minutes=60), + tags=["example"], +) as dag: + + run_this_last = DummyOperator( + task_id="run_this_last", + dag=dag, + ) + + # [START howto_operator_jdbc_template] + delete_data = JdbcOperator( + task_id="delete_data", + sql="delete from my_schema.my_table where dt = {{ ds }}", + jdbc_conn_id="my_jdbc_connection", + autocommit=True, + dag=dag, + ) + # [END howto_operator_jdbc_template] + + # [START howto_operator_jdbc] + insert_data = JdbcOperator( + task_id="insert_data", + sql="insert into my_schema.my_table select dt, value from my_schema.source_data", + jdbc_conn_id="my_jdbc_connection", + autocommit=True, + dag=dag, + ) + # [END howto_operator_jdbc] + + delete_data >> insert_data >> run_this_last diff --git a/reference/providers/jdbc/hooks/__init__.py b/reference/providers/jdbc/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/jdbc/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/jdbc/hooks/jdbc.py b/reference/providers/jdbc/hooks/jdbc.py new file mode 100644 index 0000000..7c513e7 --- /dev/null +++ b/reference/providers/jdbc/hooks/jdbc.py @@ -0,0 +1,105 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Optional + +import jaydebeapi +from airflow.hooks.dbapi import DbApiHook +from airflow.models.connection import Connection + + +class JdbcHook(DbApiHook): + """ + General hook for jdbc db access. + + JDBC URL, username and password will be taken from the predefined connection. + Note that the whole JDBC URL must be specified in the "host" field in the DB. + Raises an airflow error if the given connection id doesn't exist. + """ + + conn_name_attr = "jdbc_conn_id" + default_conn_name = "jdbc_default" + conn_type = "jdbc" + hook_name = "JDBC Connection" + supports_autocommit = True + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import StringField + + return { + "extra__jdbc__drv_path": StringField( + lazy_gettext("Driver Path"), widget=BS3TextFieldWidget() + ), + "extra__jdbc__drv_clsname": StringField( + lazy_gettext("Driver Class"), widget=BS3TextFieldWidget() + ), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ["port", "schema", "extra"], + "relabeling": {"host": "Connection URL"}, + } + + def get_conn(self) -> jaydebeapi.Connection: + conn: Connection = self.get_connection(getattr(self, self.conn_name_attr)) + host: str = conn.host + login: str = conn.login + psw: str = conn.password + jdbc_driver_loc: Optional[str] = conn.extra_dejson.get("extra__jdbc__drv_path") + jdbc_driver_name: Optional[str] = conn.extra_dejson.get( + "extra__jdbc__drv_clsname" + ) + + conn = jaydebeapi.connect( + jclassname=jdbc_driver_name, + url=str(host), + driver_args=[str(login), str(psw)], + jars=jdbc_driver_loc.split(",") if jdbc_driver_loc else None, + ) + return conn + + def set_autocommit(self, conn: jaydebeapi.Connection, autocommit: bool) -> None: + """ + Enable or disable autocommit for the given connection. + + :param conn: The connection. + :type conn: connection object + :param autocommit: The connection's autocommit setting. + :type autocommit: bool + """ + conn.jconn.setAutoCommit(autocommit) + + def get_autocommit(self, conn: jaydebeapi.Connection) -> bool: + """ + Get autocommit setting for the provided connection. + Return True if conn.autocommit is set to True. + Return False if conn.autocommit is not set or set to False + + :param conn: The connection. + :type conn: connection object + :return: connection autocommit setting. + :rtype: bool + """ + return conn.jconn.getAutoCommit() diff --git a/reference/providers/jdbc/operators/__init__.py b/reference/providers/jdbc/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/jdbc/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/jdbc/operators/jdbc.py b/reference/providers/jdbc/operators/jdbc.py new file mode 100644 index 0000000..dd81b31 --- /dev/null +++ b/reference/providers/jdbc/operators/jdbc.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Iterable, Mapping, Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.jdbc.hooks.jdbc import JdbcHook +from airflow.utils.decorators import apply_defaults + + +class JdbcOperator(BaseOperator): + """ + Executes sql code in a database using jdbc driver. + + Requires jaydebeapi. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:JdbcOperator` + + :param sql: the sql code to be executed. (templated) + :type sql: Can receive a str representing a sql statement, + a list of str (sql statements), or reference to a template file. + Template reference are recognized by str ending in '.sql' + :param jdbc_conn_id: reference to a predefined database + :type jdbc_conn_id: str + :param autocommit: if True, each command is automatically committed. + (default value: False) + :type autocommit: bool + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: dict or iterable + """ + + template_fields = ("sql",) + template_ext = (".sql",) + ui_color = "#ededed" + + @apply_defaults + def __init__( + self, + *, + sql: str, + jdbc_conn_id: str = "jdbc_default", + autocommit: bool = False, + parameters: Optional[Union[Mapping, Iterable]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.parameters = parameters + self.sql = sql + self.jdbc_conn_id = jdbc_conn_id + self.autocommit = autocommit + self.hook = None + + def execute(self, context) -> None: + self.log.info("Executing: %s", self.sql) + hook = JdbcHook(jdbc_conn_id=self.jdbc_conn_id) + hook.run(self.sql, self.autocommit, parameters=self.parameters) diff --git a/reference/providers/jdbc/provider.yaml b/reference/providers/jdbc/provider.yaml new file mode 100644 index 0000000..8fc8807 --- /dev/null +++ b/reference/providers/jdbc/provider.yaml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-jdbc +name: Java Database Connectivity (JDBC) +description: | + `Java Database Connectivity (JDBC) `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Java Database Connectivity (JDBC) + external-doc-url: https://docs.oracle.com/javase/8/docs/technotes/guides/jdbc/ + how-to-guide: + - /docs/apache-airflow-providers-jdbc/operators.rst + logo: /integration-logos/jdbc/JDBC.png + tags: [protocol] + +operators: + - integration-name: Java Database Connectivity (JDBC) + python-modules: + - airflow.providers.jdbc.operators.jdbc +hooks: + - integration-name: Java Database Connectivity (JDBC) + python-modules: + - airflow.providers.jdbc.hooks.jdbc + +hook-class-names: + - airflow.providers.jdbc.hooks.jdbc.JdbcHook diff --git a/reference/providers/jenkins/CHANGELOG.rst b/reference/providers/jenkins/CHANGELOG.rst new file mode 100644 index 0000000..26dc24b --- /dev/null +++ b/reference/providers/jenkins/CHANGELOG.rst @@ -0,0 +1,38 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.1.0 +..... + +Features +~~~~~~~~ + +* ``Add allowed_jenkins_states to JenkinsJobTriggerOperator (#14131)`` + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/jenkins/__init__.py b/reference/providers/jenkins/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/jenkins/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/jenkins/example_dags/__init__.py b/reference/providers/jenkins/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/jenkins/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/jenkins/example_dags/example_jenkins_job_trigger.py b/reference/providers/jenkins/example_dags/example_jenkins_job_trigger.py new file mode 100644 index 0000000..e66615c --- /dev/null +++ b/reference/providers/jenkins/example_dags/example_jenkins_job_trigger.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.operators.python import PythonOperator +from airflow.providers.jenkins.hooks.jenkins import JenkinsHook +from airflow.providers.jenkins.operators.jenkins_job_trigger import ( + JenkinsJobTriggerOperator, +) +from six.moves.urllib.request import Request + +default_args = { + "owner": "airflow", + "retries": 1, + "retry_delay": timedelta(minutes=5), + "depends_on_past": False, + "concurrency": 8, + "max_active_runs": 8, +} + + +with DAG( + "test_jenkins", + default_args=default_args, + start_date=datetime(2017, 6, 1), + schedule_interval=None, +) as dag: + job_trigger = JenkinsJobTriggerOperator( + task_id="trigger_job", + job_name="generate-merlin-config", + parameters={"first_parameter": "a_value", "second_parameter": "18"}, + # parameters="resources/parameter.json", You can also pass a path to a json file containing your param + jenkins_connection_id="your_jenkins_connection", # T he connection must be configured first + ) + + def grab_artifact_from_jenkins(**context): + """ + Grab an artifact from the previous job + The python-jenkins library doesn't expose a method for that + But it's totally possible to build manually the request for that + """ + hook = JenkinsHook("your_jenkins_connection") + jenkins_server = hook.get_jenkins_server() + url = context["task_instance"].xcom_pull(task_ids="trigger_job") + # The JenkinsJobTriggerOperator store the job url in the xcom variable corresponding to the task + # You can then use it to access things or to get the job number + # This url looks like : http://jenkins_url/job/job_name/job_number/ + url += "artifact/myartifact.xml" # Or any other artifact name + request = Request(url) + response = jenkins_server.jenkins_open(request) + return ( + response # We store the artifact content in a xcom variable for later use + ) + + artifact_grabber = PythonOperator( + task_id="artifact_grabber", python_callable=grab_artifact_from_jenkins + ) + + job_trigger >> artifact_grabber diff --git a/reference/providers/jenkins/hooks/__init__.py b/reference/providers/jenkins/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/jenkins/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/jenkins/hooks/jenkins.py b/reference/providers/jenkins/hooks/jenkins.py new file mode 100644 index 0000000..71ebeb9 --- /dev/null +++ b/reference/providers/jenkins/hooks/jenkins.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from distutils.util import strtobool + +import jenkins +from airflow.hooks.base import BaseHook + + +class JenkinsHook(BaseHook): + """Hook to manage connection to jenkins server""" + + conn_name_attr = "conn_id" + default_conn_name = "jenkins_default" + conn_type = "jenkins" + hook_name = "Jenkins" + + def __init__(self, conn_id: str = default_conn_name) -> None: + super().__init__() + connection = self.get_connection(conn_id) + self.connection = connection + connection_prefix = "http" + # connection.extra contains info about using https (true) or http (false) + if connection.extra is None or connection.extra == "": + connection.extra = "false" + # set a default value to connection.extra + # to avoid rising ValueError in strtobool + if strtobool(connection.extra): + connection_prefix = "https" + url = f"{connection_prefix}://{connection.host}:{connection.port}" + self.log.info("Trying to connect to %s", url) + self.jenkins_server = jenkins.Jenkins( + url, connection.login, connection.password + ) + + def get_jenkins_server(self) -> jenkins.Jenkins: + """Get jenkins server""" + return self.jenkins_server diff --git a/reference/providers/jenkins/operators/__init__.py b/reference/providers/jenkins/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/jenkins/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/jenkins/operators/jenkins_job_trigger.py b/reference/providers/jenkins/operators/jenkins_job_trigger.py new file mode 100644 index 0000000..6d9bd93 --- /dev/null +++ b/reference/providers/jenkins/operators/jenkins_job_trigger.py @@ -0,0 +1,275 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import ast +import json +import socket +import time +from typing import Any, Dict, Iterable, List, Mapping, Optional, Union +from urllib.error import HTTPError, URLError + +import jenkins +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.jenkins.hooks.jenkins import JenkinsHook +from airflow.utils.decorators import apply_defaults +from jenkins import Jenkins, JenkinsException +from requests import Request + +JenkinsRequest = Mapping[str, Any] +ParamType = Optional[Union[str, Dict, List]] + + +def jenkins_request_with_headers( + jenkins_server: Jenkins, req: Request +) -> Optional[JenkinsRequest]: + """ + We need to get the headers in addition to the body answer + to get the location from them + This function uses jenkins_request method from python-jenkins library + with just the return call changed + + :param jenkins_server: The server to query + :param req: The request to execute + :return: Dict containing the response body (key body) + and the headers coming along (headers) + """ + try: + response = jenkins_server.jenkins_request(req) + response_body = response.content + response_headers = response.headers + if response_body is None: + raise jenkins.EmptyResponseException( + f"Error communicating with server[{jenkins_server.server}]: empty response" + ) + return {"body": response_body.decode("utf-8"), "headers": response_headers} + except HTTPError as e: + # Jenkins's funky authentication means its nigh impossible to distinguish errors. + if e.code in [401, 403, 500]: + raise JenkinsException( + f"Error in request. Possibly authentication failed [{e.code}]: {e.reason}" + ) + elif e.code == 404: + raise jenkins.NotFoundException("Requested item could not be found") + else: + raise + except socket.timeout as e: + raise jenkins.TimeoutException(f"Error in request: {e}") + except URLError as e: + raise JenkinsException(f"Error in request: {e.reason}") + return None + + +class JenkinsJobTriggerOperator(BaseOperator): + """ + Trigger a Jenkins Job and monitor it's execution. + This operator depend on python-jenkins library, + version >= 0.4.15 to communicate with jenkins server. + You'll also need to configure a Jenkins connection in the connections screen. + + :param jenkins_connection_id: The jenkins connection to use for this job + :type jenkins_connection_id: str + :param job_name: The name of the job to trigger + :type job_name: str + :param parameters: The parameters block provided to jenkins for use in + the API call when triggering a build. (templated) + :type parameters: str, Dict, or List + :param sleep_time: How long will the operator sleep between each status + request for the job (min 1, default 10) + :type sleep_time: int + :param max_try_before_job_appears: The maximum number of requests to make + while waiting for the job to appears on jenkins server (default 10) + :type max_try_before_job_appears: int + :param allowed_jenkins_states: Iterable of allowed result jenkins states, default is ``['SUCCESS']`` + :type allowed_jenkins_states: Optional[Iterable[str]] + """ + + template_fields = ("parameters",) + template_ext = (".json",) + ui_color = "#f9ec86" + + @apply_defaults + def __init__( + self, + *, + jenkins_connection_id: str, + job_name: str, + parameters: ParamType = "", + sleep_time: int = 10, + max_try_before_job_appears: int = 10, + allowed_jenkins_states: Optional[Iterable[str]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.job_name = job_name + self.parameters = parameters + self.sleep_time = max(sleep_time, 1) + self.jenkins_connection_id = jenkins_connection_id + self.max_try_before_job_appears = max_try_before_job_appears + self.allowed_jenkins_states = ( + list(allowed_jenkins_states) if allowed_jenkins_states else ["SUCCESS"] + ) + + def build_job( + self, jenkins_server: Jenkins, params: ParamType = "" + ) -> Optional[JenkinsRequest]: + """ + This function makes an API call to Jenkins to trigger a build for 'job_name' + It returned a dict with 2 keys : body and headers. + headers contains also a dict-like object which can be queried to get + the location to poll in the queue. + + :param jenkins_server: The jenkins server where the job should be triggered + :param params: The parameters block to provide to jenkins API call. + :return: Dict containing the response body (key body) + and the headers coming along (headers) + """ + # Since params can be either JSON string, dictionary, or list, + # check type and pass to build_job_url + if params and isinstance(params, str): + params = ast.literal_eval(params) + + # We need a None to call the non-parametrized jenkins api end point + if not params: + params = None + + request = Request( + method="POST", url=jenkins_server.build_job_url(self.job_name, params, None) + ) + return jenkins_request_with_headers(jenkins_server, request) + + def poll_job_in_queue(self, location: str, jenkins_server: Jenkins) -> int: + """ + This method poll the jenkins queue until the job is executed. + When we trigger a job through an API call, + the job is first put in the queue without having a build number assigned. + Thus we have to wait the job exit the queue to know its build number. + To do so, we have to add /api/json (or /api/xml) to the location + returned by the build_job call and poll this file. + When a 'executable' block appears in the json, it means the job execution started + and the field 'number' then contains the build number. + + :param location: Location to poll, returned in the header of the build_job call + :param jenkins_server: The jenkins server to poll + :return: The build_number corresponding to the triggered job + """ + try_count = 0 + location += "/api/json" + # TODO Use get_queue_info instead + # once it will be available in python-jenkins (v > 0.4.15) + self.log.info("Polling jenkins queue at the url %s", location) + while try_count < self.max_try_before_job_appears: + location_answer = jenkins_request_with_headers( + jenkins_server, Request(method="POST", url=location) + ) + if location_answer is not None: + json_response = json.loads(location_answer["body"]) + if "executable" in json_response: + build_number = json_response["executable"]["number"] + self.log.info( + "Job executed on Jenkins side with the build number %s", + build_number, + ) + return build_number + try_count += 1 + time.sleep(self.sleep_time) + raise AirflowException( + "The job hasn't been executed after polling " + f"the queue {self.max_try_before_job_appears} times" + ) + + def get_hook(self) -> JenkinsHook: + """Instantiate jenkins hook""" + return JenkinsHook(self.jenkins_connection_id) + + def execute(self, context: Mapping[Any, Any]) -> Optional[str]: + if not self.jenkins_connection_id: + self.log.error( + "Please specify the jenkins connection id to use." + "You must create a Jenkins connection before" + " being able to use this operator" + ) + raise AirflowException( + "The jenkins_connection_id parameter is missing, impossible to trigger the job" + ) + + if not self.job_name: + self.log.error( + "Please specify the job name to use in the job_name parameter" + ) + raise AirflowException( + "The job_name parameter is missing,impossible to trigger the job" + ) + + self.log.info( + "Triggering the job %s on the jenkins : %s with the parameters : %s", + self.job_name, + self.jenkins_connection_id, + self.parameters, + ) + jenkins_server = self.get_hook().get_jenkins_server() + jenkins_response = self.build_job(jenkins_server, self.parameters) + if jenkins_response: + build_number = self.poll_job_in_queue( + jenkins_response["headers"]["Location"], jenkins_server + ) + + time.sleep(self.sleep_time) + keep_polling_job = True + build_info = None + # pylint: disable=too-many-nested-blocks + while keep_polling_job: + try: + build_info = jenkins_server.get_build_info( + name=self.job_name, number=build_number + ) + if build_info["result"] is not None: + keep_polling_job = False + # Check if job ended with not allowed state. + if build_info["result"] not in self.allowed_jenkins_states: + raise AirflowException( + "Jenkins job failed, final state : %s." + "Find more information on job url : %s" + % (build_info["result"], build_info["url"]) + ) + else: + self.log.info( + "Waiting for job to complete : %s , build %s", + self.job_name, + build_number, + ) + time.sleep(self.sleep_time) + except jenkins.NotFoundException as err: + # pylint: disable=no-member + raise AirflowException( + "Jenkins job status check failed. Final error was: " + f"{err.resp.status}" + ) + except jenkins.JenkinsException as err: + raise AirflowException( + f"Jenkins call failed with error : {err}, if you have parameters " + "double check them, jenkins sends back " + "this exception for unknown parameters" + "You can also check logs for more details on this exception " + "(jenkins_url/log/rss)" + ) + if build_info: + # If we can we return the url of the job + # for later use (like retrieving an artifact) + return build_info["url"] + return None diff --git a/reference/providers/jenkins/provider.yaml b/reference/providers/jenkins/provider.yaml new file mode 100644 index 0000000..47d10e6 --- /dev/null +++ b/reference/providers/jenkins/provider.yaml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-jenkins +name: Jenkins +description: | + `Jenkins `__ + +versions: + - 1.1.0 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Jenkins + external-doc-url: https://jenkins.io/ + logo: /integration-logos/jenkins/Jenkins.png + tags: [software] + +operators: + - integration-name: Jenkins + python-modules: + - airflow.providers.jenkins.operators.jenkins_job_trigger +hooks: + - integration-name: Jenkins + python-modules: + - airflow.providers.jenkins.hooks.jenkins + +hook-class-names: + - airflow.providers.jenkins.hooks.jenkins.JenkinsHook diff --git a/reference/providers/jira/CHANGELOG.rst b/reference/providers/jira/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/jira/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/jira/__init__.py b/reference/providers/jira/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/jira/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/jira/hooks/__init__.py b/reference/providers/jira/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/jira/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/jira/hooks/jira.py b/reference/providers/jira/hooks/jira.py new file mode 100644 index 0000000..1cb8f3c --- /dev/null +++ b/reference/providers/jira/hooks/jira.py @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for JIRA""" +from typing import Any, Optional + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from jira import JIRA +from jira.exceptions import JIRAError + + +class JiraHook(BaseHook): + """ + Jira interaction hook, a Wrapper around JIRA Python SDK. + + :param jira_conn_id: reference to a pre-defined Jira Connection + :type jira_conn_id: str + """ + + default_conn_name = "jira_default" + conn_type = "jira" + conn_name_attr = "jira_conn_id" + hook_name = "JIRA" + + def __init__( + self, jira_conn_id: str = default_conn_name, proxies: Optional[Any] = None + ) -> None: + super().__init__() + self.jira_conn_id = jira_conn_id + self.proxies = proxies + self.client = None + self.get_conn() + + def get_conn(self) -> JIRA: + if not self.client: + self.log.debug("Creating Jira client for conn_id: %s", self.jira_conn_id) + + get_server_info = True + validate = True + extra_options = {} + if not self.jira_conn_id: + raise AirflowException( + "Failed to create jira client. no jira_conn_id provided" + ) + + conn = self.get_connection(self.jira_conn_id) + if conn.extra is not None: + extra_options = conn.extra_dejson + # only required attributes are taken for now, + # more can be added ex: async, logging, max_retries + + # verify + if ( + "verify" in extra_options + and extra_options["verify"].lower() == "false" + ): + extra_options["verify"] = False + + # validate + if ( + "validate" in extra_options + and extra_options["validate"].lower() == "false" + ): + validate = False + + if ( + "get_server_info" in extra_options + and extra_options["get_server_info"].lower() == "false" + ): + get_server_info = False + + try: + self.client = JIRA( + conn.host, + options=extra_options, + basic_auth=(conn.login, conn.password), + get_server_info=get_server_info, + validate=validate, + proxies=self.proxies, + ) + except JIRAError as jira_error: + raise AirflowException( + f"Failed to create jira client, jira error: {str(jira_error)}" + ) + except Exception as e: + raise AirflowException(f"Failed to create jira client, error: {str(e)}") + + return self.client diff --git a/reference/providers/jira/operators/__init__.py b/reference/providers/jira/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/jira/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/jira/operators/jira.py b/reference/providers/jira/operators/jira.py new file mode 100644 index 0000000..e7ebbe4 --- /dev/null +++ b/reference/providers/jira/operators/jira.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Callable, Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.jira.hooks.jira import JIRAError, JiraHook +from airflow.utils.decorators import apply_defaults + + +class JiraOperator(BaseOperator): + """ + JiraOperator to interact and perform action on Jira issue tracking system. + This operator is designed to use Jira Python SDK: http://jira.readthedocs.io + + :param jira_conn_id: reference to a pre-defined Jira Connection + :type jira_conn_id: str + :param jira_method: method name from Jira Python SDK to be called + :type jira_method: str + :param jira_method_args: required method parameters for the jira_method. (templated) + :type jira_method_args: dict + :param result_processor: function to further process the response from Jira + :type result_processor: function + :param get_jira_resource_method: function or operator to get jira resource + on which the provided jira_method will be executed + :type get_jira_resource_method: function + """ + + template_fields = ("jira_method_args",) + + @apply_defaults + def __init__( + self, + *, + jira_method: str, + jira_conn_id: str = "jira_default", + jira_method_args: Optional[dict] = None, + result_processor: Optional[Callable] = None, + get_jira_resource_method: Optional[Callable] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.jira_conn_id = jira_conn_id + self.method_name = jira_method + self.jira_method_args = jira_method_args + self.result_processor = result_processor + self.get_jira_resource_method = get_jira_resource_method + + def execute(self, context: Dict) -> Any: + try: + if self.get_jira_resource_method is not None: + # if get_jira_resource_method is provided, jira_method will be executed on + # resource returned by executing the get_jira_resource_method. + # This makes all the provided methods of JIRA sdk accessible and usable + # directly at the JiraOperator without additional wrappers. + # ref: http://jira.readthedocs.io/en/latest/api.html + if isinstance(self.get_jira_resource_method, JiraOperator): + resource = self.get_jira_resource_method.execute(**context) + else: + resource = self.get_jira_resource_method(**context) + else: + # Default method execution is on the top level jira client resource + hook = JiraHook(jira_conn_id=self.jira_conn_id) + resource = hook.client + + # Current Jira-Python SDK (1.0.7) has issue with pickling the jira response. + # ex: self.xcom_push(context, key='operator_response', value=jira_response) + # This could potentially throw error if jira_result is not picklable + jira_result = getattr(resource, self.method_name)(**self.jira_method_args) + if self.result_processor: + return self.result_processor(context, jira_result) + + return jira_result + + except JIRAError as jira_error: + raise AirflowException( + f"Failed to execute jiraOperator, error: {str(jira_error)}" + ) + except Exception as e: + raise AirflowException(f"Jira operator error: {str(e)}") diff --git a/reference/providers/jira/provider.yaml b/reference/providers/jira/provider.yaml new file mode 100644 index 0000000..4ad1136 --- /dev/null +++ b/reference/providers/jira/provider.yaml @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-jira +name: Jira +description: | + `Atlassian Jira `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Atlassian Jira + external-doc-url: https://www.atlassian.com/pl/software/jira + logo: /integration-logos/jira/Jira.png + tags: [software] + +operators: + - integration-name: Atlassian Jira + python-modules: + - airflow.providers.jira.operators.jira + +sensors: + - integration-name: Atlassian Jira + python-modules: + - airflow.providers.jira.sensors.jira + +hooks: + - integration-name: Atlassian Jira + python-modules: + - airflow.providers.jira.hooks.jira + +hook-class-names: + - airflow.providers.jira.hooks.jira.JiraHook diff --git a/reference/providers/jira/sensors/__init__.py b/reference/providers/jira/sensors/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/jira/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/jira/sensors/jira.py b/reference/providers/jira/sensors/jira.py new file mode 100644 index 0000000..e3569a1 --- /dev/null +++ b/reference/providers/jira/sensors/jira.py @@ -0,0 +1,166 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Callable, Dict, Optional + +from airflow.providers.jira.operators.jira import JIRAError, JiraOperator +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults +from jira.resources import Issue, Resource + + +class JiraSensor(BaseSensorOperator): + """ + Monitors a jira ticket for any change. + + :param jira_conn_id: reference to a pre-defined Jira Connection + :type jira_conn_id: str + :param method_name: method name from jira-python-sdk to be execute + :type method_name: str + :param method_params: parameters for the method method_name + :type method_params: dict + :param result_processor: function that return boolean and act as a sensor response + :type result_processor: function + """ + + @apply_defaults + def __init__( + self, + *, + method_name: str, + jira_conn_id: str = "jira_default", + method_params: Optional[dict] = None, + result_processor: Optional[Callable] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.jira_conn_id = jira_conn_id + self.result_processor = None + if result_processor is not None: + self.result_processor = result_processor + self.method_name = method_name + self.method_params = method_params + self.jira_operator = JiraOperator( + task_id=self.task_id, + jira_conn_id=self.jira_conn_id, + jira_method=self.method_name, + jira_method_args=self.method_params, + result_processor=self.result_processor, + ) + + def poke(self, context: Dict) -> Any: + return self.jira_operator.execute(context=context) + + +class JiraTicketSensor(JiraSensor): + """ + Monitors a jira ticket for given change in terms of function. + + :param jira_conn_id: reference to a pre-defined Jira Connection + :type jira_conn_id: str + :param ticket_id: id of the ticket to be monitored + :type ticket_id: str + :param field: field of the ticket to be monitored + :type field: str + :param expected_value: expected value of the field + :type expected_value: str + :param result_processor: function that return boolean and act as a sensor response + :type result_processor: function + """ + + template_fields = ("ticket_id",) + + @apply_defaults + def __init__( + self, + *, + jira_conn_id: str = "jira_default", + ticket_id: Optional[str] = None, + field: Optional[str] = None, + expected_value: Optional[str] = None, + field_checker_func: Optional[Callable] = None, + **kwargs, + ) -> None: + + self.jira_conn_id = jira_conn_id + self.ticket_id = ticket_id + self.field = field + self.expected_value = expected_value + if field_checker_func is None: + field_checker_func = self.issue_field_checker + + super().__init__( + jira_conn_id=jira_conn_id, result_processor=field_checker_func, **kwargs + ) + + def poke(self, context: Dict) -> Any: + self.log.info("Jira Sensor checking for change in ticket: %s", self.ticket_id) + + self.jira_operator.method_name = "issue" + self.jira_operator.jira_method_args = { + "id": self.ticket_id, + "fields": self.field, + } + return JiraSensor.poke(self, context=context) + + def issue_field_checker(self, issue: Issue) -> Optional[bool]: + """Check issue using different conditions to prepare to evaluate sensor.""" + result = None + try: # pylint: disable=too-many-nested-blocks + if ( + issue is not None + and self.field is not None + and self.expected_value is not None + ): + + field_val = getattr(issue.fields, self.field) + if field_val is not None: + if isinstance(field_val, list): + result = self.expected_value in field_val + elif isinstance(field_val, str): + result = self.expected_value.lower() == field_val.lower() + elif isinstance(field_val, Resource) and getattr(field_val, "name"): + result = self.expected_value.lower() == field_val.name.lower() + else: + self.log.warning( + "Not implemented checker for issue field %s which " + "is neither string nor list nor Jira Resource", + self.field, + ) + + except JIRAError as jira_error: + self.log.error( + "Jira error while checking with expected value: %s", jira_error + ) + except Exception as e: # pylint: disable=broad-except + self.log.error( + "Error while checking with expected value %s:", self.expected_value + ) + self.log.exception(e) + if result is True: + self.log.info( + "Issue field %s has expected value %s, returning success", + self.field, + self.expected_value, + ) + else: + self.log.info( + "Issue field %s don't have expected value %s yet.", + self.field, + self.expected_value, + ) + return result diff --git a/reference/providers/microsoft/__init__.py b/reference/providers/microsoft/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/microsoft/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/azure/CHANGELOG.rst b/reference/providers/microsoft/azure/CHANGELOG.rst new file mode 100644 index 0000000..403dbbf --- /dev/null +++ b/reference/providers/microsoft/azure/CHANGELOG.rst @@ -0,0 +1,51 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.2.0 +..... + +Features +~~~~~~~~ + +* ``Add Azure Data Factory hook (#11015)`` + +Bug fixes +~~~~~~~~~ + +* ``BugFix: Fix remote log in azure storage blob displays in one line (#14313)`` +* ``Fix AzureDataFactoryHook failing to instantiate its connection (#14565)`` + +1.1.0 +..... + +Updated documentation and readme files. + +Features +~~~~~~~~ + +* ``Upgrade azure blob to v12 (#12188)`` +* ``Fix Azure Data Explorer Operator (#13520)`` +* ``add AzureDatalakeStorageDeleteOperator (#13206)`` + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/microsoft/azure/__init__.py b/reference/providers/microsoft/azure/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/microsoft/azure/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/azure/example_dags/__init__.py b/reference/providers/microsoft/azure/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/microsoft/azure/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/azure/example_dags/example_adls_delete.py b/reference/providers/microsoft/azure/example_dags/example_adls_delete.py new file mode 100644 index 0000000..717ed0a --- /dev/null +++ b/reference/providers/microsoft/azure/example_dags/example_adls_delete.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.microsoft.azure.operators.adls_delete import ( + AzureDataLakeStorageDeleteOperator, +) +from airflow.providers.microsoft.azure.transfers.local_to_adls import ( + LocalToAzureDataLakeStorageOperator, +) +from airflow.utils.dates import days_ago + +LOCAL_FILE_PATH = os.environ.get("LOCAL_FILE_PATH", "localfile.txt") +REMOTE_FILE_PATH = os.environ.get("REMOTE_LOCAL_PATH", "remote.txt") + + +with models.DAG( + "example_adls_delete", + start_date=days_ago(1), + schedule_interval=None, + tags=["example"], +) as dag: + + upload_file = LocalToAzureDataLakeStorageOperator( + task_id="upload_task", + local_path=LOCAL_FILE_PATH, + remote_path=REMOTE_FILE_PATH, + ) + # [START howto_operator_adls_delete] + remove_file = AzureDataLakeStorageDeleteOperator( + task_id="delete_task", path=REMOTE_FILE_PATH, recursive=True + ) + # [END howto_operator_adls_delete] + + upload_file >> remove_file diff --git a/reference/providers/microsoft/azure/example_dags/example_azure_blob_to_gcs.py b/reference/providers/microsoft/azure/example_dags/example_azure_blob_to_gcs.py new file mode 100644 index 0000000..7183972 --- /dev/null +++ b/reference/providers/microsoft/azure/example_dags/example_azure_blob_to_gcs.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os + +from airflow import DAG +from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor +from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import ( + AzureBlobStorageToGCSOperator, +) +from airflow.utils.dates import days_ago + +BLOB_NAME = os.environ.get("AZURE_BLOB_NAME", "file.txt") +AZURE_CONTAINER_NAME = os.environ.get("AZURE_CONTAINER_NAME", "airflow") +GCP_BUCKET_FILE_PATH = os.environ.get("GCP_BUCKET_FILE_PATH", "file.txt") +GCP_BUCKET_NAME = os.environ.get("GCP_BUCKET_NAME", "azure_bucket") +GCP_OBJECT_NAME = os.environ.get("GCP_OBJECT_NAME", "file.txt") + + +with DAG( + "example_azure_blob_to_gcs", + schedule_interval=None, + start_date=days_ago(1), # Override to match your needs +) as dag: + + # [START how_to_wait_for_blob] + wait_for_blob = WasbBlobSensor( + task_id="wait_for_blob", + wasb_conn_id="wasb_default", + container_name=AZURE_CONTAINER_NAME, + blob_name=BLOB_NAME, + ) + # [END how_to_wait_for_blob] + + # [START how_to_azure_blob_to_gcs] + transfer_files_to_gcs = AzureBlobStorageToGCSOperator( + task_id="transfer_files_to_gcs", + # AZURE args + wasb_conn_id="wasb_default", + container_name=AZURE_CONTAINER_NAME, + blob_name=BLOB_NAME, + file_path=GCP_OBJECT_NAME, + # GCP args + gcp_conn_id="google_cloud_default", + bucket_name=GCP_BUCKET_NAME, + object_name=GCP_OBJECT_NAME, + filename=GCP_BUCKET_FILE_PATH, + gzip=False, + delegate_to=None, + impersonation_chain=None, + ) + # [END how_to_azure_blob_to_gcs] + + wait_for_blob >> transfer_files_to_gcs diff --git a/reference/providers/microsoft/azure/example_dags/example_azure_container_instances.py b/reference/providers/microsoft/azure/example_dags/example_azure_container_instances.py new file mode 100644 index 0000000..2e661b1 --- /dev/null +++ b/reference/providers/microsoft/azure/example_dags/example_azure_container_instances.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example dag for using the AzureContainerInstancesOperator. +""" +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.providers.microsoft.azure.operators.azure_container_instances import ( + AzureContainerInstancesOperator, +) + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), +} + +with DAG( + dag_id="aci_example", + default_args=default_args, + schedule_interval=timedelta(1), + start_date=datetime(2018, 11, 1), + tags=["example"], +) as dag: + + t1 = AzureContainerInstancesOperator( + ci_conn_id="azure_container_instances_default", + registry_conn_id=None, + resource_group="resource-group", + name="aci-test-{{ ds }}", + image="hello-world", + region="WestUS2", + environment_variables={}, + volumes=[], + memory_in_gb=4.0, + cpu=1.0, + task_id="start_container", + ) diff --git a/reference/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py b/reference/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py new file mode 100644 index 0000000..97a42bd --- /dev/null +++ b/reference/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This is only an example DAG to highlight usage of AzureCosmosDocumentSensor to detect +if a document now exists. + +You can trigger this manually with `airflow dags trigger example_cosmosdb_sensor`. + +*Note: Make sure that connection `azure_cosmos_default` is properly set before running +this example.* +""" + +from airflow import DAG +from airflow.providers.microsoft.azure.operators.azure_cosmos import ( + AzureCosmosInsertDocumentOperator, +) +from airflow.providers.microsoft.azure.sensors.azure_cosmos import ( + AzureCosmosDocumentSensor, +) +from airflow.utils import dates + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, +} + +with DAG( + dag_id="example_azure_cosmosdb_sensor", + default_args=default_args, + start_date=dates.days_ago(2), + doc_md=__doc__, + tags=["example"], +) as dag: + + t1 = AzureCosmosDocumentSensor( + task_id="check_cosmos_file", + database_name="airflow_example_db", + collection_name="airflow_example_coll", + document_id="airflow_checkid", + azure_cosmos_conn_id="azure_cosmos_default", + ) + + t2 = AzureCosmosInsertDocumentOperator( + task_id="insert_cosmos_file", + database_name="airflow_example_db", + collection_name="new-collection", + document={"id": "someuniqueid", "param1": "value1", "param2": "value2"}, + azure_cosmos_conn_id="azure_cosmos_default", + ) + + t1 >> t2 diff --git a/reference/providers/microsoft/azure/example_dags/example_file_to_wasb.py b/reference/providers/microsoft/azure/example_dags/example_file_to_wasb.py new file mode 100644 index 0000000..d69c3c6 --- /dev/null +++ b/reference/providers/microsoft/azure/example_dags/example_file_to_wasb.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow.models import DAG +from airflow.providers.microsoft.azure.operators.wasb_delete_blob import ( + WasbDeleteBlobOperator, +) +from airflow.providers.microsoft.azure.transfers.file_to_wasb import FileToWasbOperator +from airflow.utils.dates import days_ago + +PATH_TO_UPLOAD_FILE = os.environ.get("AZURE_PATH_TO_UPLOAD_FILE", "example-text.txt") + +with DAG( + "example_file_to_wasb", schedule_interval="@once", start_date=days_ago(2) +) as dag: + upload = FileToWasbOperator( + task_id="upload_file", + file_path=PATH_TO_UPLOAD_FILE, + container_name="mycontainer", + blob_name="myblob", + ) + delete = WasbDeleteBlobOperator( + task_id="delete_file", container_name="mycontainer", blob_name="myblob" + ) + upload >> delete diff --git a/reference/providers/microsoft/azure/example_dags/example_fileshare.py b/reference/providers/microsoft/azure/example_dags/example_fileshare.py new file mode 100644 index 0000000..f4e6e61 --- /dev/null +++ b/reference/providers/microsoft/azure/example_dags/example_fileshare.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.models import DAG +from airflow.operators.python import PythonOperator +from airflow.providers.microsoft.azure.hooks.azure_fileshare import AzureFileShareHook +from airflow.utils.dates import days_ago + +NAME = "myfileshare" +DIRECTORY = "mydirectory" + + +def create_fileshare(): + """Create a fileshare with directory""" + hook = AzureFileShareHook() + hook.create_share(NAME) + hook.create_directory(share_name=NAME, directory_name=DIRECTORY) + exists = hook.check_for_directory(share_name=NAME, directory_name=DIRECTORY) + if not exists: + raise Exception + + +def delete_fileshare(): + """Delete a fileshare""" + hook = AzureFileShareHook() + hook.delete_share(NAME) + + +with DAG("example_fileshare", schedule_interval="@once", start_date=days_ago(2)) as dag: + create = PythonOperator(task_id="create-share", python_callable=create_fileshare) + delete = PythonOperator(task_id="delete-share", python_callable=delete_fileshare) + + create >> delete diff --git a/reference/providers/microsoft/azure/example_dags/example_local_to_adls.py b/reference/providers/microsoft/azure/example_dags/example_local_to_adls.py new file mode 100644 index 0000000..1a1b5aa --- /dev/null +++ b/reference/providers/microsoft/azure/example_dags/example_local_to_adls.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow import models +from airflow.providers.microsoft.azure.operators.adls_delete import ( + AzureDataLakeStorageDeleteOperator, +) +from airflow.providers.microsoft.azure.transfers.local_to_adls import ( + LocalToAzureDataLakeStorageOperator, +) +from airflow.utils.dates import days_ago + +LOCAL_FILE_PATH = os.environ.get("LOCAL_FILE_PATH", "localfile.txt") +REMOTE_FILE_PATH = os.environ.get("REMOTE_LOCAL_PATH", "remote.txt") + +with models.DAG( + "example_local_to_adls", + start_date=days_ago(1), + schedule_interval=None, + tags=["example"], +) as dag: + # [START howto_operator_local_to_adls] + upload_file = LocalToAzureDataLakeStorageOperator( + task_id="upload_task", + local_path=LOCAL_FILE_PATH, + remote_path=REMOTE_FILE_PATH, + ) + # [END howto_operator_local_to_adls] + + delete_file = AzureDataLakeStorageDeleteOperator( + task_id="remove_task", path=REMOTE_FILE_PATH, recursive=True + ) + + upload_file >> delete_file diff --git a/reference/providers/microsoft/azure/hooks/__init__.py b/reference/providers/microsoft/azure/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/microsoft/azure/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/azure/hooks/adx.py b/reference/providers/microsoft/azure/hooks/adx.py new file mode 100644 index 0000000..fbdb89d --- /dev/null +++ b/reference/providers/microsoft/azure/hooks/adx.py @@ -0,0 +1,159 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +"""This module contains Azure Data Explorer hook""" +from typing import Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from azure.kusto.data.exceptions import KustoServiceError +from azure.kusto.data.request import ( + ClientRequestProperties, + KustoClient, + KustoConnectionStringBuilder, +) +from azure.kusto.data.response import KustoResponseDataSetV2 + + +class AzureDataExplorerHook(BaseHook): + """ + Interacts with Azure Data Explorer (Kusto). + + Extra JSON field contains the following parameters: + + .. code-block:: json + + { + "tenant": "", + "auth_method": "", + "certificate": "", + "thumbprint": "" + } + + **Cluster**: + + Azure Data Explorer cluster is specified by a URL, for example: "https://help.kusto.windows.net". + The parameter must be provided through `Host` connection detail. + + **Tenant ID**: + + To learn about tenants refer to: https://docs.microsoft.com/en-us/onedrive/find-your-office-365-tenant-id + + **Authentication methods**: + + Authentication method must be provided through "auth_method" extra parameter. + Available authentication methods are: + + - AAD_APP : Authentication with AAD application certificate. Extra parameters: + "tenant" is required when using this method. Provide application ID + and application key through username and password parameters. + + - AAD_APP_CERT: Authentication with AAD application certificate. Extra parameters: + "tenant", "certificate" and "thumbprint" are required + when using this method. + + - AAD_CREDS : Authentication with AAD username and password. Extra parameters: + "tenant" is required when using this method. Username and password + parameters are used for authentication with AAD. + + - AAD_DEVICE : Authenticate with AAD device code. Please note that if you choose + this option, you'll need to authenticate for every new instance + that is initialized. It is highly recommended to create one instance + and use it for all queries. + + :param azure_data_explorer_conn_id: Reference to the Azure Data Explorer connection. + :type azure_data_explorer_conn_id: str + """ + + conn_name_attr = "azure_data_explorer_conn_id" + default_conn_name = "azure_data_explorer_default" + conn_type = "azure_data_explorer" + hook_name = "Azure Data Explorer" + + def __init__(self, azure_data_explorer_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = azure_data_explorer_conn_id + self.connection = self.get_conn() + + def get_conn(self) -> KustoClient: + """Return a KustoClient object.""" + conn = self.get_connection(self.conn_id) + cluster = conn.host + if not cluster: + raise AirflowException("Host connection option is required") + + def get_required_param(name: str) -> str: + """Extract required parameter from extra JSON, raise exception if not found""" + value = conn.extra_dejson.get(name) + if not value: + raise AirflowException( + f"Extra connection option is missing required parameter: `{name}`" + ) + return value + + auth_method = get_required_param("auth_method") + + if auth_method == "AAD_APP": + kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication( + cluster, conn.login, conn.password, get_required_param("tenant") + ) + elif auth_method == "AAD_APP_CERT": + kcsb = KustoConnectionStringBuilder.with_aad_application_certificate_authentication( + cluster, + conn.login, + get_required_param("certificate"), + get_required_param("thumbprint"), + get_required_param("tenant"), + ) + elif auth_method == "AAD_CREDS": + kcsb = KustoConnectionStringBuilder.with_aad_user_password_authentication( + cluster, conn.login, conn.password, get_required_param("tenant") + ) + elif auth_method == "AAD_DEVICE": + kcsb = KustoConnectionStringBuilder.with_aad_device_authentication(cluster) + else: + raise AirflowException(f"Unknown authentication method: {auth_method}") + + return KustoClient(kcsb) + + def run_query( + self, query: str, database: str, options: Optional[Dict] = None + ) -> KustoResponseDataSetV2: + """ + Run KQL query using provided configuration, and return + `azure.kusto.data.response.KustoResponseDataSet` instance. + If query is unsuccessful AirflowException is raised. + + :param query: KQL query to run + :type query: str + :param database: Database to run the query on. + :type database: str + :param options: Optional query options. See: + https://docs.microsoft.com/en-us/azure/kusto/api/netfx/request-properties#list-of-clientrequestproperties + :type options: dict + :return: dict + """ + properties = ClientRequestProperties() + if options: + for k, v in options.items(): + properties.set_option(k, v) + try: + return self.connection.execute(database, query, properties=properties) + except KustoServiceError as error: + raise AirflowException(f"Error running Kusto query: {error}") diff --git a/reference/providers/microsoft/azure/hooks/azure_batch.py b/reference/providers/microsoft/azure/hooks/azure_batch.py new file mode 100644 index 0000000..e4f8e3e --- /dev/null +++ b/reference/providers/microsoft/azure/hooks/azure_batch.py @@ -0,0 +1,395 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import time +from datetime import timedelta +from typing import Optional, Set + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.models import Connection +from airflow.utils import timezone +from azure.batch import BatchServiceClient, batch_auth +from azure.batch import models as batch_models +from azure.batch.models import JobAddParameter, PoolAddParameter, TaskAddParameter + + +class AzureBatchHook(BaseHook): + """ + Hook for Azure Batch APIs + + Account name and account key should be in login and password parameters. + The account url should be in extra parameter as account_url + """ + + conn_name_attr = "azure_batch_conn_id" + default_conn_name = "azure_batch_default" + conn_type = "azure_batch" + hook_name = "Azure Batch Service" + + def __init__(self, azure_batch_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = azure_batch_conn_id + self.connection = self.get_conn() + self.extra = self._connection().extra_dejson + + def _connection(self) -> Connection: + """Get connected to azure batch service""" + conn = self.get_connection(self.conn_id) + return conn + + def get_conn(self): + """ + Get the batch client connection + + :return: Azure batch client + """ + conn = self._connection() + + def _get_required_param(name): + """Extract required parameter from extra JSON, raise exception if not found""" + value = conn.extra_dejson.get(name) + if not value: + raise AirflowException( + f"Extra connection option is missing required parameter: `{name}`" + ) + return value + + batch_account_url = _get_required_param("account_url") + credentials = batch_auth.SharedKeyCredentials(conn.login, conn.password) + batch_client = BatchServiceClient(credentials, batch_url=batch_account_url) + return batch_client + + def configure_pool( # pylint: disable=too-many-arguments + self, + pool_id: str, + vm_size: Optional[str] = None, + vm_publisher: Optional[str] = None, + vm_offer: Optional[str] = None, + sku_starts_with: Optional[str] = None, + vm_sku: Optional[str] = None, + vm_version: Optional[str] = None, + vm_node_agent_sku_id: Optional[str] = None, + os_family: Optional[str] = None, + os_version: Optional[str] = None, + display_name: Optional[str] = None, + target_dedicated_nodes: Optional[int] = None, + use_latest_image_and_sku: bool = False, + **kwargs, + ) -> PoolAddParameter: + """ + Configures a pool + + :param pool_id: A string that uniquely identifies the Pool within the Account + :type pool_id: str + + :param vm_size: The size of virtual machines in the Pool. + :type vm_size: str + + :param display_name: The display name for the Pool + :type display_name: str + + :param target_dedicated_nodes: The desired number of dedicated Compute Nodes in the Pool. + :type target_dedicated_nodes: Optional[int] + + :param use_latest_image_and_sku: Whether to use the latest verified vm image and sku + :type use_latest_image_and_sku: bool + + :param vm_publisher: The publisher of the Azure Virtual Machines Marketplace Image. + For example, Canonical or MicrosoftWindowsServer. + :type vm_publisher: Optional[str] + + :param vm_offer: The offer type of the Azure Virtual Machines Marketplace Image. + For example, UbuntuServer or WindowsServer. + :type vm_offer: Optional[str] + + :param sku_starts_with: The start name of the sku to search + :type sku_starts_with: Optional[str] + + :param vm_sku: The name of the virtual machine sku to use + :type vm_sku: Optional[str] + + :param vm_version: The version of the virtual machine + :param vm_version: str + + :param vm_node_agent_sku_id: The node agent sku id of the virtual machine + :type vm_node_agent_sku_id: Optional[str] + + :param os_family: The Azure Guest OS family to be installed on the virtual machines in the Pool. + :type os_family: Optional[str] + + :param os_version: The OS family version + :type os_version: Optional[str] + + """ + if use_latest_image_and_sku: + self.log.info( + "Using latest verified virtual machine image with node agent sku" + ) + sku_to_use, image_ref_to_use = self._get_latest_verified_image_vm_and_sku( + publisher=vm_publisher, offer=vm_offer, sku_starts_with=sku_starts_with + ) + pool = batch_models.PoolAddParameter( + id=pool_id, + vm_size=vm_size, + display_name=display_name, + virtual_machine_configuration=batch_models.VirtualMachineConfiguration( + image_reference=image_ref_to_use, node_agent_sku_id=sku_to_use + ), + target_dedicated_nodes=target_dedicated_nodes, + **kwargs, + ) + + elif os_family: + self.log.info( + "Using cloud service configuration to create pool, virtual machine configuration ignored" + ) + pool = batch_models.PoolAddParameter( + id=pool_id, + vm_size=vm_size, + display_name=display_name, + cloud_service_configuration=batch_models.CloudServiceConfiguration( + os_family=os_family, os_version=os_version + ), + target_dedicated_nodes=target_dedicated_nodes, + **kwargs, + ) + + else: + self.log.info("Using virtual machine configuration to create a pool") + pool = batch_models.PoolAddParameter( + id=pool_id, + vm_size=vm_size, + display_name=display_name, + virtual_machine_configuration=batch_models.VirtualMachineConfiguration( + image_reference=batch_models.ImageReference( + publisher=vm_publisher, + offer=vm_offer, + sku=vm_sku, + version=vm_version, + ), + node_agent_sku_id=vm_node_agent_sku_id, + ), + target_dedicated_nodes=target_dedicated_nodes, + **kwargs, + ) + return pool + + def create_pool(self, pool: PoolAddParameter) -> None: + """ + Creates a pool if not already existing + + :param pool: the pool object to create + :type pool: batch_models.PoolAddParameter + + """ + try: + self.log.info("Attempting to create a pool: %s", pool.id) + self.connection.pool.add(pool) + self.log.info("Created pool: %s", pool.id) + except batch_models.BatchErrorException as e: + if e.error.code != "PoolExists": + raise + else: + self.log.info("Pool %s already exists", pool.id) + + def _get_latest_verified_image_vm_and_sku( + self, + publisher: Optional[str] = None, + offer: Optional[str] = None, + sku_starts_with: Optional[str] = None, + ) -> tuple: + """ + Get latest verified image vm and sku + + :param publisher: The publisher of the Azure Virtual Machines Marketplace Image. + For example, Canonical or MicrosoftWindowsServer. + :type publisher: str + :param offer: The offer type of the Azure Virtual Machines Marketplace Image. + For example, UbuntuServer or WindowsServer. + :type offer: str + :param sku_starts_with: The start name of the sku to search + :type sku_starts_with: str + """ + options = batch_models.AccountListSupportedImagesOptions( + filter="verificationType eq 'verified'" + ) + images = self.connection.account.list_supported_images( + account_list_supported_images_options=options + ) + # pick the latest supported sku + skus_to_use = [ + (image.node_agent_sku_id, image.image_reference) + for image in images + if image.image_reference.publisher.lower() == publisher + and image.image_reference.offer.lower() == offer + and image.image_reference.sku.startswith(sku_starts_with) + ] + + # pick first + agent_sku_id, image_ref_to_use = skus_to_use[0] + return agent_sku_id, image_ref_to_use + + def wait_for_all_node_state(self, pool_id: str, node_state: Set) -> list: + """ + Wait for all nodes in a pool to reach given states + + :param pool_id: A string that identifies the pool + :type pool_id: str + :param node_state: A set of batch_models.ComputeNodeState + :type node_state: set + """ + self.log.info( + "waiting for all nodes in pool %s to reach one of: %s", pool_id, node_state + ) + while True: + # refresh pool to ensure that there is no resize error + pool = self.connection.pool.get(pool_id) + if pool.resize_errors is not None: + resize_errors = "\n".join([repr(e) for e in pool.resize_errors]) + raise RuntimeError( + f"resize error encountered for pool {pool.id}:\n{resize_errors}" + ) + nodes = list(self.connection.compute_node.list(pool.id)) + if len(nodes) >= pool.target_dedicated_nodes and all( + node.state in node_state for node in nodes + ): + return nodes + # Allow the timeout to be controlled by the AzureBatchOperator + # specified timeout. This way we don't interrupt a startTask inside + # the pool + time.sleep(10) + + def configure_job( + self, + job_id: str, + pool_id: str, + display_name: Optional[str] = None, + **kwargs, + ) -> JobAddParameter: + """ + Configures a job for use in the pool + + :param job_id: A string that uniquely identifies the job within the account + :type job_id: str + :param pool_id: A string that identifies the pool + :type pool_id: str + :param display_name: The display name for the job + :type display_name: str + """ + job = batch_models.JobAddParameter( + id=job_id, + pool_info=batch_models.PoolInformation(pool_id=pool_id), + display_name=display_name, + **kwargs, + ) + return job + + def create_job(self, job: JobAddParameter) -> None: + """ + Creates a job in the pool + + :param job: The job object to create + :type job: batch_models.JobAddParameter + """ + try: + self.connection.job.add(job) + self.log.info("Job %s created", job.id) + except batch_models.BatchErrorException as err: + if err.error.code != "JobExists": + raise + else: + self.log.info("Job %s already exists", job.id) + + def configure_task( + self, + task_id: str, + command_line: str, + display_name: Optional[str] = None, + container_settings=None, + **kwargs, + ) -> TaskAddParameter: + """ + Creates a task + + :param task_id: A string that identifies the task to create + :type task_id: str + :param command_line: The command line of the Task. + :type command_line: str + :param display_name: A display name for the Task + :type display_name: str + :param container_settings: The settings for the container under which the Task runs. + If the Pool that will run this Task has containerConfiguration set, + this must be set as well. If the Pool that will run this Task doesn't have + containerConfiguration set, this must not be set. + :type container_settings: batch_models.TaskContainerSettings + """ + task = batch_models.TaskAddParameter( + id=task_id, + command_line=command_line, + display_name=display_name, + container_settings=container_settings, + **kwargs, + ) + self.log.info("Task created: %s", task_id) + return task + + def add_single_task_to_job(self, job_id: str, task: TaskAddParameter) -> None: + """ + Add a single task to given job if it doesn't exist + + :param job_id: A string that identifies the given job + :type job_id: str + :param task: The task to add + :type task: batch_models.TaskAddParameter + """ + try: + + self.connection.task.add(job_id=job_id, task=task) + except batch_models.BatchErrorException as err: + if err.error.code != "TaskExists": + raise + else: + self.log.info("Task %s already exists", task.id) + + def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> None: + """ + Wait for tasks in a particular job to complete + + :param job_id: A string that identifies the job + :type job_id: str + :param timeout: The amount of time to wait before timing out in minutes + :type timeout: int + """ + timeout_time = timezone.utcnow() + timedelta(minutes=timeout) + while timezone.utcnow() < timeout_time: + tasks = self.connection.task.list(job_id) + + incomplete_tasks = [ + task for task in tasks if task.state != batch_models.TaskState.completed + ] + if not incomplete_tasks: + return + for task in incomplete_tasks: + self.log.info( + "Waiting for %s to complete, currently on %s state", + task.id, + task.state, + ) + time.sleep(15) + raise TimeoutError("Timed out waiting for tasks to complete") diff --git a/reference/providers/microsoft/azure/hooks/azure_container_instance.py b/reference/providers/microsoft/azure/hooks/azure_container_instance.py new file mode 100644 index 0000000..45d845d --- /dev/null +++ b/reference/providers/microsoft/azure/hooks/azure_container_instance.py @@ -0,0 +1,166 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import warnings +from typing import Any + +from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook +from azure.mgmt.containerinstance import ContainerInstanceManagementClient +from azure.mgmt.containerinstance.models import ContainerGroup + + +class AzureContainerInstanceHook(AzureBaseHook): + """ + A hook to communicate with Azure Container Instances. + + This hook requires a service principal in order to work. + After creating this service principal + (Azure Active Directory/App Registrations), you need to fill in the + client_id (Application ID) as login, the generated password as password, + and tenantId and subscriptionId in the extra's field as a json. + + :param conn_id: connection id of a service principal which will be used + to start the container instance + :type conn_id: str + """ + + conn_name_attr = "conn_id" + default_conn_name = "azure_default" + conn_type = "azure_container_instances" + hook_name = "Azure Container Instance" + + def __init__(self, conn_id: str = default_conn_name) -> None: + super().__init__(sdk_client=ContainerInstanceManagementClient, conn_id=conn_id) + self.connection = self.get_conn() + + def create_or_update( + self, resource_group: str, name: str, container_group: ContainerGroup + ) -> None: + """ + Create a new container group + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + :param container_group: the properties of the container group + :type container_group: azure.mgmt.containerinstance.models.ContainerGroup + """ + self.connection.container_groups.create_or_update( + resource_group, name, container_group + ) + + def get_state_exitcode_details(self, resource_group: str, name: str) -> tuple: + """ + Get the state and exitcode of a container group + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + :return: A tuple with the state, exitcode, and details. + If the exitcode is unknown 0 is returned. + :rtype: tuple(state,exitcode,details) + """ + warnings.warn( + "get_state_exitcode_details() is deprecated. Related method is get_state()", + DeprecationWarning, + stacklevel=2, + ) + cg_state = self.get_state(resource_group, name) + c_state = cg_state.containers[0].instance_view.current_state + return (c_state.state, c_state.exit_code, c_state.detail_status) + + def get_messages(self, resource_group: str, name: str) -> list: + """ + Get the messages of a container group + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + :return: A list of the event messages + :rtype: list[str] + """ + warnings.warn( + "get_messages() is deprecated. Related method is get_state()", + DeprecationWarning, + stacklevel=2, + ) + cg_state = self.get_state(resource_group, name) + instance_view = cg_state.containers[0].instance_view + return [event.message for event in instance_view.events] + + def get_state(self, resource_group: str, name: str) -> Any: + """ + Get the state of a container group + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + :return: ContainerGroup + :rtype: ~azure.mgmt.containerinstance.models.ContainerGroup + """ + return self.connection.container_groups.get(resource_group, name, raw=False) + + def get_logs(self, resource_group: str, name: str, tail: int = 1000) -> list: + """ + Get the tail from logs of a container group + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + :param tail: the size of the tail + :type tail: int + :return: A list of log messages + :rtype: list[str] + """ + logs = self.connection.container.list_logs( + resource_group, name, name, tail=tail + ) + return logs.content.splitlines(True) + + def delete(self, resource_group: str, name: str) -> None: + """ + Delete a container group + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + """ + self.connection.container_groups.delete(resource_group, name) + + def exists(self, resource_group: str, name: str) -> bool: + """ + Test if a container group exists + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + """ + for container in self.connection.container_groups.list_by_resource_group( + resource_group + ): + if container.name == name: + return True + return False diff --git a/reference/providers/microsoft/azure/hooks/azure_container_registry.py b/reference/providers/microsoft/azure/hooks/azure_container_registry.py new file mode 100644 index 0000000..9268ea2 --- /dev/null +++ b/reference/providers/microsoft/azure/hooks/azure_container_registry.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Azure Container Registry""" + +from airflow.hooks.base import BaseHook +from azure.mgmt.containerinstance.models import ImageRegistryCredential + + +class AzureContainerRegistryHook(BaseHook): + """ + A hook to communicate with a Azure Container Registry. + + :param conn_id: connection id of a service principal which will be used + to start the container instance + :type conn_id: str + """ + + def __init__(self, conn_id: str = "azure_registry") -> None: + super().__init__() + self.conn_id = conn_id + self.connection = self.get_conn() + + def get_conn(self) -> ImageRegistryCredential: + conn = self.get_connection(self.conn_id) + return ImageRegistryCredential( + server=conn.host, username=conn.login, password=conn.password + ) diff --git a/reference/providers/microsoft/azure/hooks/azure_container_volume.py b/reference/providers/microsoft/azure/hooks/azure_container_volume.py new file mode 100644 index 0000000..d249f1e --- /dev/null +++ b/reference/providers/microsoft/azure/hooks/azure_container_volume.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.hooks.base import BaseHook +from azure.mgmt.containerinstance.models import AzureFileVolume, Volume + + +class AzureContainerVolumeHook(BaseHook): + """ + A hook which wraps an Azure Volume. + + :param wasb_conn_id: connection id of a Azure storage account of + which file shares should be mounted + :type wasb_conn_id: str + """ + + def __init__(self, wasb_conn_id: str = "wasb_default") -> None: + super().__init__() + self.conn_id = wasb_conn_id + + def get_storagekey(self) -> str: + """Get Azure File Volume storage key""" + conn = self.get_connection(self.conn_id) + service_options = conn.extra_dejson + + if "connection_string" in service_options: + for keyvalue in service_options["connection_string"].split(";"): + key, value = keyvalue.split("=", 1) + if key == "AccountKey": + return value + return conn.password + + def get_file_volume( + self, + mount_name: str, + share_name: str, + storage_account_name: str, + read_only: bool = False, + ) -> Volume: + """Get Azure File Volume""" + return Volume( + name=mount_name, + azure_file=AzureFileVolume( + share_name=share_name, + storage_account_name=storage_account_name, + read_only=read_only, + storage_account_key=self.get_storagekey(), + ), + ) diff --git a/reference/providers/microsoft/azure/hooks/azure_cosmos.py b/reference/providers/microsoft/azure/hooks/azure_cosmos.py new file mode 100644 index 0000000..bad3dad --- /dev/null +++ b/reference/providers/microsoft/azure/hooks/azure_cosmos.py @@ -0,0 +1,335 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains integration with Azure CosmosDB. + +AzureCosmosDBHook communicates via the Azure Cosmos library. Make sure that a +Airflow connection of type `azure_cosmos` exists. Authorization can be done by supplying a +login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify +the default database and collection to use (see connection `azure_cosmos_default` for an example). +""" +import uuid +from typing import Optional + +from airflow.exceptions import AirflowBadRequest +from airflow.hooks.base import BaseHook +from azure.cosmos.cosmos_client import CosmosClient +from azure.cosmos.errors import HTTPFailure + + +class AzureCosmosDBHook(BaseHook): + """ + Interacts with Azure CosmosDB. + + login should be the endpoint uri, password should be the master key + optionally, you can use the following extras to default these values + {"database_name": "", "collection_name": "COLLECTION_NAME"}. + + :param azure_cosmos_conn_id: Reference to the Azure CosmosDB connection. + :type azure_cosmos_conn_id: str + """ + + conn_name_attr = "azure_cosmos_conn_id" + default_conn_name = "azure_cosmos_default" + conn_type = "azure_cosmos" + hook_name = "Azure CosmosDB" + + def __init__(self, azure_cosmos_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = azure_cosmos_conn_id + self._conn = None + + self.default_database_name = None + self.default_collection_name = None + + def get_conn(self) -> CosmosClient: + """Return a cosmos db client.""" + if not self._conn: + conn = self.get_connection(self.conn_id) + extras = conn.extra_dejson + endpoint_uri = conn.login + master_key = conn.password + + self.default_database_name = extras.get("database_name") + self.default_collection_name = extras.get("collection_name") + + # Initialize the Python Azure Cosmos DB client + self._conn = CosmosClient(endpoint_uri, {"masterKey": master_key}) + return self._conn + + def __get_database_name(self, database_name: Optional[str] = None) -> str: + self.get_conn() + db_name = database_name + if db_name is None: + db_name = self.default_database_name + + if db_name is None: + raise AirflowBadRequest("Database name must be specified") + + return db_name + + def __get_collection_name(self, collection_name: Optional[str] = None) -> str: + self.get_conn() + coll_name = collection_name + if coll_name is None: + coll_name = self.default_collection_name + + if coll_name is None: + raise AirflowBadRequest("Collection name must be specified") + + return coll_name + + def does_collection_exist(self, collection_name: str, database_name: str) -> bool: + """Checks if a collection exists in CosmosDB.""" + if collection_name is None: + raise AirflowBadRequest("Collection name cannot be None.") + + existing_container = list( + self.get_conn().QueryContainers( + get_database_link(self.__get_database_name(database_name)), + { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [{"name": "@id", "value": collection_name}], + }, + ) + ) + if len(existing_container) == 0: + return False + + return True + + def create_collection( + self, collection_name: str, database_name: Optional[str] = None + ) -> None: + """Creates a new collection in the CosmosDB database.""" + if collection_name is None: + raise AirflowBadRequest("Collection name cannot be None.") + + # We need to check to see if this container already exists so we don't try + # to create it twice + existing_container = list( + self.get_conn().QueryContainers( + get_database_link(self.__get_database_name(database_name)), + { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [{"name": "@id", "value": collection_name}], + }, + ) + ) + + # Only create if we did not find it already existing + if len(existing_container) == 0: + self.get_conn().CreateContainer( + get_database_link(self.__get_database_name(database_name)), + {"id": collection_name}, + ) + + def does_database_exist(self, database_name: str) -> bool: + """Checks if a database exists in CosmosDB.""" + if database_name is None: + raise AirflowBadRequest("Database name cannot be None.") + + existing_database = list( + self.get_conn().QueryDatabases( + { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [{"name": "@id", "value": database_name}], + } + ) + ) + if len(existing_database) == 0: + return False + + return True + + def create_database(self, database_name: str) -> None: + """Creates a new database in CosmosDB.""" + if database_name is None: + raise AirflowBadRequest("Database name cannot be None.") + + # We need to check to see if this database already exists so we don't try + # to create it twice + existing_database = list( + self.get_conn().QueryDatabases( + { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [{"name": "@id", "value": database_name}], + } + ) + ) + + # Only create if we did not find it already existing + if len(existing_database) == 0: + self.get_conn().CreateDatabase({"id": database_name}) + + def delete_database(self, database_name: str) -> None: + """Deletes an existing database in CosmosDB.""" + if database_name is None: + raise AirflowBadRequest("Database name cannot be None.") + + self.get_conn().DeleteDatabase(get_database_link(database_name)) + + def delete_collection( + self, collection_name: str, database_name: Optional[str] = None + ) -> None: + """Deletes an existing collection in the CosmosDB database.""" + if collection_name is None: + raise AirflowBadRequest("Collection name cannot be None.") + + self.get_conn().DeleteContainer( + get_collection_link( + self.__get_database_name(database_name), collection_name + ) + ) + + def upsert_document( + self, document, database_name=None, collection_name=None, document_id=None + ): + """ + Inserts a new document (or updates an existing one) into an existing + collection in the CosmosDB database. + """ + # Assign unique ID if one isn't provided + if document_id is None: + document_id = str(uuid.uuid4()) + + if document is None: + raise AirflowBadRequest("You cannot insert a None document") + + # Add document id if isn't found + if "id" in document: + if document["id"] is None: + document["id"] = document_id + else: + document["id"] = document_id + + created_document = self.get_conn().CreateItem( + get_collection_link( + self.__get_database_name(database_name), + self.__get_collection_name(collection_name), + ), + document, + ) + + return created_document + + def insert_documents( + self, + documents, + database_name: Optional[str] = None, + collection_name: Optional[str] = None, + ) -> list: + """Insert a list of new documents into an existing collection in the CosmosDB database.""" + if documents is None: + raise AirflowBadRequest("You cannot insert empty documents") + + created_documents = [] + for single_document in documents: + created_documents.append( + self.get_conn().CreateItem( + get_collection_link( + self.__get_database_name(database_name), + self.__get_collection_name(collection_name), + ), + single_document, + ) + ) + + return created_documents + + def delete_document( + self, + document_id: str, + database_name: Optional[str] = None, + collection_name: Optional[str] = None, + ) -> None: + """Delete an existing document out of a collection in the CosmosDB database.""" + if document_id is None: + raise AirflowBadRequest("Cannot delete a document without an id") + + self.get_conn().DeleteItem( + get_document_link( + self.__get_database_name(database_name), + self.__get_collection_name(collection_name), + document_id, + ) + ) + + def get_document( + self, + document_id: str, + database_name: Optional[str] = None, + collection_name: Optional[str] = None, + ): + """Get a document from an existing collection in the CosmosDB database.""" + if document_id is None: + raise AirflowBadRequest("Cannot get a document without an id") + + try: + return self.get_conn().ReadItem( + get_document_link( + self.__get_database_name(database_name), + self.__get_collection_name(collection_name), + document_id, + ) + ) + except HTTPFailure: + return None + + def get_documents( + self, + sql_string: str, + database_name: Optional[str] = None, + collection_name: Optional[str] = None, + partition_key: Optional[str] = None, + ) -> Optional[list]: + """Get a list of documents from an existing collection in the CosmosDB database via SQL query.""" + if sql_string is None: + raise AirflowBadRequest("SQL query string cannot be None") + + # Query them in SQL + query = {"query": sql_string} + + try: + result_iterable = self.get_conn().QueryItems( + get_collection_link( + self.__get_database_name(database_name), + self.__get_collection_name(collection_name), + ), + query, + partition_key, + ) + + return list(result_iterable) + except HTTPFailure: + return None + + +def get_database_link(database_id: str) -> str: + """Get Azure CosmosDB database link""" + return "dbs/" + database_id + + +def get_collection_link(database_id: str, collection_id: str) -> str: + """Get Azure CosmosDB collection link""" + return get_database_link(database_id) + "/colls/" + collection_id + + +def get_document_link(database_id: str, collection_id: str, document_id: str) -> str: + """Get Azure CosmosDB document link""" + return get_collection_link(database_id, collection_id) + "/docs/" + document_id diff --git a/reference/providers/microsoft/azure/hooks/azure_data_factory.py b/reference/providers/microsoft/azure/hooks/azure_data_factory.py new file mode 100644 index 0000000..a5205f5 --- /dev/null +++ b/reference/providers/microsoft/azure/hooks/azure_data_factory.py @@ -0,0 +1,794 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import inspect +from functools import wraps +from typing import Any, Callable, Optional + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from azure.core.polling import LROPoller +from azure.identity import ClientSecretCredential +from azure.mgmt.datafactory import DataFactoryManagementClient +from azure.mgmt.datafactory.models import ( + CreateRunResponse, + Dataset, + DatasetResource, + Factory, + LinkedService, + LinkedServiceResource, + PipelineResource, + PipelineRun, + Trigger, + TriggerResource, +) + + +def provide_targeted_factory(func: Callable) -> Callable: + """ + Provide the targeted factory to the decorated function in case it isn't specified. + + If ``resource_group_name`` or ``factory_name`` is not provided it defaults to the value specified in + the connection extras. + """ + signature = inspect.signature(func) + + @wraps(func) + def wrapper(*args, **kwargs) -> Callable: + bound_args = signature.bind(*args, **kwargs) + + def bind_argument(arg, default_key): + if arg not in bound_args.arguments: + self = args[0] + conn = self.get_connection(self.conn_id) + default_value = conn.extra_dejson.get(default_key) + + if not default_value: + raise AirflowException( + "Could not determine the targeted data factory." + ) + + bound_args.arguments[arg] = conn.extra_dejson[default_key] + + bind_argument("resource_group_name", "resourceGroup") + bind_argument("factory_name", "factory") + + return func(*bound_args.args, **bound_args.kwargs) + + return wrapper + + +class AzureDataFactoryHook(BaseHook): # pylint: disable=too-many-public-methods + """ + A hook to interact with Azure Data Factory. + + :param conn_id: The Azure Data Factory connection id. + """ + + conn_type: str = "azure_data_factory" + conn_name_attr: str = "azure_data_factory_conn_id" + default_conn_name: str = "azure_data_factory_default" + hook_name: str = "Azure Data Factory" + + def __init__(self, conn_id: Optional[str] = default_conn_name): + self._conn: DataFactoryManagementClient = None + self.conn_id = conn_id + super().__init__() + + def get_conn(self) -> DataFactoryManagementClient: + if self._conn is not None: + return self._conn + + conn = self.get_connection(self.conn_id) + + self._conn = DataFactoryManagementClient( + credential=ClientSecretCredential( + client_id=conn.login, + client_secret=conn.password, + tenant_id=conn.extra_dejson.get("tenantId"), + ), + subscription_id=conn.extra_dejson.get("subscriptionId"), + ) + + return self._conn + + @provide_targeted_factory + def get_factory( + self, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> Factory: + """ + Get the factory. + + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The factory. + """ + return self.get_conn().factories.get( + resource_group_name, factory_name, **config + ) + + def _factory_exists(self, resource_group_name, factory_name) -> bool: + """Return whether or not the factory already exists.""" + factories = { + factory.name + for factory in self.get_conn().factories.list_by_resource_group( + resource_group_name + ) + } + + return factory_name in factories + + @provide_targeted_factory + def update_factory( + self, + factory: Factory, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> Factory: + """ + Update the factory. + + :param factory: The factory resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the factory does not exist. + :return: The factory. + """ + if not self._factory_exists(resource_group_name, factory_name): + raise AirflowException(f"Factory {factory!r} does not exist.") + + return self.get_conn().factories.create_or_update( + resource_group_name, factory_name, factory, **config + ) + + @provide_targeted_factory + def create_factory( + self, + factory: Factory, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> Factory: + """ + Create the factory. + + :param factory: The factory resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the factory already exists. + :return: The factory. + """ + if self._factory_exists(resource_group_name, factory_name): + raise AirflowException(f"Factory {factory!r} already exists.") + + return self.get_conn().factories.create_or_update( + resource_group_name, factory_name, factory, **config + ) + + @provide_targeted_factory + def delete_factory( + self, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Delete the factory. + + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().factories.delete(resource_group_name, factory_name, **config) + + @provide_targeted_factory + def get_linked_service( + self, + linked_service_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> LinkedServiceRe# + """ + Get the linked service. + + :param linked_service_name: The linked service name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The linked service. + """ + return self.get_conn().linked_services.get( + resource_group_name, factory_name, linked_service_name, **config + ) + + def _linked_service_exists( + self, resource_group_name, factory_name, linked_service_name + ) -> bool: + """Return whether or not the linked service already exists.""" + linked_services = { + linked_service.name + for linked_service in self.get_conn().linked_services.list_by_factory( + resource_group_name, factory_name + ) + } + + return linked_service_name in linked_services + + @provide_targeted_factory + def update_linked_service( + self, + linked_service_name: str, + linked_service: LinkedService, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> LinkedServiceRe# + """ + Update the linked service. + + :param linked_service_name: The linked service name. + :param linked_service: The linked service resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the linked service does not exist. + :return: The linked service. + """ + if not self._linked_service_exists( + resource_group_name, factory_name, linked_service_name + ): + raise AirflowException( + f"Linked service {linked_service_name!r} does not exist." + ) + + return self.get_conn().linked_services.create_or_update( + resource_group_name, + factory_name, + linked_service_name, + linked_service, + **config, + ) + + @provide_targeted_factory + def create_linked_service( + self, + linked_service_name: str, + linked_service: LinkedService, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> LinkedServiceRe# + """ + Create the linked service. + + :param linked_service_name: The linked service name. + :param linked_service: The linked service resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the linked service already exists. + :return: The linked service. + """ + if self._linked_service_exists( + resource_group_name, factory_name, linked_service_name + ): + raise AirflowException( + f"Linked service {linked_service_name!r} already exists." + ) + + return self.get_conn().linked_services.create_or_update( + resource_group_name, + factory_name, + linked_service_name, + linked_service, + **config, + ) + + @provide_targeted_factory + def delete_linked_service( + self, + linked_service_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Delete the linked service. + + :param linked_service_name: The linked service name. + :param resource_group_name: The linked service name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().linked_services.delete( + resource_group_name, factory_name, linked_service_name, **config + ) + + @provide_targeted_factory + def get_dataset( + self, + dataset_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> DatasetRe# + """ + Get the dataset. + + :param dataset_name: The dataset name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The dataset. + """ + return self.get_conn().datasets.get( + resource_group_name, factory_name, dataset_name, **config + ) + + def _dataset_exists(self, resource_group_name, factory_name, dataset_name) -> bool: + """Return whether or not the dataset already exists.""" + datasets = { + dataset.name + for dataset in self.get_conn().datasets.list_by_factory( + resource_group_name, factory_name + ) + } + + return dataset_name in datasets + + @provide_targeted_factory + def update_dataset( + self, + dataset_name: str, + dataset: Dataset, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> DatasetRe# + """ + Update the dataset. + + :param dataset_name: The dataset name. + :param dataset: The dataset resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the dataset does not exist. + :return: The dataset. + """ + if not self._dataset_exists(resource_group_name, factory_name, dataset_name): + raise AirflowException(f"Dataset {dataset_name!r} does not exist.") + + return self.get_conn().datasets.create_or_update( + resource_group_name, factory_name, dataset_name, dataset, **config + ) + + @provide_targeted_factory + def create_dataset( + self, + dataset_name: str, + dataset: Dataset, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> DatasetRe# + """ + Create the dataset. + + :param dataset_name: The dataset name. + :param dataset: The dataset resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the dataset already exists. + :return: The dataset. + """ + if self._dataset_exists(resource_group_name, factory_name, dataset_name): + raise AirflowException(f"Dataset {dataset_name!r} already exists.") + + return self.get_conn().datasets.create_or_update( + resource_group_name, factory_name, dataset_name, dataset, **config + ) + + @provide_targeted_factory + def delete_dataset( + self, + dataset_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Delete the dataset. + + :param dataset_name: The dataset name. + :param resource_group_name: The dataset name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().datasets.delete( + resource_group_name, factory_name, dataset_name, **config + ) + + @provide_targeted_factory + def get_pipeline( + self, + pipeline_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> PipelineRe# + """ + Get the pipeline. + + :param pipeline_name: The pipeline name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The pipeline. + """ + return self.get_conn().pipelines.get( + resource_group_name, factory_name, pipeline_name, **config + ) + + def _pipeline_exists( + self, resource_group_name, factory_name, pipeline_name + ) -> bool: + """Return whether or not the pipeline already exists.""" + pipelines = { + pipeline.name + for pipeline in self.get_conn().pipelines.list_by_factory( + resource_group_name, factory_name + ) + } + + return pipeline_name in pipelines + + @provide_targeted_factory + def update_pipeline( + self, + pipeline_name: str, + pipeline: PipelineResource, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> PipelineRe# + """ + Update the pipeline. + + :param pipeline_name: The pipeline name. + :param pipeline: The pipeline resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the pipeline does not exist. + :return: The pipeline. + """ + if not self._pipeline_exists(resource_group_name, factory_name, pipeline_name): + raise AirflowException(f"Pipeline {pipeline_name!r} does not exist.") + + return self.get_conn().pipelines.create_or_update( + resource_group_name, factory_name, pipeline_name, pipeline, **config + ) + + @provide_targeted_factory + def create_pipeline( + self, + pipeline_name: str, + pipeline: PipelineResource, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> PipelineRe# + """ + Create the pipeline. + + :param pipeline_name: The pipeline name. + :param pipeline: The pipeline resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the pipeline already exists. + :return: The pipeline. + """ + if self._pipeline_exists(resource_group_name, factory_name, pipeline_name): + raise AirflowException(f"Pipeline {pipeline_name!r} already exists.") + + return self.get_conn().pipelines.create_or_update( + resource_group_name, factory_name, pipeline_name, pipeline, **config + ) + + @provide_targeted_factory + def delete_pipeline( + self, + pipeline_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Delete the pipeline. + + :param pipeline_name: The pipeline name. + :param resource_group_name: The pipeline name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().pipelines.delete( + resource_group_name, factory_name, pipeline_name, **config + ) + + @provide_targeted_factory + def run_pipeline( + self, + pipeline_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> CreateRunResponse: + """ + Run a pipeline. + + :param pipeline_name: The pipeline name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The pipeline run. + """ + return self.get_conn().pipelines.create_run( + resource_group_name, factory_name, pipeline_name, **config + ) + + @provide_targeted_factory + def get_pipeline_run( + self, + run_id: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> PipelineRun: + """ + Get the pipeline run. + + :param run_id: The pipeline run identifier. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The pipeline run. + """ + return self.get_conn().pipeline_runs.get( + resource_group_name, factory_name, run_id, **config + ) + + @provide_targeted_factory + def cancel_pipeline_run( + self, + run_id: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Cancel the pipeline run. + + :param run_id: The pipeline run identifier. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().pipeline_runs.cancel( + resource_group_name, factory_name, run_id, **config + ) + + @provide_targeted_factory + def get_trigger( + self, + trigger_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> TriggerRe# + """ + Get the trigger. + + :param trigger_name: The trigger name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The trigger. + """ + return self.get_conn().triggers.get( + resource_group_name, factory_name, trigger_name, **config + ) + + def _trigger_exists(self, resource_group_name, factory_name, trigger_name) -> bool: + """Return whether or not the trigger already exists.""" + triggers = { + trigger.name + for trigger in self.get_conn().triggers.list_by_factory( + resource_group_name, factory_name + ) + } + + return trigger_name in triggers + + @provide_targeted_factory + def update_trigger( + self, + trigger_name: str, + trigger: Trigger, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> TriggerRe# + """ + Update the trigger. + + :param trigger_name: The trigger name. + :param trigger: The trigger resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the trigger does not exist. + :return: The trigger. + """ + if not self._trigger_exists(resource_group_name, factory_name, trigger_name): + raise AirflowException(f"Trigger {trigger_name!r} does not exist.") + + return self.get_conn().triggers.create_or_update( + resource_group_name, factory_name, trigger_name, trigger, **config + ) + + @provide_targeted_factory + def create_trigger( + self, + trigger_name: str, + trigger: Trigger, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> TriggerRe# + """ + Create the trigger. + + :param trigger_name: The trigger name. + :param trigger: The trigger resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the trigger already exists. + :return: The trigger. + """ + if self._trigger_exists(resource_group_name, factory_name, trigger_name): + raise AirflowException(f"Trigger {trigger_name!r} already exists.") + + return self.get_conn().triggers.create_or_update( + resource_group_name, factory_name, trigger_name, trigger, **config + ) + + @provide_targeted_factory + def delete_trigger( + self, + trigger_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Delete the trigger. + + :param trigger_name: The trigger name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().triggers.delete( + resource_group_name, factory_name, trigger_name, **config + ) + + @provide_targeted_factory + def start_trigger( + self, + trigger_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> LROPoller: + """ + Start the trigger. + + :param trigger_name: The trigger name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: An Azure operation poller. + """ + return self.get_conn().triggers.begin_start( + resource_group_name, factory_name, trigger_name, **config + ) + + @provide_targeted_factory + def stop_trigger( + self, + trigger_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> LROPoller: + """ + Stop the trigger. + + :param trigger_name: The trigger name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: An Azure operation poller. + """ + return self.get_conn().triggers.begin_stop( + resource_group_name, factory_name, trigger_name, **config + ) + + @provide_targeted_factory + def rerun_trigger( + self, + trigger_name: str, + run_id: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Rerun the trigger. + + :param trigger_name: The trigger name. + :param run_id: The trigger run identifier. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + return self.get_conn().trigger_runs.rerun( + resource_group_name, factory_name, trigger_name, run_id, **config + ) + + @provide_targeted_factory + def cancel_trigger( + self, + trigger_name: str, + run_id: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Cancel the trigger. + + :param trigger_name: The trigger name. + :param run_id: The trigger run identifier. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().trigger_runs.cancel( + resource_group_name, factory_name, trigger_name, run_id, **config + ) diff --git a/reference/providers/microsoft/azure/hooks/azure_data_lake.py b/reference/providers/microsoft/azure/hooks/azure_data_lake.py new file mode 100644 index 0000000..796be12 --- /dev/null +++ b/reference/providers/microsoft/azure/hooks/azure_data_lake.py @@ -0,0 +1,216 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +This module contains integration with Azure Data Lake. + +AzureDataLakeHook communicates via a REST API compatible with WebHDFS. Make sure that a +Airflow connection of type `azure_data_lake` exists. Authorization can be done by supplying a +login (=Client ID), password (=Client Secret) and extra fields tenant (Tenant) and account_name (Account Name) +(see connection `azure_data_lake_default` for an example). +""" +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from azure.datalake.store import core, lib, multithread + + +class AzureDataLakeHook(BaseHook): + """ + Interacts with Azure Data Lake. + + Client ID and client secret should be in user and password parameters. + Tenant and account name should be extra field as + {"tenant": "", "account_name": "ACCOUNT_NAME"}. + + :param azure_data_lake_conn_id: Reference to the Azure Data Lake connection. + :type azure_data_lake_conn_id: str + """ + + conn_name_attr = "azure_data_lake_conn_id" + default_conn_name = "azure_data_lake_default" + conn_type = "azure_data_lake" + hook_name = "Azure Data Lake" + + def __init__(self, azure_data_lake_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = azure_data_lake_conn_id + self._conn: Optional[core.AzureDLFileSystem] = None + self.account_name: Optional[str] = None + + def get_conn(self) -> core.AzureDLFileSystem: + """Return a AzureDLFileSystem object.""" + if not self._conn: + conn = self.get_connection(self.conn_id) + service_options = conn.extra_dejson + self.account_name = service_options.get("account_name") + + adl_creds = lib.auth( + tenant_id=service_options.get("tenant"), + client_secret=conn.password, + client_id=conn.login, + ) + self._conn = core.AzureDLFileSystem(adl_creds, store_name=self.account_name) + self._conn.connect() + return self._conn + + def check_for_file(self, file_path: str) -> bool: + """ + Check if a file exists on Azure Data Lake. + + :param file_path: Path and name of the file. + :type file_path: str + :return: True if the file exists, False otherwise. + :rtype: bool + """ + try: + files = self.get_conn().glob( + file_path, details=False, invalidate_cache=True + ) + return len(files) == 1 + except FileNotFoundError: + return False + + def upload_file( + self, + local_path: str, + remote_path: str, + nthreads: int = 64, + overwrite: bool = True, + buffersize: int = 4194304, + blocksize: int = 4194304, + **kwargs, + ) -> None: + """ + Upload a file to Azure Data Lake. + + :param local_path: local path. Can be single file, directory (in which case, + upload recursively) or glob pattern. Recursive glob patterns using `**` + are not supported. + :type local_path: str + :param remote_path: Remote path to upload to; if multiple files, this is the + directory root to write within. + :type remote_path: str + :param nthreads: Number of threads to use. If None, uses the number of cores. + :type nthreads: int + :param overwrite: Whether to forcibly overwrite existing files/directories. + If False and remote path is a directory, will quit regardless if any files + would be overwritten or not. If True, only matching filenames are actually + overwritten. + :type overwrite: bool + :param buffersize: int [2**22] + Number of bytes for internal buffer. This block cannot be bigger than + a chunk and cannot be smaller than a block. + :type buffersize: int + :param blocksize: int [2**22] + Number of bytes for a block. Within each chunk, we write a smaller + block for each API call. This block cannot be bigger than a chunk. + :type blocksize: int + """ + multithread.ADLUploader( + self.get_conn(), + lpath=local_path, + rpath=remote_path, + nthreads=nthreads, + overwrite=overwrite, + buffersize=buffersize, + blocksize=blocksize, + **kwargs, + ) + + def download_file( + self, + local_path: str, + remote_path: str, + nthreads: int = 64, + overwrite: bool = True, + buffersize: int = 4194304, + blocksize: int = 4194304, + **kwargs, + ) -> None: + """ + Download a file from Azure Blob Storage. + + :param local_path: local path. If downloading a single file, will write to this + specific file, unless it is an existing directory, in which case a file is + created within it. If downloading multiple files, this is the root + directory to write within. Will create directories as required. + :type local_path: str + :param remote_path: remote path/globstring to use to find remote files. + Recursive glob patterns using `**` are not supported. + :type remote_path: str + :param nthreads: Number of threads to use. If None, uses the number of cores. + :type nthreads: int + :param overwrite: Whether to forcibly overwrite existing files/directories. + If False and remote path is a directory, will quit regardless if any files + would be overwritten or not. If True, only matching filenames are actually + overwritten. + :type overwrite: bool + :param buffersize: int [2**22] + Number of bytes for internal buffer. This block cannot be bigger than + a chunk and cannot be smaller than a block. + :type buffersize: int + :param blocksize: int [2**22] + Number of bytes for a block. Within each chunk, we write a smaller + block for each API call. This block cannot be bigger than a chunk. + :type blocksize: int + """ + multithread.ADLDownloader( + self.get_conn(), + lpath=local_path, + rpath=remote_path, + nthreads=nthreads, + overwrite=overwrite, + buffersize=buffersize, + blocksize=blocksize, + **kwargs, + ) + + def list(self, path: str) -> list: + """ + List files in Azure Data Lake Storage + + :param path: full path/globstring to use to list files in ADLS + :type path: str + """ + if "*" in path: + return self.get_conn().glob(path) + else: + return self.get_conn().walk(path) + + def remove( + self, path: str, recursive: bool = False, ignore_not_found: bool = True + ) -> None: + """ + Remove files in Azure Data Lake Storage + + :param path: A directory or file to remove in ADLS + :type path: str + :param recursive: Whether to loop into directories in the location and remove the files + :type recursive: bool + :param ignore_not_found: Whether to raise error if file to delete is not found + :type ignore_not_found: bool + """ + try: + self.get_conn().remove(path=path, recursive=recursive) + except FileNotFoundError: + if ignore_not_found: + self.log.info("File %s not found", path) + else: + raise AirflowException(f"File {path} not found") diff --git a/reference/providers/microsoft/azure/hooks/azure_fileshare.py b/reference/providers/microsoft/azure/hooks/azure_fileshare.py new file mode 100644 index 0000000..fd34693 --- /dev/null +++ b/reference/providers/microsoft/azure/hooks/azure_fileshare.py @@ -0,0 +1,313 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import List, Optional + +from airflow.hooks.base import BaseHook +from azure.storage.file import File, FileService + + +class AzureFileShareHook(BaseHook): + """ + Interacts with Azure FileShare Storage. + + Additional options passed in the 'extra' field of the connection will be + passed to the `FileService()` constructor. + + :param wasb_conn_id: Reference to the wasb connection. + :type wasb_conn_id: str + """ + + def __init__(self, wasb_conn_id: str = "wasb_default") -> None: + super().__init__() + self.conn_id = wasb_conn_id + self._conn = None + + def get_conn(self) -> FileService: + """Return the FileService object.""" + if not self._conn: + conn = self.get_connection(self.conn_id) + service_options = conn.extra_dejson + self._conn = FileService( + account_name=conn.login, account_key=conn.password, **service_options + ) + return self._conn + + def check_for_directory( + self, share_name: str, directory_name: str, **kwargs + ) -> bool: + """ + Check if a directory exists on Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param kwargs: Optional keyword arguments that + `FileService.exists()` takes. + :type kwargs: object + :return: True if the file exists, False otherwise. + :rtype: bool + """ + return self.get_conn().exists(share_name, directory_name, **kwargs) + + def check_for_file( + self, share_name: str, directory_name: str, file_name: str, **kwargs + ) -> bool: + """ + Check if a file exists on Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param file_name: Name of the file. + :type file_name: str + :param kwargs: Optional keyword arguments that + `FileService.exists()` takes. + :type kwargs: object + :return: True if the file exists, False otherwise. + :rtype: bool + """ + return self.get_conn().exists(share_name, directory_name, file_name, **kwargs) + + def list_directories_and_files( + self, share_name: str, directory_name: Optional[str] = None, **kwargs + ) -> list: + """ + Return the list of directories and files stored on a Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param kwargs: Optional keyword arguments that + `FileService.list_directories_and_files()` takes. + :type kwargs: object + :return: A list of files and directories + :rtype: list + """ + return self.get_conn().list_directories_and_files( + share_name, directory_name, **kwargs + ) + + def list_files( + self, share_name: str, directory_name: Optional[str] = None, **kwargs + ) -> List[str]: + """ + Return the list of files stored on a Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param kwargs: Optional keyword arguments that + `FileService.list_directories_and_files()` takes. + :type kwargs: object + :return: A list of files + :rtype: list + """ + return [ + obj.name + for obj in self.list_directories_and_files( + share_name, directory_name, **kwargs + ) + if isinstance(obj, File) + ] + + def create_share(self, share_name: str, **kwargs) -> bool: + """ + Create new Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param kwargs: Optional keyword arguments that + `FileService.create_share()` takes. + :type kwargs: object + :return: True if share is created, False if share already exists. + :rtype: bool + """ + return self.get_conn().create_share(share_name, **kwargs) + + def delete_share(self, share_name: str, **kwargs) -> bool: + """ + Delete existing Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param kwargs: Optional keyword arguments that + `FileService.delete_share()` takes. + :type kwargs: object + :return: True if share is deleted, False if share does not exist. + :rtype: bool + """ + return self.get_conn().delete_share(share_name, **kwargs) + + def create_directory(self, share_name: str, directory_name: str, **kwargs) -> list: + """ + Create a new directory on a Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param kwargs: Optional keyword arguments that + `FileService.create_directory()` takes. + :type kwargs: object + :return: A list of files and directories + :rtype: list + """ + return self.get_conn().create_directory(share_name, directory_name, **kwargs) + + def get_file( + self, + file_path: str, + share_name: str, + directory_name: str, + file_name: str, + **kwargs + ) -> None: + """ + Download a file from Azure File Share. + + :param file_path: Where to store the file. + :type file_path: str + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param file_name: Name of the file. + :type file_name: str + :param kwargs: Optional keyword arguments that + `FileService.get_file_to_path()` takes. + :type kwargs: object + """ + self.get_conn().get_file_to_path( + share_name, directory_name, file_name, file_path, **kwargs + ) + + def get_file_to_stream( + self, + stream: str, + share_name: str, + directory_name: str, + file_name: str, + **kwargs + ) -> None: + """ + Download a file from Azure File Share. + + :param stream: A filehandle to store the file to. + :type stream: file-like object + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param file_name: Name of the file. + :type file_name: str + :param kwargs: Optional keyword arguments that + `FileService.get_file_to_stream()` takes. + :type kwargs: object + """ + self.get_conn().get_file_to_stream( + share_name, directory_name, file_name, stream, **kwargs + ) + + def load_file( + self, + file_path: str, + share_name: str, + directory_name: str, + file_name: str, + **kwargs + ) -> None: + """ + Upload a file to Azure File Share. + + :param file_path: Path to the file to load. + :type file_path: str + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param file_name: Name of the file. + :type file_name: str + :param kwargs: Optional keyword arguments that + `FileService.create_file_from_path()` takes. + :type kwargs: object + """ + self.get_conn().create_file_from_path( + share_name, directory_name, file_name, file_path, **kwargs + ) + + def load_string( + self, + string_data: str, + share_name: str, + directory_name: str, + file_name: str, + **kwargs + ) -> None: + """ + Upload a string to Azure File Share. + + :param string_data: String to load. + :type string_data: str + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param file_name: Name of the file. + :type file_name: str + :param kwargs: Optional keyword arguments that + `FileService.create_file_from_text()` takes. + :type kwargs: object + """ + self.get_conn().create_file_from_text( + share_name, directory_name, file_name, string_data, **kwargs + ) + + def load_stream( + self, + stream: str, + share_name: str, + directory_name: str, + file_name: str, + count: str, + **kwargs + ) -> None: + """ + Upload a stream to Azure File Share. + + :param stream: Opened file/stream to upload as the file content. + :type stream: file-like + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param file_name: Name of the file. + :type file_name: str + :param count: Size of the stream in bytes + :type count: int + :param kwargs: Optional keyword arguments that + `FileService.create_file_from_stream()` takes. + :type kwargs: object + """ + self.get_conn().create_file_from_stream( + share_name, directory_name, file_name, stream, count, **kwargs + ) diff --git a/reference/providers/microsoft/azure/hooks/base_azure.py b/reference/providers/microsoft/azure/hooks/base_azure.py new file mode 100644 index 0000000..b4e88d9 --- /dev/null +++ b/reference/providers/microsoft/azure/hooks/base_azure.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from azure.common.client_factory import ( + get_client_from_auth_file, + get_client_from_json_dict, +) +from azure.common.credentials import ServicePrincipalCredentials + + +class AzureBaseHook(BaseHook): + """ + This hook acts as a base hook for azure services. It offers several authentication mechanisms to + authenticate the client library used for upstream azure hooks. + + :param sdk_client: The SDKClient to use. + :param conn_id: The azure connection id which refers to the information to connect to the service. + """ + + conn_name_attr = "conn_id" + default_conn_name = "azure_default" + conn_type = "azure" + hook_name = "Azure" + + def __init__(self, sdk_client: Any, conn_id: str = "azure_default"): + self.sdk_client = sdk_client + self.conn_id = conn_id + super().__init__() + + def get_conn(self) -> Any: + """ + Authenticates the resource using the connection id passed during init. + + :return: the authenticated client. + """ + conn = self.get_connection(self.conn_id) + + key_path = conn.extra_dejson.get("key_path") + if key_path: + if not key_path.endswith(".json"): + raise AirflowException("Unrecognised extension for key file.") + self.log.info("Getting connection using a JSON key file.") + return get_client_from_auth_file( + client_class=self.sdk_client, auth_path=key_path + ) + + key_json = conn.extra_dejson.get("key_json") + if key_json: + self.log.info("Getting connection using a JSON config.") + return get_client_from_json_dict( + client_class=self.sdk_client, config_dict=key_json + ) + + self.log.info( + "Getting connection using specific credentials and subscription_id." + ) + return self.sdk_client( + credentials=ServicePrincipalCredentials( + client_id=conn.login, + secret=conn.password, + tenant=conn.extra_dejson.get("tenantId"), + ), + subscription_id=conn.extra_dejson.get("subscriptionId"), + ) diff --git a/reference/providers/microsoft/azure/hooks/wasb.py b/reference/providers/microsoft/azure/hooks/wasb.py new file mode 100644 index 0000000..ebffd89 --- /dev/null +++ b/reference/providers/microsoft/azure/hooks/wasb.py @@ -0,0 +1,410 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +This module contains integration with Azure Blob Storage. + +It communicate via the Window Azure Storage Blob protocol. Make sure that a +Airflow connection of type `wasb` exists. Authorization can be done by supplying a +login (=Storage account name) and password (=KEY), or login and SAS token in the extra +field (see connection `wasb_default` for an example). + +""" + +from typing import Any, Dict, List, Optional + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError +from azure.identity import ClientSecretCredential +from azure.storage.blob import ( + BlobClient, + BlobServiceClient, + ContainerClient, + StorageStreamDownloader, +) + + +class WasbHook(BaseHook): + """ + Interacts with Azure Blob Storage through the ``wasb://`` protocol. + + These parameters have to be passed in Airflow Data Base: account_name and account_key. + + Additional options passed in the 'extra' field of the connection will be + passed to the `BlockBlockService()` constructor. For example, authenticate + using a SAS token by adding {"sas_token": "YOUR_TOKEN"}. + + :param wasb_conn_id: Reference to the wasb connection. + :type wasb_conn_id: str + :param public_read: Whether an anonymous public read access should be used. default is False + :type public_read: bool + """ + + conn_name_attr = "wasb_conn_id" + default_conn_name = "wasb_default" + conn_type = "wasb" + hook_name = "Azure Blob Storage" + + def __init__( + self, wasb_conn_id: str = default_conn_name, public_read: bool = False + ) -> None: + super().__init__() + self.conn_id = wasb_conn_id + self.public_read = public_read + self.connection = self.get_conn() + + def get_conn( + self, + ) -> BlobServiceClient: # pylint: disable=too-many-return-statements + """Return the BlobServiceClient object.""" + conn = self.get_connection(self.conn_id) + extra = conn.extra_dejson or {} + + if self.public_read: + # Here we use anonymous public read + # more info + # https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources + return BlobServiceClient(account_url=conn.host) + + if extra.get("connection_string"): + # connection_string auth takes priority + return BlobServiceClient.from_connection_string( + extra.get("connection_string") + ) + if extra.get("shared_access_key"): + # using shared access key + return BlobServiceClient( + account_url=conn.host, credential=extra.get("shared_access_key") + ) + if extra.get("tenant_id"): + # use Active Directory auth + app_id = conn.login + app_secret = conn.password + token_credential = ClientSecretCredential( + extra.get("tenant_id"), app_id, app_secret + ) + return BlobServiceClient(account_url=conn.host, credential=token_credential) + sas_token = extra.get("sas_token") + if sas_token and sas_token.startswith("https"): + return BlobServiceClient(account_url=extra.get("sas_token")) + if sas_token and not sas_token.startswith("https"): + return BlobServiceClient( + account_url=f"https://{conn.login}.blob.core.windows.net/" + sas_token + ) + else: + # Fall back to old auth + return BlobServiceClient( + account_url=f"https://{conn.login}.blob.core.windows.net/", + credential=conn.password, + **extra, + ) + + def _get_container_client(self, container_name: str) -> ContainerClient: + """ + Instantiates a container client + + :param container_name: The name of the container + :type container_name: str + :return: ContainerClient + """ + return self.connection.get_container_client(container_name) + + def _get_blob_client(self, container_name: str, blob_name: str) -> BlobClient: + """ + Instantiates a blob client + + :param container_name: The name of the blob container + :type container_name: str + :param blob_name: The name of the blob. This needs not be existing + :type blob_name: str + """ + container_client = self.create_container(container_name) + return container_client.get_blob_client(blob_name) + + def check_for_blob(self, container_name: str, blob_name: str, **kwargs) -> bool: + """ + Check if a blob exists on Azure Blob Storage. + + :param container_name: Name of the container. + :type container_name: str + :param blob_name: Name of the blob. + :type blob_name: str + :param kwargs: Optional keyword arguments for ``BlobClient.get_blob_properties`` takes. + :type kwargs: object + :return: True if the blob exists, False otherwise. + :rtype: bool + """ + try: + self._get_blob_client(container_name, blob_name).get_blob_properties( + **kwargs + ) + except ResourceNotFoundError: + return False + return True + + def check_for_prefix(self, container_name: str, prefix: str, **kwargs): + """ + Check if a prefix exists on Azure Blob storage. + + :param container_name: Name of the container. + :type container_name: str + :param prefix: Prefix of the blob. + :type prefix: str + :param kwargs: Optional keyword arguments that ``ContainerClient.walk_blobs`` takes + :type kwargs: object + :return: True if blobs matching the prefix exist, False otherwise. + :rtype: bool + """ + blobs = self.get_blobs_list( + container_name=container_name, prefix=prefix, **kwargs + ) + return len(blobs) > 0 + + def get_blobs_list( + self, + container_name: str, + prefix: Optional[str] = None, + include: Optional[List[str]] = None, + delimiter: Optional[str] = "/", + **kwargs, + ) -> List: + """ + List blobs in a given container + + :param container_name: The name of the container + :type container_name: str + :param prefix: Filters the results to return only blobs whose names + begin with the specified prefix. + :type prefix: str + :param include: Specifies one or more additional datasets to include in the + response. Options include: ``snapshots``, ``metadata``, ``uncommittedblobs``, + ``copy`, ``deleted``. + :type include: List[str] + :param delimiter: filters objects based on the delimiter (for e.g '.csv') + :type delimiter: str + """ + container = self._get_container_client(container_name) + blob_list = [] + blobs = container.walk_blobs( + name_starts_with=prefix, include=include, delimiter=delimiter, **kwargs + ) + for blob in blobs: + blob_list.append(blob.name) + return blob_list + + def load_file( + self, file_path: str, container_name: str, blob_name: str, **kwargs + ) -> None: + """ + Upload a file to Azure Blob Storage. + + :param file_path: Path to the file to load. + :type file_path: str + :param container_name: Name of the container. + :type container_name: str + :param blob_name: Name of the blob. + :type blob_name: str + :param kwargs: Optional keyword arguments that ``BlobClient.upload_blob()`` takes. + :type kwargs: object + """ + with open(file_path, "rb") as data: + self.upload( + container_name=container_name, blob_name=blob_name, data=data, **kwargs + ) + + def load_string( + self, string_data: str, container_name: str, blob_name: str, **kwargs + ) -> None: + """ + Upload a string to Azure Blob Storage. + + :param string_data: String to load. + :type string_data: str + :param container_name: Name of the container. + :type container_name: str + :param blob_name: Name of the blob. + :type blob_name: str + :param kwargs: Optional keyword arguments that ``BlobClient.upload()`` takes. + :type kwargs: object + """ + # Reorder the argument order from airflow.providers.amazon.aws.hooks.s3.load_string. + self.upload(container_name, blob_name, string_data, **kwargs) + + def get_file(self, file_path: str, container_name: str, blob_name: str, **kwargs): + """ + Download a file from Azure Blob Storage. + + :param file_path: Path to the file to download. + :type file_path: str + :param container_name: Name of the container. + :type container_name: str + :param blob_name: Name of the blob. + :type blob_name: str + :param kwargs: Optional keyword arguments that `BlobClient.download_blob()` takes. + :type kwargs: object + """ + with open(file_path, "wb") as fileblob: + stream = self.download( + container_name=container_name, blob_name=blob_name, **kwargs + ) + fileblob.write(stream.readall()) + + def read_file(self, container_name: str, blob_name: str, **kwargs): + """ + Read a file from Azure Blob Storage and return as a string. + + :param container_name: Name of the container. + :type container_name: str + :param blob_name: Name of the blob. + :type blob_name: str + :param kwargs: Optional keyword arguments that `BlobClient.download_blob` takes. + :type kwargs: object + """ + return self.download(container_name, blob_name, **kwargs).content_as_text() + + def upload( + self, + container_name, + blob_name, + data, + blob_type: str = "BlockBlob", + length: Optional[int] = None, + **kwargs, + ) -> Dict[str, Any]: + """ + Creates a new blob from a data source with automatic chunking. + + :param container_name: The name of the container to upload data + :type container_name: str + :param blob_name: The name of the blob to upload. This need not exist in the container + :type blob_name: str + :param data: The blob data to upload + :param blob_type: The type of the blob. This can be either ``BlockBlob``, + ``PageBlob`` or ``AppendBlob``. The default value is ``BlockBlob``. + :type blob_type: storage.BlobType + :param length: Number of bytes to read from the stream. This is optional, + but should be supplied for optimal performance. + :type length: int + """ + blob_client = self._get_blob_client(container_name, blob_name) + return blob_client.upload_blob(data, blob_type, length=length, **kwargs) + + def download( + self, + container_name, + blob_name, + offset: Optional[int] = None, + length: Optional[int] = None, + **kwargs, + ) -> StorageStreamDownloader: + """ + Downloads a blob to the StorageStreamDownloader + + :param container_name: The name of the container containing the blob + :type container_name: str + :param blob_name: The name of the blob to download + :type blob_name: str + :param offset: Start of byte range to use for downloading a section of the blob. + Must be set if length is provided. + :type offset: int + :param length: Number of bytes to read from the stream. + :type length: int + """ + blob_client = self._get_blob_client(container_name, blob_name) + return blob_client.download_blob(offset=offset, length=length, **kwargs) + + def create_container(self, container_name: str) -> ContainerClient: + """ + Create container object if not already existing + + :param container_name: The name of the container to create + :type container_name: str + """ + container_client = self._get_container_client(container_name) + try: + self.log.info("Attempting to create container: %s", container_name) + container_client.create_container() + self.log.info("Created container: %s", container_name) + return container_client + except ResourceExistsError: + self.log.info("Container %s already exists", container_name) + return container_client + + def delete_container(self, container_name: str) -> None: + """ + Delete a container object + + :param container_name: The name of the container + :type container_name: str + """ + try: + self.log.info("Attempting to delete container: %s", container_name) + self._get_container_client(container_name).delete_container() + self.log.info("Deleted container: %s", container_name) + except ResourceNotFoundError: + self.log.info("Container %s not found", container_name) + + def delete_blobs(self, container_name: str, *blobs, **kwargs) -> None: + """ + Marks the specified blobs or snapshots for deletion. + + :param container_name: The name of the container containing the blobs + :type container_name: str + :param blobs: The blobs to delete. This can be a single blob, or multiple values + can be supplied, where each value is either the name of the blob (str) or BlobProperties. + :type blobs: Union[str, BlobProperties] + """ + self._get_container_client(container_name).delete_blobs(*blobs, **kwargs) + self.log.info("Deleted blobs: %s", blobs) + + def delete_file( + self, + container_name: str, + blob_name: str, + is_prefix: bool = False, + ignore_if_missing: bool = False, + **kwargs, + ) -> None: + """ + Delete a file from Azure Blob Storage. + + :param container_name: Name of the container. + :type container_name: str + :param blob_name: Name of the blob. + :type blob_name: str + :param is_prefix: If blob_name is a prefix, delete all matching files + :type is_prefix: bool + :param ignore_if_missing: if True, then return success even if the + blob does not exist. + :type ignore_if_missing: bool + :param kwargs: Optional keyword arguments that ``ContainerClient.delete_blobs()`` takes. + :type kwargs: object + """ + if is_prefix: + blobs_to_delete = self.get_blobs_list( + container_name, prefix=blob_name, **kwargs + ) + elif self.check_for_blob(container_name, blob_name): + blobs_to_delete = [blob_name] + else: + blobs_to_delete = [] + if not ignore_if_missing and len(blobs_to_delete) == 0: + raise AirflowException(f"Blob(s) not found: {blob_name}") + + self.delete_blobs(container_name, *blobs_to_delete, **kwargs) diff --git a/reference/providers/microsoft/azure/log/__init__.py b/reference/providers/microsoft/azure/log/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/microsoft/azure/log/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/azure/log/wasb_task_handler.py b/reference/providers/microsoft/azure/log/wasb_task_handler.py new file mode 100644 index 0000000..32fea8e --- /dev/null +++ b/reference/providers/microsoft/azure/log/wasb_task_handler.py @@ -0,0 +1,199 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import shutil +from typing import Dict, Optional, Tuple + +from azure.common import AzureHttpError + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.configuration import conf +from airflow.utils.log.file_task_handler import FileTaskHandler +from airflow.utils.log.logging_mixin import LoggingMixin + + +class WasbTaskHandler(FileTaskHandler, LoggingMixin): + """ + WasbTaskHandler is a python log handler that handles and reads + task instance logs. It extends airflow FileTaskHandler and + uploads to and reads from Wasb remote storage. + """ + + def __init__( + self, + base_log_folder: str, + wasb_log_folder: str, + wasb_container: str, + filename_template: str, + delete_local_copy: str, + ) -> None: + super().__init__(base_log_folder, filename_template) + self.wasb_container = wasb_container + self.remote_base = wasb_log_folder + self.log_relative_path = "" + self._hook = None + self.closed = False + self.upload_on_close = True + self.delete_local_copy = delete_local_copy + + @cached_property + def hook(self): + """Returns WasbHook.""" + remote_conn_id = conf.get("logging", "REMOTE_LOG_CONN_ID") + try: + from airflow.providers.microsoft.azure.hooks.wasb import WasbHook + + return WasbHook(remote_conn_id) + except AzureHttpError as e: + self.log.error( + 'Could not create an WasbHook with connection id "%s". ' + "Please make sure that airflow[azure] is installed and " + 'the Wasb connection exists. Exception "%s"', + remote_conn_id, + e, + ) + return None + + def set_context(self, ti) -> None: + super().set_context(ti) + # Local location and remote location is needed to open and + # upload local log file to Wasb remote storage. + self.log_relative_path = self._render_filename(ti, ti.try_number) + self.upload_on_close = not ti.raw + + def close(self) -> None: + """Close and upload local log file to remote storage Wasb.""" + # When application exit, system shuts down all handlers by + # calling close method. Here we check if logger is already + # closed to prevent uploading the log to remote storage multiple + # times when `logging.shutdown` is called. + if self.closed: + return + + super().close() + + if not self.upload_on_close: + return + + local_loc = os.path.join(self.local_base, self.log_relative_path) + remote_loc = os.path.join(self.remote_base, self.log_relative_path) + if os.path.exists(local_loc): + # read log and remove old logs to get just the latest additions + with open(local_loc) as logfile: + log = logfile.read() + self.wasb_write(log, remote_loc, append=True) + + if self.delete_local_copy: + shutil.rmtree(os.path.dirname(local_loc)) + # Mark closed so we don't double write if close is called twice + self.closed = True + + def _read( + self, ti, try_number: str, metadata: Optional[str] = None + ) -> Tuple[str, Dict[str, bool]]: + """ + Read logs of given task instance and try_number from Wasb remote storage. + If failed, read the log from task instance host machine. + + :param ti: task instance object + :param try_number: task instance try_number to read logs from + :param metadata: log metadata, + can be used for steaming log reading and auto-tailing. + """ + # Explicitly getting log relative path is necessary as the given + # task instance might be different than task instance passed in + # in set_context method. + log_relative_path = self._render_filename(ti, try_number) + remote_loc = os.path.join(self.remote_base, log_relative_path) + + if self.wasb_log_exists(remote_loc): + # If Wasb remote file exists, we do not fetch logs from task instance + # local machine even if there are errors reading remote logs, as + # returned remote_log will contain error messages. + remote_log = self.wasb_read(remote_loc, return_error=True) + log = f"*** Reading remote log from {remote_loc}.\n{remote_log}\n" + return log, {"end_of_log": True} + else: + return super()._read(ti, try_number) + + def wasb_log_exists(self, remote_log_location: str) -> bool: + """ + Check if remote_log_location exists in remote storage + + :param remote_log_location: log's location in remote storage + :return: True if location exists else False + """ + try: + return self.hook.check_for_blob(self.wasb_container, remote_log_location) + # pylint: disable=broad-except + except Exception as e: + self.log.debug('Exception when trying to check remote location: "%s"', e) + return False + + def wasb_read(self, remote_log_location: str, return_error: bool = False): + """ + Returns the log found at the remote_log_location. Returns '' if no + logs are found or there is an error. + + :param remote_log_location: the log's location in remote storage + :type remote_log_location: str (path) + :param return_error: if True, returns a string error message if an + error occurs. Otherwise returns '' when an error occurs. + :type return_error: bool + """ + try: + return self.hook.read_file(self.wasb_container, remote_log_location) + except AzureHttpError as e: + msg = f"Could not read logs from {remote_log_location}" + self.log.exception("Message: '%s', exception '%s'", msg, e) + # return error if needed + if return_error: + return msg + return "" + + def wasb_write( + self, log: str, remote_log_location: str, append: bool = True + ) -> None: + """ + Writes the log to the remote_log_location. Fails silently if no hook + was created. + + :param log: the log to write to the remote_log_location + :type log: str + :param remote_log_location: the log's location in remote storage + :type remote_log_location: str (path) + :param append: if False, any existing log file is overwritten. If True, + the new log is appended to any existing logs. + :type append: bool + """ + if append and self.wasb_log_exists(remote_log_location): + old_log = self.wasb_read(remote_log_location) + log = "\n".join([old_log, log]) if old_log else log + + try: + self.hook.load_string( + log, + self.wasb_container, + remote_log_location, + ) + except AzureHttpError: + self.log.exception("Could not write logs to %s", remote_log_location) diff --git a/reference/providers/microsoft/azure/operators/__init__.py b/reference/providers/microsoft/azure/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/microsoft/azure/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/azure/operators/adls_delete.py b/reference/providers/microsoft/azure/operators/adls_delete.py new file mode 100644 index 0000000..c8b343d --- /dev/null +++ b/reference/providers/microsoft/azure/operators/adls_delete.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Sequence + +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook +from airflow.utils.decorators import apply_defaults + + +class AzureDataLakeStorageDeleteOperator(BaseOperator): + """ + Delete files in the specified path. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureDataLakeStorageDeleteOperator` + + :param path: A directory or file to remove + :type path: str + :param recursive: Whether to loop into directories in the location and remove the files + :type recursive: bool + :param ignore_not_found: Whether to raise error if file to delete is not found + :type ignore_not_found: bool + """ + + template_fields: Sequence[str] = ("path",) + ui_color = "#901dd2" + + @apply_defaults + def __init__( + self, + *, + path: str, + recursive: bool = False, + ignore_not_found: bool = True, + azure_data_lake_conn_id: str = "azure_data_lake_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.path = path + self.recursive = recursive + self.ignore_not_found = ignore_not_found + self.azure_data_lake_conn_id = azure_data_lake_conn_id + + def execute(self, context: dict) -> Any: + hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id) + + return hook.remove( + path=self.path, + recursive=self.recursive, + ignore_not_found=self.ignore_not_found, + ) diff --git a/reference/providers/microsoft/azure/operators/adls_list.py b/reference/providers/microsoft/azure/operators/adls_list.py new file mode 100644 index 0000000..4bfb0d1 --- /dev/null +++ b/reference/providers/microsoft/azure/operators/adls_list.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Sequence + +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook +from airflow.utils.decorators import apply_defaults + + +class AzureDataLakeStorageListOperator(BaseOperator): + """ + List all files from the specified path + + This operator returns a python list with the names of files which can be used by + `xcom` in the downstream tasks. + + :param path: The Azure Data Lake path to find the objects. Supports glob + strings (templated) + :type path: str + :param azure_data_lake_conn_id: The connection ID to use when + connecting to Azure Data Lake Storage. + :type azure_data_lake_conn_id: str + + **Example**: + The following Operator would list all the Parquet files from ``folder/output/`` + folder in the specified ADLS account :: + + adls_files = AzureDataLakeStorageListOperator( + task_id='adls_files', + path='folder/output/*.parquet', + azure_data_lake_conn_id='azure_data_lake_default' + ) + """ + + template_fields: Sequence[str] = ("path",) + ui_color = "#901dd2" + + @apply_defaults + def __init__( + self, + *, + path: str, + azure_data_lake_conn_id: str = "azure_data_lake_default", + **kwargs + ) -> None: + super().__init__(**kwargs) + self.path = path + self.azure_data_lake_conn_id = azure_data_lake_conn_id + + def execute(self, context: dict) -> list: + + hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id) + + self.log.info("Getting list of ADLS files in path: %s", self.path) + + return hook.list(path=self.path) diff --git a/reference/providers/microsoft/azure/operators/adx.py b/reference/providers/microsoft/azure/operators/adx.py new file mode 100644 index 0000000..ffd83e6 --- /dev/null +++ b/reference/providers/microsoft/azure/operators/adx.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +"""This module contains Azure Data Explorer operators""" +from typing import Optional, Union + +from airflow.configuration import conf +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook +from airflow.utils.decorators import apply_defaults +from azure.kusto.data._models import KustoResultTable + + +class AzureDataExplorerQueryOperator(BaseOperator): + """ + Operator for querying Azure Data Explorer (Kusto). + + :param query: KQL query to run (templated). + :type query: str + :param database: Database to run the query on (templated). + :type database: str + :param options: Optional query options. See: + https://docs.microsoft.com/en-us/azure/kusto/api/netfx/request-properties#list-of-clientrequestproperties + :type options: dict + :param azure_data_explorer_conn_id: Azure Data Explorer connection to use. + :type azure_data_explorer_conn_id: str + """ + + ui_color = "#00a1f2" + template_fields = ("query", "database") + template_ext = (".kql",) + + @apply_defaults + def __init__( + self, + *, + query: str, + database: str, + options: Optional[dict] = None, + azure_data_explorer_conn_id: str = "azure_data_explorer_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.query = query + self.database = database + self.options = options + self.azure_data_explorer_conn_id = azure_data_explorer_conn_id + + def get_hook(self) -> AzureDataExplorerHook: + """Returns new instance of AzureDataExplorerHook""" + return AzureDataExplorerHook(self.azure_data_explorer_conn_id) + + def execute(self, context: dict) -> Union[KustoResultTable, str]: + """ + Run KQL Query on Azure Data Explorer (Kusto). + Returns `PrimaryResult` of Query v2 HTTP response contents + (https://docs.microsoft.com/en-us/azure/kusto/api/rest/response2) + """ + hook = self.get_hook() + response = hook.run_query(self.query, self.database, self.options) + if conf.getboolean("core", "enable_xcom_pickling"): + return response.primary_results[0] + else: + return str(response.primary_results[0]) diff --git a/reference/providers/microsoft/azure/operators/azure_batch.py b/reference/providers/microsoft/azure/operators/azure_batch.py new file mode 100644 index 0000000..f74317a --- /dev/null +++ b/reference/providers/microsoft/azure/operators/azure_batch.py @@ -0,0 +1,410 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Any, List, Optional + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.azure_batch import AzureBatchHook +from airflow.utils.decorators import apply_defaults +from azure.batch import models as batch_models + + +# pylint: disable=too-many-instance-attributes +class AzureBatchOperator(BaseOperator): + """ + Executes a job on Azure Batch Service + + :param batch_pool_id: A string that uniquely identifies the Pool within the Account. + :type batch_pool_id: str + + :param batch_pool_vm_size: The size of virtual machines in the Pool + :type batch_pool_vm_size: str + + :param batch_job_id: A string that uniquely identifies the Job within the Account. + :type batch_job_id: str + + :param batch_task_command_line: The command line of the Task + :type batch_command_line: str + + :param batch_task_id: A string that uniquely identifies the task within the Job. + :type batch_task_id: str + + :param batch_pool_display_name: The display name for the Pool. + The display name need not be unique + :type batch_pool_display_name: Optional[str] + + :param batch_job_display_name: The display name for the Job. + The display name need not be unique + :type batch_job_display_name: Optional[str] + + :param batch_job_manager_task: Details of a Job Manager Task to be launched when the Job is started. + :type job_manager_task: Optional[batch_models.JobManagerTask] + + :param batch_job_preparation_task: The Job Preparation Task. If set, the Batch service will + run the Job Preparation Task on a Node before starting any Tasks of that + Job on that Compute Node. Required if batch_job_release_task is set. + :type batch_job_preparation_task: Optional[batch_models.JobPreparationTask] + + :param batch_job_release_task: The Job Release Task. Use to undo changes to Compute Nodes + made by the Job Preparation Task + :type batch_job_release_task: Optional[batch_models.JobReleaseTask] + + :param batch_task_display_name: The display name for the task. + The display name need not be unique + :type batch_task_display_name: Optional[str] + + :param batch_task_container_settings: The settings for the container under which the Task runs + :type batch_task_container_settings: Optional[batch_models.TaskContainerSettings] + + :param batch_start_task: A Task specified to run on each Compute Node as it joins the Pool. + The Task runs when the Compute Node is added to the Pool or + when the Compute Node is restarted. + :type batch_start_task: Optional[batch_models.StartTask] + + :param batch_max_retries: The number of times to retry this batch operation before it's + considered a failed operation. Default is 3 + :type batch_max_retries: int + + :param batch_task_resource_files: A list of files that the Batch service will + download to the Compute Node before running the command line. + :type batch_task_resource_files: Optional[List[batch_models.ResourceFile]] + + :param batch_task_output_files: A list of files that the Batch service will upload + from the Compute Node after running the command line. + :type batch_task_output_files: Optional[List[batch_models.OutputFile]] + + :param batch_task_user_identity: The user identity under which the Task runs. + If omitted, the Task runs as a non-administrative user unique to the Task. + :type batch_task_user_identity: Optional[batch_models.UserIdentity] + + :param target_low_priority_nodes: The desired number of low-priority Compute Nodes in the Pool. + This property must not be specified if enable_auto_scale is set to true. + :type target_low_priority_nodes: Optional[int] + + :param target_dedicated_nodes: The desired number of dedicated Compute Nodes in the Pool. + This property must not be specified if enable_auto_scale is set to true. + :type target_dedicated_nodes: Optional[int] + + :param enable_auto_scale: Whether the Pool size should automatically adjust over time. Default is false + :type enable_auto_scale: bool + + :param auto_scale_formula: A formula for the desired number of Compute Nodes in the Pool. + This property must not be specified if enableAutoScale is set to false. + It is required if enableAutoScale is set to true. + :type auto_scale_formula: Optional[str] + + :param azure_batch_conn_id: The connection id of Azure batch service + :type azure_batch_conn_id: str + + :param use_latest_verified_vm_image_and_sku: Whether to use the latest verified virtual + machine image and sku in the batch account. Default is false. + :type use_latest_verified_vm_image_and_sku: bool + + :param vm_publisher: The publisher of the Azure Virtual Machines Marketplace Image. + For example, Canonical or MicrosoftWindowsServer. Required if + use_latest_image_and_sku is set to True + :type vm_publisher: Optional[str] + + :param vm_offer: The offer type of the Azure Virtual Machines Marketplace Image. + For example, UbuntuServer or WindowsServer. Required if + use_latest_image_and_sku is set to True + :type vm_offer: Optional[str] + + :param sku_starts_with: The starting string of the Virtual Machine SKU. Required if + use_latest_image_and_sku is set to True + :type sku_starts_with: Optional[str] + + :param vm_sku: The name of the virtual machine sku to use + :type vm_sku: Optional[str] + + :param vm_version: The version of the virtual machine + :param vm_version: Optional[str] + + :param vm_node_agent_sku_id: The node agent sku id of the virtual machine + :type vm_node_agent_sku_id: Optional[str] + + :param os_family: The Azure Guest OS family to be installed on the virtual machines in the Pool. + :type os_family: Optional[str] + + :param os_version: The OS family version + :type os_version: Optional[str] + + :param timeout: The amount of time to wait for the job to complete in minutes. Default is 25 + :type timeout: int + + :param should_delete_job: Whether to delete job after execution. Default is False + :type should_delete_job: bool + + :param should_delete_pool: Whether to delete pool after execution of jobs. Default is False + :type should_delete_pool: bool + + + """ + + template_fields = ( + "batch_pool_id", + "batch_pool_vm_size", + "batch_job_id", + "batch_task_id", + "batch_task_command_line", + ) + ui_color = "#f0f0e4" + + @apply_defaults + def __init__( + self, + *, # pylint: disable=too-many-arguments,too-many-locals + batch_pool_id: str, + batch_pool_vm_size: str, + batch_job_id: str, + batch_task_command_line: str, + batch_task_id: str, + vm_publisher: Optional[str] = None, + vm_offer: Optional[str] = None, + sku_starts_with: Optional[str] = None, + vm_sku: Optional[str] = None, + vm_version: Optional[str] = None, + vm_node_agent_sku_id: Optional[str] = None, + os_family: Optional[str] = None, + os_version: Optional[str] = None, + batch_pool_display_name: Optional[str] = None, + batch_job_display_name: Optional[str] = None, + batch_job_manager_task: Optional[batch_models.JobManagerTask] = None, + batch_job_preparation_task: Optional[batch_models.JobPreparationTask] = None, + batch_job_release_task: Optional[batch_models.JobReleaseTask] = None, + batch_task_display_name: Optional[str] = None, + batch_task_container_settings: Optional[ + batch_models.TaskContainerSettings + ] = None, + batch_start_task: Optional[batch_models.StartTask] = None, + batch_max_retries: int = 3, + batch_task_resource_files: Optional[List[batch_models.ResourceFile]] = None, + batch_task_output_files: Optional[List[batch_models.OutputFile]] = None, + batch_task_user_identity: Optional[batch_models.UserIdentity] = None, + target_low_priority_nodes: Optional[int] = None, + target_dedicated_nodes: Optional[int] = None, + enable_auto_scale: bool = False, + auto_scale_formula: Optional[str] = None, + azure_batch_conn_id="azure_batch_default", + use_latest_verified_vm_image_and_sku: bool = False, + timeout: int = 25, + should_delete_job: bool = False, + should_delete_pool: bool = False, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self.batch_pool_id = batch_pool_id + self.batch_pool_vm_size = batch_pool_vm_size + self.batch_job_id = batch_job_id + self.batch_task_id = batch_task_id + self.batch_task_command_line = batch_task_command_line + self.batch_pool_display_name = batch_pool_display_name + self.batch_job_display_name = batch_job_display_name + self.batch_job_manager_task = batch_job_manager_task + self.batch_job_preparation_task = batch_job_preparation_task + self.batch_job_release_task = batch_job_release_task + self.batch_task_display_name = batch_task_display_name + self.batch_task_container_settings = batch_task_container_settings + self.batch_start_task = batch_start_task + self.batch_max_retries = batch_max_retries + self.batch_task_resource_files = batch_task_resource_files + self.batch_task_output_files = batch_task_output_files + self.batch_task_user_identity = batch_task_user_identity + self.target_low_priority_nodes = target_low_priority_nodes + self.target_dedicated_nodes = target_dedicated_nodes + self.enable_auto_scale = enable_auto_scale + self.auto_scale_formula = auto_scale_formula + self.azure_batch_conn_id = azure_batch_conn_id + self.use_latest_image = use_latest_verified_vm_image_and_sku + self.vm_publisher = vm_publisher + self.vm_offer = vm_offer + self.sku_starts_with = sku_starts_with + self.vm_sku = vm_sku + self.vm_version = vm_version + self.vm_node_agent_sku_id = vm_node_agent_sku_id + self.os_family = os_family + self.os_version = os_version + self.timeout = timeout + self.should_delete_job = should_delete_job + self.should_delete_pool = should_delete_pool + self.hook = self.get_hook() + + def _check_inputs(self) -> Any: + if not self.os_family and not self.vm_publisher: + raise AirflowException("You must specify either vm_publisher or os_family") + if self.os_family and self.vm_publisher: + raise AirflowException( + "Cloud service configuration and virtual machine configuration " + "are mutually exclusive. You must specify either of os_family and" + " vm_publisher" + ) + + if self.use_latest_image: + if not all(elem for elem in [self.vm_publisher, self.vm_offer]): + raise AirflowException( + "If use_latest_image_and_sku is" + " set to True then the parameters vm_publisher, vm_offer, " + "must all be set. Found " + "vm_publisher={}, vm_offer={}".format( + self.vm_publisher, self.vm_offer + ) + ) + if self.vm_publisher: + if not all([self.vm_sku, self.vm_offer, self.vm_node_agent_sku_id]): + raise AirflowException( + "If vm_publisher is set, then the parameters vm_sku, vm_offer," + "vm_node_agent_sku_id must be set. Found " + f"vm_publisher={self.vm_publisher}, vm_offer={self.vm_offer} " + f"vm_node_agent_sku_id={self.vm_node_agent_sku_id}, " + f"vm_version={self.vm_version}" + ) + + if not self.target_dedicated_nodes and not self.enable_auto_scale: + raise AirflowException( + "Either target_dedicated_nodes or enable_auto_scale must be set. None was set" + ) + if self.enable_auto_scale: + if self.target_dedicated_nodes or self.target_low_priority_nodes: + raise AirflowException( + "If enable_auto_scale is set, then the parameters " + "target_dedicated_nodes and target_low_priority_nodes must not " + "be set. Found target_dedicated_nodes={}," + " target_low_priority_nodes={}".format( + self.target_dedicated_nodes, self.target_low_priority_nodes + ) + ) + if not self.auto_scale_formula: + raise AirflowException( + "The auto_scale_formula is required when enable_auto_scale is set" + ) + if self.batch_job_release_task and not self.batch_job_preparation_task: + raise AirflowException( + "A batch_job_release_task cannot be specified without also " + " specifying a batch_job_preparation_task for the Job." + ) + if not all( + [ + self.batch_pool_id, + self.batch_job_id, + self.batch_pool_vm_size, + self.batch_task_id, + self.batch_task_command_line, + ] + ): + raise AirflowException( + "Some required parameters are missing.Please you must set all the required parameters. " + ) + + def execute(self, context: dict) -> None: + self._check_inputs() + self.hook.connection.config.retry_policy = self.batch_max_retries + + pool = self.hook.configure_pool( + pool_id=self.batch_pool_id, + vm_size=self.batch_pool_vm_size, + display_name=self.batch_pool_display_name, + target_dedicated_nodes=self.target_dedicated_nodes, + use_latest_image_and_sku=self.use_latest_image, + vm_publisher=self.vm_publisher, + vm_offer=self.vm_offer, + sku_starts_with=self.sku_starts_with, + vm_sku=self.vm_sku, + vm_version=self.vm_version, + vm_node_agent_sku_id=self.vm_node_agent_sku_id, + os_family=self.os_family, + os_version=self.os_version, + target_low_priority_nodes=self.target_low_priority_nodes, + enable_auto_scale=self.enable_auto_scale, + auto_scale_formula=self.auto_scale_formula, + start_task=self.batch_start_task, + ) + self.hook.create_pool(pool) + # Wait for nodes to reach complete state + self.hook.wait_for_all_node_state( + self.batch_pool_id, + { + batch_models.ComputeNodeState.start_task_failed, + batch_models.ComputeNodeState.unusable, + batch_models.ComputeNodeState.idle, + }, + ) + # Create job if not already exist + job = self.hook.configure_job( + job_id=self.batch_job_id, + pool_id=self.batch_pool_id, + display_name=self.batch_job_display_name, + job_manager_task=self.batch_job_manager_task, + job_preparation_task=self.batch_job_preparation_task, + job_release_task=self.batch_job_release_task, + ) + self.hook.create_job(job) + # Create task + task = self.hook.configure_task( + task_id=self.batch_task_id, + command_line=self.batch_task_command_line, + display_name=self.batch_task_display_name, + container_settings=self.batch_task_container_settings, + resource_files=self.batch_task_resource_files, + output_files=self.batch_task_output_files, + user_identity=self.batch_task_user_identity, + ) + # Add task to job + self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task) + # Wait for tasks to complete + self.hook.wait_for_job_tasks_to_complete( + job_id=self.batch_job_id, timeout=self.timeout + ) + # Clean up + if self.should_delete_job: + # delete job first + self.clean_up(job_id=self.batch_job_id) + if self.should_delete_pool: + self.clean_up(self.batch_pool_id) + + def on_kill(self) -> None: + response = self.hook.connection.job.terminate( + job_id=self.batch_job_id, terminate_reason="Job killed by user" + ) + self.log.info( + "Azure Batch job (%s) terminated: %s", self.batch_job_id, response + ) + + def get_hook(self) -> AzureBatchHook: + """Create and return an AzureBatchHook.""" + return AzureBatchHook(azure_batch_conn_id=self.azure_batch_conn_id) + + def clean_up( + self, pool_id: Optional[str] = None, job_id: Optional[str] = None + ) -> None: + """ + Delete the given pool and job in the batch account + + :param pool_id: The id of the pool to delete + :type pool_id: str + :param job_id: The id of the job to delete + :type job_id: str + + """ + if job_id: + self.log.info("Deleting job: %s", job_id) + self.hook.connection.job.delete(job_id) + if pool_id: + self.log.info("Deleting pool: %s", pool_id) + self.hook.connection.pool.delete(pool_id) diff --git a/reference/providers/microsoft/azure/operators/azure_container_instances.py b/reference/providers/microsoft/azure/operators/azure_container_instances.py new file mode 100644 index 0000000..d62afbd --- /dev/null +++ b/reference/providers/microsoft/azure/operators/azure_container_instances.py @@ -0,0 +1,421 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +from collections import namedtuple +from time import sleep +from typing import Any, Dict, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException, AirflowTaskTimeout +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.azure_container_instance import ( + AzureContainerInstanceHook, +) +from airflow.providers.microsoft.azure.hooks.azure_container_registry import ( + AzureContainerRegistryHook, +) +from airflow.providers.microsoft.azure.hooks.azure_container_volume import ( + AzureContainerVolumeHook, +) +from airflow.utils.decorators import apply_defaults +from azure.mgmt.containerinstance.models import ( + Container, + ContainerGroup, + ContainerPort, + EnvironmentVariable, + IpAddress, + ResourceRequests, + ResourceRequirements, + VolumeMount, +) +from msrestazure.azure_exceptions import CloudError + +Volume = namedtuple( + "Volume", + ["conn_id", "account_name", "share_name", "mount_path", "read_only"], +) + + +DEFAULT_ENVIRONMENT_VARIABLES: Dict[str, str] = {} +DEFAULT_SECURED_VARIABLES: Sequence[str] = [] +DEFAULT_VOLUMES: Sequence[Volume] = [] +DEFAULT_MEMORY_IN_GB = 2.0 +DEFAULT_CPU = 1.0 + + +# pylint: disable=too-many-instance-attributes +class AzureContainerInstancesOperator(BaseOperator): + """ + Start a container on Azure Container Instances + + :param ci_conn_id: connection id of a service principal which will be used + to start the container instance + :type ci_conn_id: str + :param registry_conn_id: connection id of a user which can login to a + private docker registry. If None, we assume a public registry + :type registry_conn_id: Optional[str] + :param resource_group: name of the resource group wherein this container + instance should be started + :type resource_group: str + :param name: name of this container instance. Please note this name has + to be unique in order to run containers in parallel. + :type name: str + :param image: the docker image to be used + :type image: str + :param region: the region wherein this container instance should be started + :type region: str + :param environment_variables: key,value pairs containing environment + variables which will be passed to the running container + :type environment_variables: Optional[dict] + :param secured_variables: names of environmental variables that should not + be exposed outside the container (typically passwords). + :type secured_variables: Optional[str] + :param volumes: list of ``Volume`` tuples to be mounted to the container. + Currently only Azure Fileshares are supported. + :type volumes: list[] + :param memory_in_gb: the amount of memory to allocate to this container + :type memory_in_gb: double + :param cpu: the number of cpus to allocate to this container + :type cpu: double + :param gpu: GPU Resource for the container. + :type gpu: azure.mgmt.containerinstance.models.GpuResource + :param command: the command to run inside the container + :type command: Optional[List[str]] + :param container_timeout: max time allowed for the execution of + the container instance. + :type container_timeout: datetime.timedelta + :param tags: azure tags as dict of str:str + :type tags: Optional[dict[str, str]] + :param os_type: The operating system type required by the containers + in the container group. Possible values include: 'Windows', 'Linux' + :type os_type: str + :param restart_policy: Restart policy for all containers within the container group. + Possible values include: 'Always', 'OnFailure', 'Never' + :type restart_policy: str + :param ip_address: The IP address type of the container group. + :type ip_address: IpAddress + + **Example**:: + + AzureContainerInstancesOperator( + ci_conn_id = "azure_service_principal", + registry_conn_id = "azure_registry_user", + resource_group = "my-resource-group", + name = "my-container-name-{{ ds }}", + image = "myprivateregistry.azurecr.io/my_container:latest", + region = "westeurope", + environment_variables = {"MODEL_PATH": "my_value", + "POSTGRES_LOGIN": "{{ macros.connection('postgres_default').login }}", + "POSTGRES_PASSWORD": "{{ macros.connection('postgres_default').password }}", + "JOB_GUID": "{{ ti.xcom_pull(task_ids='task1', key='guid') }}" }, + secured_variables = ['POSTGRES_PASSWORD'], + volumes = [("azure_wasb_conn_id", + "my_storage_container", + "my_fileshare", + "/input-data", + True),], + memory_in_gb=14.0, + cpu=4.0, + gpu=GpuResource(count=1, sku='K80'), + command=["/bin/echo", "world"], + task_id="start_container" + ) + """ + + template_fields = ("name", "image", "command", "environment_variables") + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + ci_conn_id: str, + registry_conn_id: Optional[str], + resource_group: str, + name: str, + image: str, + region: str, + environment_variables: Optional[dict] = None, + secured_variables: Optional[str] = None, + volumes: Optional[list] = None, + memory_in_gb: Optional[Any] = None, + cpu: Optional[Any] = None, + gpu: Optional[Any] = None, + command: Optional[List[str]] = None, + remove_on_error: bool = True, + fail_if_exists: bool = True, + tags: Optional[Dict[str, str]] = None, + os_type: str = "Linux", + restart_policy: str = "Never", + ip_address: Optional[IpAddress] = None, + ports: Optional[List[ContainerPort]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.ci_conn_id = ci_conn_id + self.resource_group = resource_group + self.name = self._check_name(name) + self.image = image + self.region = region + self.registry_conn_id = registry_conn_id + self.environment_variables = ( + environment_variables or DEFAULT_ENVIRONMENT_VARIABLES + ) + self.secured_variables = secured_variables or DEFAULT_SECURED_VARIABLES + self.volumes = volumes or DEFAULT_VOLUMES + self.memory_in_gb = memory_in_gb or DEFAULT_MEMORY_IN_GB + self.cpu = cpu or DEFAULT_CPU + self.gpu = gpu + self.command = command + self.remove_on_error = remove_on_error + self.fail_if_exists = fail_if_exists + self._ci_hook: Any = None + self.tags = tags + self.os_type = os_type + if self.os_type not in ["Linux", "Windows"]: + raise AirflowException( + "Invalid value for the os_type argument. " + "Please set 'Linux' or 'Windows' as the os_type. " + f"Found `{self.os_type}`." + ) + self.restart_policy = restart_policy + if self.restart_policy not in ["Always", "OnFailure", "Never"]: + raise AirflowException( + "Invalid value for the restart_policy argument. " + "Please set one of 'Always', 'OnFailure','Never' as the restart_policy. " + f"Found `{self.restart_policy}`" + ) + self.ip_address = ip_address + self.ports = ports + + def execute(self, context: dict) -> int: + # Check name again in case it was templated. + self._check_name(self.name) + + self._ci_hook = AzureContainerInstanceHook(self.ci_conn_id) + + if self.fail_if_exists: + self.log.info("Testing if container group already exists") + if self._ci_hook.exists(self.resource_group, self.name): + raise AirflowException("Container group exists") + + if self.registry_conn_id: + registry_hook = AzureContainerRegistryHook(self.registry_conn_id) + image_registry_credentials: Optional[list] = [ + registry_hook.connection, + ] + else: + image_registry_credentials = None + + environment_variables = [] + for key, value in self.environment_variables.items(): + if key in self.secured_variables: + e = EnvironmentVariable(name=key, secure_value=value) + else: + e = EnvironmentVariable(name=key, value=value) + environment_variables.append(e) + + volumes: List[Union[Volume, Volume]] = [] + volume_mounts: List[Union[VolumeMount, VolumeMount]] = [] + for conn_id, account_name, share_name, mount_path, read_only in self.volumes: + hook = AzureContainerVolumeHook(conn_id) + + mount_name = "mount-%d" % len(volumes) + volumes.append( + hook.get_file_volume(mount_name, share_name, account_name, read_only) + ) + volume_mounts.append( + VolumeMount(name=mount_name, mount_path=mount_path, read_only=read_only) + ) + + exit_code = 1 + try: + self.log.info( + "Starting container group with %.1f cpu %.1f mem", + self.cpu, + self.memory_in_gb, + ) + if self.gpu: + self.log.info( + "GPU count: %.1f, GPU SKU: %s", self.gpu.count, self.gpu.sku + ) + + resources = ResourceRequirements( + requests=ResourceRequests( + memory_in_gb=self.memory_in_gb, cpu=self.cpu, gpu=self.gpu + ) + ) + + if self.ip_address and not self.ports: + self.ports = [ContainerPort(port=80)] + self.log.info("Default port set. Container will listen on port 80") + + container = Container( + name=self.name, + image=self.image, + resources=resources, + command=self.command, + environment_variables=environment_variables, + volume_mounts=volume_mounts, + ports=self.ports, + ) + + container_group = ContainerGroup( + location=self.region, + containers=[ + container, + ], + image_registry_credentials=image_registry_credentials, + volumes=volumes, + restart_policy=self.restart_policy, + os_type=self.os_type, + tags=self.tags, + ip_address=self.ip_address, + ) + + self._ci_hook.create_or_update( + self.resource_group, self.name, container_group + ) + + self.log.info( + "Container group started %s/%s", self.resource_group, self.name + ) + + exit_code = self._monitor_logging(self.resource_group, self.name) + + self.log.info("Container had exit code: %s", exit_code) + if exit_code != 0: + raise AirflowException( + f"Container had a non-zero exit code, {exit_code}" + ) + return exit_code + + except CloudError: + self.log.exception("Could not start container group") + raise AirflowException("Could not start container group") + + finally: + if exit_code == 0 or self.remove_on_error: + self.on_kill() + + def on_kill(self) -> None: + if self.remove_on_error: + self.log.info("Deleting container group") + try: + self._ci_hook.delete(self.resource_group, self.name) + except Exception: # pylint: disable=broad-except + self.log.exception("Could not delete container group") + + def _monitor_logging(self, resource_group: str, name: str) -> int: + last_state = None + last_message_logged = None + last_line_logged = None + + # pylint: disable=too-many-nested-blocks + while True: + try: + cg_state = self._ci_hook.get_state(resource_group, name) + instance_view = cg_state.containers[0].instance_view + + # If there is no instance view, we show the provisioning state + if instance_view is not None: + c_state = instance_view.current_state + state, exit_code, detail_status = ( + c_state.state, + c_state.exit_code, + c_state.detail_status, + ) + + messages = [event.message for event in instance_view.events] + last_message_logged = self._log_last(messages, last_message_logged) + else: + state = cg_state.provisioning_state + exit_code = 0 + detail_status = "Provisioning" + + if state != last_state: + self.log.info("Container group state changed to %s", state) + last_state = state + + if state in ["Running", "Terminated"]: + try: + logs = self._ci_hook.get_logs(resource_group, name) + last_line_logged = self._log_last(logs, last_line_logged) + except CloudError: + self.log.exception( + "Exception while getting logs from container instance, retrying..." + ) + + if state == "Terminated": + self.log.error( + "Container exited with detail_status %s", detail_status + ) + return exit_code + + if state == "Failed": + self.log.error("Azure provision failure") + return 1 + + except AirflowTaskTimeout: + raise + except CloudError as err: + if "ResourceNotFound" in str(err): + self.log.warning( + "ResourceNotFound, container is probably removed " + "by another process " + "(make sure that the name is unique)." + ) + return 1 + else: + self.log.exception("Exception while getting container groups") + except Exception: # pylint: disable=broad-except + self.log.exception("Exception while getting container groups") + + sleep(1) + + def _log_last(self, logs: Optional[list], last_line_logged: Any) -> Optional[Any]: + if logs: + # determine the last line which was logged before + last_line_index = 0 + for i in range(len(logs) - 1, -1, -1): + if logs[i] == last_line_logged: + # this line is the same, hence print from i+1 + last_line_index = i + 1 + break + + # log all new ones + for line in logs[last_line_index:]: + self.log.info(line.rstrip()) + + return logs[-1] + return None + + @staticmethod + def _check_name(name: str) -> str: + if "{{" in name: + # Let macros pass as they cannot be checked at construction time + return name + regex_check = re.match("[a-z0-9]([-a-z0-9]*[a-z0-9])?", name) + if regex_check is None or regex_check.group() != name: + raise AirflowException( + 'ACI name must match regex [a-z0-9]([-a-z0-9]*[a-z0-9])? (like "my-name")' + ) + if len(name) > 63: + raise AirflowException("ACI name cannot be longer than 63 characters") + return name diff --git a/reference/providers/microsoft/azure/operators/azure_cosmos.py b/reference/providers/microsoft/azure/operators/azure_cosmos.py new file mode 100644 index 0000000..d74793c --- /dev/null +++ b/reference/providers/microsoft/azure/operators/azure_cosmos.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook +from airflow.utils.decorators import apply_defaults + + +class AzureCosmosInsertDocumentOperator(BaseOperator): + """ + Inserts a new document into the specified Cosmos database and collection + It will create both the database and collection if they do not already exist + + :param database_name: The name of the database. (templated) + :type database_name: str + :param collection_name: The name of the collection. (templated) + :type collection_name: str + :param document: The document to insert + :type document: dict + :param azure_cosmos_conn_id: reference to a CosmosDB connection. + :type azure_cosmos_conn_id: str + """ + + template_fields = ("database_name", "collection_name") + ui_color = "#e4f0e8" + + @apply_defaults + def __init__( + self, + *, + database_name: str, + collection_name: str, + document: dict, + azure_cosmos_conn_id: str = "azure_cosmos_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.database_name = database_name + self.collection_name = collection_name + self.document = document + self.azure_cosmos_conn_id = azure_cosmos_conn_id + + def execute(self, context: dict) -> None: + # Create the hook + hook = AzureCosmosDBHook(azure_cosmos_conn_id=self.azure_cosmos_conn_id) + + # Create the DB if it doesn't already exist + if not hook.does_database_exist(self.database_name): + hook.create_database(self.database_name) + + # Create the collection as well + if not hook.does_collection_exist(self.collection_name, self.database_name): + hook.create_collection(self.collection_name, self.database_name) + + # finally insert the document + hook.upsert_document(self.document, self.database_name, self.collection_name) diff --git a/reference/providers/microsoft/azure/operators/wasb_delete_blob.py b/reference/providers/microsoft/azure/operators/wasb_delete_blob.py new file mode 100644 index 0000000..9f5db6a --- /dev/null +++ b/reference/providers/microsoft/azure/operators/wasb_delete_blob.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Any + +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.wasb import WasbHook +from airflow.utils.decorators import apply_defaults + + +class WasbDeleteBlobOperator(BaseOperator): + """ + Deletes blob(s) on Azure Blob Storage. + + :param container_name: Name of the container. (templated) + :type container_name: str + :param blob_name: Name of the blob. (templated) + :type blob_name: str + :param wasb_conn_id: Reference to the wasb connection. + :type wasb_conn_id: str + :param check_options: Optional keyword arguments that + `WasbHook.check_for_blob()` takes. + :param is_prefix: If blob_name is a prefix, delete all files matching prefix. + :type is_prefix: bool + :param ignore_if_missing: if True, then return success even if the + blob does not exist. + :type ignore_if_missing: bool + """ + + template_fields = ("container_name", "blob_name") + + @apply_defaults + def __init__( + self, + *, + container_name: str, + blob_name: str, + wasb_conn_id: str = "wasb_default", + check_options: Any = None, + is_prefix: bool = False, + ignore_if_missing: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if check_options is None: + check_options = {} + self.wasb_conn_id = wasb_conn_id + self.container_name = container_name + self.blob_name = blob_name + self.check_options = check_options + self.is_prefix = is_prefix + self.ignore_if_missing = ignore_if_missing + + def execute(self, context: dict) -> None: + self.log.info( + "Deleting blob: %s\nin wasb://%s", self.blob_name, self.container_name + ) + hook = WasbHook(wasb_conn_id=self.wasb_conn_id) + + hook.delete_file( + self.container_name, + self.blob_name, + self.is_prefix, + self.ignore_if_missing, + **self.check_options, + ) diff --git a/reference/providers/microsoft/azure/provider.yaml b/reference/providers/microsoft/azure/provider.yaml new file mode 100644 index 0000000..6829470 --- /dev/null +++ b/reference/providers/microsoft/azure/provider.yaml @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-microsoft-azure +name: Microsoft Azure +description: | + `Microsoft Azure `__ + +versions: + - 1.2.0 + - 1.1.0 + - 1.0.0 + +integrations: + - integration-name: Microsoft Azure Batch + external-doc-url: https://azure.microsoft.com/en-us/services/batch/ + logo: /integration-logos/azure/Microsoft-Azure-Batch.png + tags: [azure] + - integration-name: Microsoft Azure Blob Storage + external-doc-url: https://azure.microsoft.com/en-us/services/storage/blobs/ + logo: /integration-logos/azure/Blob Storage.svg + tags: [azure] + - integration-name: Microsoft Azure Container Instances + external-doc-url: https://azure.microsoft.com/en-us/services/container-instances/ + logo: /integration-logos/azure/Container Instances.svg + tags: [azure] + - integration-name: Microsoft Azure Cosmos DB + external-doc-url: https://azure.microsoft.com/en-us/services/cosmos-db/ + logo: /integration-logos/azure/Azure Cosmos DB.svg + tags: [azure] + - integration-name: Microsoft Azure Data Explorer + external-doc-url: https://azure.microsoft.com/en-us/services/data-explorer/ + logo: /integration-logos/azure/Microsoft-Azure-Data-Explorer.png + tags: [azure] + - integration-name: Microsoft Azure Data Lake Storage + how-to-guide: + - /docs/apache-airflow-providers-microsoft-azure/operators/adls.rst + external-doc-url: https://azure.microsoft.com/en-us/services/storage/data-lake-storage/ + logo: /integration-logos/azure/Data Lake Storage.svg + tags: [azure] + - integration-name: Microsoft Azure Files + external-doc-url: https://azure.microsoft.com/en-us/services/storage/files/ + logo: /integration-logos/azure/Azure Files.svg + tags: [azure] + - integration-name: Microsoft Azure FileShare + external-doc-url: https://cloud.google.com/storage/ + logo: /integration-logos/azure/Microsoft-Azure-Fileshare.png + tags: [azure] + - integration-name: Microsoft Azure Data Factory + external-doc-url: https://azure.microsoft.com/en-us/services/data-factory/ + logo: /integration-logos/azure/Azure Data Factory.svg + tags: [azure] + - integration-name: Microsoft Azure + external-doc-url: https://azure.microsoft.com/ + logo: /integration-logos/azure/Microsoft-Azure.png + tags: [azure] + +operators: + - integration-name: Microsoft Azure Data Lake Storage + python-modules: + - airflow.providers.microsoft.azure.operators.adls_list + - airflow.providers.microsoft.azure.operators.adls_delete + - integration-name: Microsoft Azure Data Explorer + python-modules: + - airflow.providers.microsoft.azure.operators.adx + - integration-name: Microsoft Azure Batch + python-modules: + - airflow.providers.microsoft.azure.operators.azure_batch + - integration-name: Microsoft Azure Container Instances + python-modules: + - airflow.providers.microsoft.azure.operators.azure_container_instances + - integration-name: Microsoft Azure Cosmos DB + python-modules: + - airflow.providers.microsoft.azure.operators.azure_cosmos + - integration-name: Microsoft Azure Blob Storage + python-modules: + - airflow.providers.microsoft.azure.operators.wasb_delete_blob + +sensors: + - integration-name: Microsoft Azure Cosmos DB + python-modules: + - airflow.providers.microsoft.azure.sensors.azure_cosmos + - integration-name: Microsoft Azure Blob Storage + python-modules: + - airflow.providers.microsoft.azure.sensors.wasb + +hooks: + - integration-name: Microsoft Azure Container Instances + python-modules: + - airflow.providers.microsoft.azure.hooks.azure_container_volume + - airflow.providers.microsoft.azure.hooks.azure_container_registry + - airflow.providers.microsoft.azure.hooks.azure_container_instance + - integration-name: Microsoft Azure Data Explorer + python-modules: + - airflow.providers.microsoft.azure.hooks.adx + - integration-name: Microsoft Azure FileShare + python-modules: + - airflow.providers.microsoft.azure.hooks.azure_fileshare + - integration-name: Microsoft Azure + python-modules: + - airflow.providers.microsoft.azure.hooks.base_azure + - integration-name: Microsoft Azure Batch + python-modules: + - airflow.providers.microsoft.azure.hooks.azure_batch + - integration-name: Microsoft Azure Data Lake Storage + python-modules: + - airflow.providers.microsoft.azure.hooks.azure_data_lake + - integration-name: Microsoft Azure Cosmos DB + python-modules: + - airflow.providers.microsoft.azure.hooks.azure_cosmos + - integration-name: Microsoft Azure Blob Storage + python-modules: + - airflow.providers.microsoft.azure.hooks.wasb + - integration-name: Microsoft Azure Data Factory + python-modules: + - airflow.providers.microsoft.azure.hooks.azure_data_factory + +transfers: + - source-integration-name: Local + target-integration-name: Microsoft Azure Data Lake Storage + how-to-guide: /docs/apache-airflow-providers-microsoft-azure/operators/local_to_adls.rst + python-module: airflow.providers.microsoft.azure.transfers.local_to_adls + - source-integration-name: Oracle + target-integration-name: Microsoft Azure Data Lake Storage + python-module: airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake + - source-integration-name: Local + target-integration-name: Microsoft Azure Blob Storage + python-module: airflow.providers.microsoft.azure.transfers.file_to_wasb + - source-integration-name: Microsoft Azure Blob Storage + target-integration-name: Google Cloud Storage (GCS) + how-to-guide: /docs/apache-airflow-providers-microsoft-azure/operators/azure_blob_to_gcs.rst + python-module: airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs + +hook-class-names: + - airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook + - airflow.providers.microsoft.azure.hooks.adx.AzureDataExplorerHook + - airflow.providers.microsoft.azure.hooks.azure_batch.AzureBatchHook + - airflow.providers.microsoft.azure.hooks.azure_cosmos.AzureCosmosDBHook + - airflow.providers.microsoft.azure.hooks.azure_data_lake.AzureDataLakeHook + - airflow.providers.microsoft.azure.hooks.azure_container_instance.AzureContainerInstanceHook + - airflow.providers.microsoft.azure.hooks.wasb.WasbHook + - airflow.providers.microsoft.azure.hooks.azure_data_factory.AzureDataFactoryHook diff --git a/reference/providers/microsoft/azure/secrets/__init__.py b/reference/providers/microsoft/azure/secrets/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/microsoft/azure/secrets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/azure/secrets/azure_key_vault.py b/reference/providers/microsoft/azure/secrets/azure_key_vault.py new file mode 100644 index 0000000..f982dbe --- /dev/null +++ b/reference/providers/microsoft/azure/secrets/azure_key_vault.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from azure.core.exceptions import ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.keyvault.secrets import SecretClient + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.secrets import BaseSecretsBackend +from airflow.utils.log.logging_mixin import LoggingMixin + + +class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin): + """ + Retrieves Airflow Connections or Variables from Azure Key Vault secrets. + + The Azure Key Vault can be configured as a secrets backend in the ``airflow.cfg``: + + .. code-block:: ini + + [secrets] + backend = airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend + backend_kwargs = {"connections_prefix": "airflow-connections", "vault_url": ""} + + For example, if the secrets prefix is ``airflow-connections-smtp-default``, this would be accessible + if you provide ``{"connections_prefix": "airflow-connections"}`` and request conn_id ``smtp-default``. + And if variables prefix is ``airflow-variables-hello``, this would be accessible + if you provide ``{"variables_prefix": "airflow-variables"}`` and request variable key ``hello``. + + :param connections_prefix: Specifies the prefix of the secret to read to get Connections + If set to None (null), requests for connections will not be sent to Azure Key Vault + :type connections_prefix: str + :param variables_prefix: Specifies the prefix of the secret to read to get Variables + If set to None (null), requests for variables will not be sent to Azure Key Vault + :type variables_prefix: str + :param config_prefix: Specifies the prefix of the secret to read to get Variables. + If set to None (null), requests for configurations will not be sent to Azure Key Vault + :type config_prefix: str + :param vault_url: The URL of an Azure Key Vault to use + :type vault_url: str + :param sep: separator used to concatenate secret_prefix and secret_id. Default: "-" + :type sep: str + """ + + def __init__( + self, + connections_prefix: str = "airflow-connections", + variables_prefix: str = "airflow-variables", + config_prefix: str = "airflow-config", + vault_url: str = "", + sep: str = "-", + **kwargs, + ) -> None: + super().__init__() + self.vault_url = vault_url + if connections_prefix is not None: + self.connections_prefix = connections_prefix.rstrip(sep) + else: + self.connections_prefix = connections_prefix + if variables_prefix is not None: + self.variables_prefix = variables_prefix.rstrip(sep) + else: + self.variables_prefix = variables_prefix + if config_prefix is not None: + self.config_prefix = config_prefix.rstrip(sep) + else: + self.config_prefix = config_prefix + self.sep = sep + self.kwargs = kwargs + + @cached_property + def client(self) -> SecretClient: + """Create a Azure Key Vault client.""" + credential = DefaultAzureCredential() + client = SecretClient( + vault_url=self.vault_url, credential=credential, **self.kwargs + ) + return client + + def get_conn_uri(self, conn_id: str) -> Optional[str]: + """ + Get an Airflow Connection URI from an Azure Key Vault secret + + :param conn_id: The Airflow connection id to retrieve + :type conn_id: str + """ + if self.connections_prefix is None: + return None + + return self._get_secret(self.connections_prefix, conn_id) + + def get_variable(self, key: str) -> Optional[str]: + """ + Get an Airflow Variable from an Azure Key Vault secret. + + :param key: Variable Key + :type key: str + :return: Variable Value + """ + if self.variables_prefix is None: + return None + + return self._get_secret(self.variables_prefix, key) + + def get_config(self, key: str) -> Optional[str]: + """ + Get Airflow Configuration + + :param key: Configuration Option Key + :return: Configuration Option Value + """ + if self.config_prefix is None: + return None + + return self._get_secret(self.config_prefix, key) + + @staticmethod + def build_path(path_prefix: str, secret_id: str, sep: str = "-") -> str: + """ + Given a path_prefix and secret_id, build a valid secret name for the Azure Key Vault Backend. + Also replaces underscore in the path with dashes to support easy switching between + environment variables, so ``connection_default`` becomes ``connection-default``. + + :param path_prefix: The path prefix of the secret to retrieve + :type path_prefix: str + :param secret_id: Name of the secret + :type secret_id: str + :param sep: Separator used to concatenate path_prefix and secret_id + :type sep: str + """ + path = f"{path_prefix}{sep}{secret_id}" + return path.replace("_", sep) + + def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]: + """ + Get an Azure Key Vault secret value + + :param path_prefix: Prefix for the Path to get Secret + :type path_prefix: str + :param secret_id: Secret Key + :type secret_id: str + """ + name = self.build_path(path_prefix, secret_id, self.sep) + try: + secret = self.client.get_secret(name=name) + return secret.value + except ResourceNotFoundError as ex: + self.log.debug("Secret %s not found: %s", name, ex) + return None diff --git a/reference/providers/microsoft/azure/sensors/__init__.py b/reference/providers/microsoft/azure/sensors/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/microsoft/azure/sensors/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/azure/sensors/azure_cosmos.py b/reference/providers/microsoft/azure/sensors/azure_cosmos.py new file mode 100644 index 0000000..5effb70 --- /dev/null +++ b/reference/providers/microsoft/azure/sensors/azure_cosmos.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class AzureCosmosDocumentSensor(BaseSensorOperator): + """ + Checks for the existence of a document which + matches the given query in CosmosDB. Example: + + >>> azure_cosmos_sensor = AzureCosmosDocumentSensor(database_name="somedatabase_name", + ... collection_name="somecollection_name", + ... document_id="unique-doc-id", + ... azure_cosmos_conn_id="azure_cosmos_default", + ... task_id="azure_cosmos_sensor") + + :param database_name: Target CosmosDB database_name. + :type database_name: str + :param collection_name: Target CosmosDB collection_name. + :type collection_name: str + :param document_id: The ID of the target document. + :type query: str + :param azure_cosmos_conn_id: Reference to the Azure CosmosDB connection. + :type azure_cosmos_conn_id: str + """ + + template_fields = ("database_name", "collection_name", "document_id") + + @apply_defaults + def __init__( + self, + *, + database_name: str, + collection_name: str, + document_id: str, + azure_cosmos_conn_id: str = "azure_cosmos_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.azure_cosmos_conn_id = azure_cosmos_conn_id + self.database_name = database_name + self.collection_name = collection_name + self.document_id = document_id + + def poke(self, context: dict) -> bool: + self.log.info("*** Intering poke") + hook = AzureCosmosDBHook(self.azure_cosmos_conn_id) + return ( + hook.get_document( + self.document_id, self.database_name, self.collection_name + ) + is not None + ) diff --git a/reference/providers/microsoft/azure/sensors/wasb.py b/reference/providers/microsoft/azure/sensors/wasb.py new file mode 100644 index 0000000..71cc09c --- /dev/null +++ b/reference/providers/microsoft/azure/sensors/wasb.py @@ -0,0 +1,113 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Optional + +from airflow.providers.microsoft.azure.hooks.wasb import WasbHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class WasbBlobSensor(BaseSensorOperator): + """ + Waits for a blob to arrive on Azure Blob Storage. + + :param container_name: Name of the container. + :type container_name: str + :param blob_name: Name of the blob. + :type blob_name: str + :param wasb_conn_id: Reference to the wasb connection. + :type wasb_conn_id: str + :param check_options: Optional keyword arguments that + `WasbHook.check_for_blob()` takes. + :type check_options: dict + """ + + template_fields = ("container_name", "blob_name") + + @apply_defaults + def __init__( + self, + *, + container_name: str, + blob_name: str, + wasb_conn_id: str = "wasb_default", + check_options: Optional[dict] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if check_options is None: + check_options = {} + self.wasb_conn_id = wasb_conn_id + self.container_name = container_name + self.blob_name = blob_name + self.check_options = check_options + + def poke(self, context: dict): + self.log.info( + "Poking for blob: %s\nin wasb://%s", self.blob_name, self.container_name + ) + hook = WasbHook(wasb_conn_id=self.wasb_conn_id) + return hook.check_for_blob( + self.container_name, self.blob_name, **self.check_options + ) + + +class WasbPrefixSensor(BaseSensorOperator): + """ + Waits for blobs matching a prefix to arrive on Azure Blob Storage. + + :param container_name: Name of the container. + :type container_name: str + :param prefix: Prefix of the blob. + :type prefix: str + :param wasb_conn_id: Reference to the wasb connection. + :type wasb_conn_id: str + :param check_options: Optional keyword arguments that + `WasbHook.check_for_prefix()` takes. + :type check_options: dict + """ + + template_fields = ("container_name", "prefix") + + @apply_defaults + def __init__( + self, + *, + container_name: str, + prefix: str, + wasb_conn_id: str = "wasb_default", + check_options: Optional[dict] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if check_options is None: + check_options = {} + self.wasb_conn_id = wasb_conn_id + self.container_name = container_name + self.prefix = prefix + self.check_options = check_options + + def poke(self, context: dict) -> bool: + self.log.info( + "Poking for prefix: %s in wasb://%s", self.prefix, self.container_name + ) + hook = WasbHook(wasb_conn_id=self.wasb_conn_id) + return hook.check_for_prefix( + self.container_name, self.prefix, **self.check_options + ) diff --git a/reference/providers/microsoft/azure/transfers/__init__.py b/reference/providers/microsoft/azure/transfers/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/microsoft/azure/transfers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/azure/transfers/azure_blob_to_gcs.py b/reference/providers/microsoft/azure/transfers/azure_blob_to_gcs.py new file mode 100644 index 0000000..605288d --- /dev/null +++ b/reference/providers/microsoft/azure/transfers/azure_blob_to_gcs.py @@ -0,0 +1,139 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import tempfile +from typing import Optional, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.microsoft.azure.hooks.wasb import WasbHook +from airflow.utils.decorators import apply_defaults + + +class AzureBlobStorageToGCSOperator(BaseOperator): + """ + Operator transfers data from Azure Blob Storage to specified bucket in Google Cloud Storage + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureBlobStorageToGCSOperator` + + :param wasb_conn_id: Reference to the wasb connection. + :type wasb_conn_id: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param blob_name: Name of the blob + :type blob_name: str + :param file_path: Path to the file to download + :type file_path: str + :param container_name: Name of the container + :type container_name: str + :param bucket_name: The bucket to upload to + :type bucket_name: str + :param object_name: The object name to set when uploading the file + :type object_name: str + :param filename: The local file path to the file to be uploaded + :type filename: str + :param gzip: Option to compress local file or file data for upload + :type gzip: bool + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + :type impersonation_chain: Union[str, Sequence[str]] + """ + + @apply_defaults + def __init__( + self, + *, + wasb_conn_id="wasb_default", + gcp_conn_id: str = "google_cloud_default", + blob_name: str, + file_path: str, + container_name: str, + bucket_name: str, + object_name: str, + filename: str, + gzip: bool, + delegate_to: Optional[str], + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.wasb_conn_id = wasb_conn_id + self.gcp_conn_id = gcp_conn_id + self.blob_name = blob_name + self.file_path = file_path + self.container_name = container_name + self.bucket_name = bucket_name + self.object_name = object_name + self.filename = filename + self.gzip = gzip + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + template_fields = ( + "blob_name", + "file_path", + "container_name", + "bucket_name", + "object_name", + "filename", + ) + + def execute(self, context: dict) -> str: + azure_hook = WasbHook(wasb_conn_id=self.wasb_conn_id) + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + with tempfile.NamedTemporaryFile() as temp_file: + self.log.info("Downloading data from blob: %s", self.blob_name) + azure_hook.get_file( + file_path=temp_file.name, + container_name=self.container_name, + blob_name=self.blob_name, + ) + self.log.info( + "Uploading data from blob's: %s into GCP bucket: %s", + self.object_name, + self.bucket_name, + ) + gcs_hook.upload( + bucket_name=self.bucket_name, + object_name=self.object_name, + filename=temp_file.name, + gzip=self.gzip, + ) + self.log.info( + "Resources have been uploaded from blob: %s to GCS bucket:%s", + self.blob_name, + self.bucket_name, + ) + return f"gs://{self.bucket_name}/{self.object_name}" diff --git a/reference/providers/microsoft/azure/transfers/file_to_wasb.py b/reference/providers/microsoft/azure/transfers/file_to_wasb.py new file mode 100644 index 0000000..ff38cc8 --- /dev/null +++ b/reference/providers/microsoft/azure/transfers/file_to_wasb.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Optional + +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.wasb import WasbHook +from airflow.utils.decorators import apply_defaults + + +class FileToWasbOperator(BaseOperator): + """ + Uploads a file to Azure Blob Storage. + + :param file_path: Path to the file to load. (templated) + :type file_path: str + :param container_name: Name of the container. (templated) + :type container_name: str + :param blob_name: Name of the blob. (templated) + :type blob_name: str + :param wasb_conn_id: Reference to the wasb connection. + :type wasb_conn_id: str + :param load_options: Optional keyword arguments that + `WasbHook.load_file()` takes. + :type load_options: Optional[dict] + """ + + template_fields = ("file_path", "container_name", "blob_name") + + @apply_defaults + def __init__( + self, + *, + file_path: str, + container_name: str, + blob_name: str, + wasb_conn_id: str = "wasb_default", + load_options: Optional[dict] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if load_options is None: + load_options = {} + self.file_path = file_path + self.container_name = container_name + self.blob_name = blob_name + self.wasb_conn_id = wasb_conn_id + self.load_options = load_options + + def execute(self, context: dict) -> None: + """Upload a file to Azure Blob Storage.""" + hook = WasbHook(wasb_conn_id=self.wasb_conn_id) + self.log.info( + "Uploading %s to wasb://%s as %s", + self.file_path, + self.container_name, + self.blob_name, + ) + hook.load_file( + self.file_path, self.container_name, self.blob_name, **self.load_options + ) diff --git a/reference/providers/microsoft/azure/transfers/local_to_adls.py b/reference/providers/microsoft/azure/transfers/local_to_adls.py new file mode 100644 index 0000000..df17ef6 --- /dev/null +++ b/reference/providers/microsoft/azure/transfers/local_to_adls.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook +from airflow.utils.decorators import apply_defaults + + +class LocalToAzureDataLakeStorageOperator(BaseOperator): + """ + Upload file(s) to Azure Data Lake + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:LocalToAzureDataLakeStorageOperator` + + :param local_path: local path. Can be single file, directory (in which case, + upload recursively) or glob pattern. Recursive glob patterns using `**` + are not supported + :type local_path: str + :param remote_path: Remote path to upload to; if multiple files, this is the + directory root to write within + :type remote_path: str + :param nthreads: Number of threads to use. If None, uses the number of cores. + :type nthreads: int + :param overwrite: Whether to forcibly overwrite existing files/directories. + If False and remote path is a directory, will quit regardless if any files + would be overwritten or not. If True, only matching filenames are actually + overwritten + :type overwrite: bool + :param buffersize: int [2**22] + Number of bytes for internal buffer. This block cannot be bigger than + a chunk and cannot be smaller than a block + :type buffersize: int + :param blocksize: int [2**22] + Number of bytes for a block. Within each chunk, we write a smaller + block for each API call. This block cannot be bigger than a chunk + :type blocksize: int + :param extra_upload_options: Extra upload options to add to the hook upload method + :type extra_upload_options: dict + :param azure_data_lake_conn_id: Reference to the Azure Data Lake connection + :type azure_data_lake_conn_id: str + """ + + template_fields = ("local_path", "remote_path") + ui_color = "#e4f0e8" + + @apply_defaults + def __init__( + self, + *, + local_path: str, + remote_path: str, + overwrite: bool = True, + nthreads: int = 64, + buffersize: int = 4194304, + blocksize: int = 4194304, + extra_upload_options: Optional[Dict[str, Any]] = None, + azure_data_lake_conn_id: str = "azure_data_lake_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.local_path = local_path + self.remote_path = remote_path + self.overwrite = overwrite + self.nthreads = nthreads + self.buffersize = buffersize + self.blocksize = blocksize + self.extra_upload_options = extra_upload_options + self.azure_data_lake_conn_id = azure_data_lake_conn_id + + def execute(self, context: dict) -> None: + if "**" in self.local_path: + raise AirflowException( + "Recursive glob patterns using `**` are not supported" + ) + if not self.extra_upload_options: + self.extra_upload_options = {} + hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id) + self.log.info("Uploading %s to %s", self.local_path, self.remote_path) + return hook.upload_file( + local_path=self.local_path, + remote_path=self.remote_path, + nthreads=self.nthreads, + overwrite=self.overwrite, + buffersize=self.buffersize, + blocksize=self.blocksize, + **self.extra_upload_options, + ) diff --git a/reference/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py b/reference/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py new file mode 100644 index 0000000..2858820 --- /dev/null +++ b/reference/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from tempfile import TemporaryDirectory +from typing import Any, Optional, Union + +import unicodecsv as csv +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook +from airflow.providers.oracle.hooks.oracle import OracleHook +from airflow.utils.decorators import apply_defaults + + +class OracleToAzureDataLakeOperator(BaseOperator): + """ + Moves data from Oracle to Azure Data Lake. The operator runs the query against + Oracle and stores the file locally before loading it into Azure Data Lake. + + + :param filename: file name to be used by the csv file. + :type filename: str + :param azure_data_lake_conn_id: destination azure data lake connection. + :type azure_data_lake_conn_id: str + :param azure_data_lake_path: destination path in azure data lake to put the file. + :type azure_data_lake_path: str + :param oracle_conn_id: source Oracle connection. + :type oracle_conn_id: str + :param sql: SQL query to execute against the Oracle database. (templated) + :type sql: str + :param sql_params: Parameters to use in sql query. (templated) + :type sql_params: Optional[dict] + :param delimiter: field delimiter in the file. + :type delimiter: str + :param encoding: encoding type for the file. + :type encoding: str + :param quotechar: Character to use in quoting. + :type quotechar: str + :param quoting: Quoting strategy. See unicodecsv quoting for more information. + :type quoting: str + """ + + template_fields = ("filename", "sql", "sql_params") + ui_color = "#e08c8c" + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + filename: str, + azure_data_lake_conn_id: str, + azure_data_lake_path: str, + oracle_conn_id: str, + sql: str, + sql_params: Optional[dict] = None, + delimiter: str = ",", + encoding: str = "utf-8", + quotechar: str = '"', + quoting: str = csv.QUOTE_MINIMAL, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if sql_params is None: + sql_params = {} + self.filename = filename + self.oracle_conn_id = oracle_conn_id + self.sql = sql + self.sql_params = sql_params + self.azure_data_lake_conn_id = azure_data_lake_conn_id + self.azure_data_lake_path = azure_data_lake_path + self.delimiter = delimiter + self.encoding = encoding + self.quotechar = quotechar + self.quoting = quoting + + def _write_temp_file( + self, cursor: Any, path_to_save: Union[str, bytes, int] + ) -> None: + with open(path_to_save, "wb") as csvfile: + csv_writer = csv.writer( + csvfile, + delimiter=self.delimiter, + encoding=self.encoding, + quotechar=self.quotechar, + quoting=self.quoting, + ) + csv_writer.writerow(map(lambda field: field[0], cursor.description)) + csv_writer.writerows(cursor) + csvfile.flush() + + def execute(self, context: dict) -> None: + oracle_hook = OracleHook(oracle_conn_id=self.oracle_conn_id) + azure_data_lake_hook = AzureDataLakeHook( + azure_data_lake_conn_id=self.azure_data_lake_conn_id + ) + + self.log.info("Dumping Oracle query results to local file") + conn = oracle_hook.get_conn() + cursor = conn.cursor() # type: ignore[attr-defined] + cursor.execute(self.sql, self.sql_params) + + with TemporaryDirectory(prefix="airflow_oracle_to_azure_op_") as temp: + self._write_temp_file(cursor, os.path.join(temp, self.filename)) + self.log.info("Uploading local file to Azure Data Lake") + azure_data_lake_hook.upload_file( + os.path.join(temp, self.filename), + os.path.join(self.azure_data_lake_path, self.filename), + ) + cursor.close() + conn.close() # type: ignore[attr-defined] diff --git a/reference/providers/microsoft/mssql/CHANGELOG.rst b/reference/providers/microsoft/mssql/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/microsoft/mssql/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/microsoft/mssql/__init__.py b/reference/providers/microsoft/mssql/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/microsoft/mssql/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/mssql/hooks/__init__.py b/reference/providers/microsoft/mssql/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/microsoft/mssql/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/mssql/hooks/mssql.py b/reference/providers/microsoft/mssql/hooks/mssql.py new file mode 100644 index 0000000..9ff5a0e --- /dev/null +++ b/reference/providers/microsoft/mssql/hooks/mssql.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Microsoft SQLServer hook module""" + +import pymssql +from airflow.hooks.dbapi import DbApiHook + + +class MsSqlHook(DbApiHook): + """Interact with Microsoft SQL Server.""" + + conn_name_attr = "mssql_conn_id" + default_conn_name = "mssql_default" + conn_type = "mssql" + hook_name = "Microsoft SQL Server" + supports_autocommit = True + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.schema = kwargs.pop("schema", None) + + def get_conn( + self, + ) -> pymssql.connect: # pylint: disable=protected-access # pylint: disable=c-extension-no-member + """Returns a mssql connection object""" + conn = self.get_connection( + self.mssql_conn_id # type: ignore[attr-defined] # pylint: disable=no-member + ) + # pylint: disable=c-extension-no-member + conn = pymssql.connect( + server=conn.host, + user=conn.login, + password=conn.password, + database=self.schema or conn.schema, + port=conn.port, + ) + return conn + + def set_autocommit( + self, + conn: pymssql.connect, # pylint: disable=c-extension-no-member + autocommit: bool, + ) -> None: + conn.autocommit(autocommit) + + def get_autocommit( + self, conn: pymssql.connect + ): # pylint: disable=c-extension-no-member + return conn.autocommit_state diff --git a/reference/providers/microsoft/mssql/operators/__init__.py b/reference/providers/microsoft/mssql/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/microsoft/mssql/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/mssql/operators/mssql.py b/reference/providers/microsoft/mssql/operators/mssql.py new file mode 100644 index 0000000..88ec9dd --- /dev/null +++ b/reference/providers/microsoft/mssql/operators/mssql.py @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Iterable, Mapping, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook +from airflow.providers.odbc.hooks.odbc import OdbcHook +from airflow.utils.decorators import apply_defaults + + +class MsSqlOperator(BaseOperator): + """ + Executes sql code in a specific Microsoft SQL database + + This operator may use one of two hooks, depending on the ``conn_type`` of the connection. + + If conn_type is ``'odbc'``, then :py:class:`~airflow.providers.odbc.hooks.odbc.OdbcHook` + is used. Otherwise, :py:class:`~airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook` is used. + + :param sql: the sql code to be executed + :type sql: str or string pointing to a template file with .sql + extension. (templated) + :param mssql_conn_id: reference to a specific mssql database + :type mssql_conn_id: str + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: dict or iterable + :param autocommit: if True, each command is automatically committed. + (default value: False) + :type autocommit: bool + :param database: name of database which overwrite defined one in connection + :type database: str + """ + + template_fields = ("sql",) + template_ext = (".sql",) + ui_color = "#ededed" + + @apply_defaults + def __init__( + self, + *, + sql: str, + mssql_conn_id: str = "mssql_default", + parameters: Optional[Union[Mapping, Iterable]] = None, + autocommit: bool = False, + database: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.mssql_conn_id = mssql_conn_id + self.sql = sql + self.parameters = parameters + self.autocommit = autocommit + self.database = database + self._hook: Optional[Union[MsSqlHook, OdbcHook]] = None + + def get_hook(self) -> Optional[Union[MsSqlHook, OdbcHook]]: + """ + Will retrieve hook as determined by Connection. + + If conn_type is ``'odbc'``, will use + :py:class:`~airflow.providers.odbc.hooks.odbc.OdbcHook`. + Otherwise, :py:class:`~airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook` will be used. + """ + if not self._hook: + conn = MsSqlHook.get_connection(conn_id=self.mssql_conn_id) + try: + self._hook = conn.get_hook() + self._hook.schema = self.database # type: ignore[union-attr] + except AirflowException: + self._hook = MsSqlHook( + mssql_conn_id=self.mssql_conn_id, schema=self.database + ) + return self._hook + + def execute(self, context: dict) -> None: + self.log.info("Executing: %s", self.sql) + hook = self.get_hook() + hook.run( # type: ignore[union-attr] + sql=self.sql, autocommit=self.autocommit, parameters=self.parameters + ) diff --git a/reference/providers/microsoft/mssql/provider.yaml b/reference/providers/microsoft/mssql/provider.yaml new file mode 100644 index 0000000..22e7e7a --- /dev/null +++ b/reference/providers/microsoft/mssql/provider.yaml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-microsoft-mssql +name: Microsoft SQL Server (MSSQL) +description: | + `Microsoft SQL Server (MSSQL) `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Microsoft SQL Server (MSSQL) + external-doc-url: https://www.microsoft.com/en-us/sql-server/sql-server-downloads + logo: /integration-logos/mssql/Microsoft-SQL-Server.png + tags: [software] + +operators: + - integration-name: Microsoft SQL Server (MSSQL) + python-modules: + - airflow.providers.microsoft.mssql.operators.mssql + +hooks: + - integration-name: Microsoft SQL Server (MSSQL) + python-modules: + - airflow.providers.microsoft.mssql.hooks.mssql + +hook-class-names: + - airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook diff --git a/reference/providers/microsoft/winrm/CHANGELOG.rst b/reference/providers/microsoft/winrm/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/microsoft/winrm/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/microsoft/winrm/__init__.py b/reference/providers/microsoft/winrm/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/microsoft/winrm/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/winrm/example_dags/__init__.py b/reference/providers/microsoft/winrm/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/microsoft/winrm/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/winrm/example_dags/example_winrm.py b/reference/providers/microsoft/winrm/example_dags/example_winrm.py new file mode 100644 index 0000000..0980586 --- /dev/null +++ b/reference/providers/microsoft/winrm/example_dags/example_winrm.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# -------------------------------------------------------------------------------- +# Written By: Ekhtiar Syed +# Last Update: 8th April 2016 +# Caveat: This Dag will not run because of missing scripts. +# The purpose of this is to give you a sample of a real world example DAG! +# -------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------- +# Load The Dependencies +# -------------------------------------------------------------------------------- +""" +This is an example dag for using the WinRMOperator. +""" +from datetime import timedelta + +from airflow import DAG +from airflow.operators.dummy import DummyOperator +from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook +from airflow.providers.microsoft.winrm.operators.winrm import WinRMOperator +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", +} + +with DAG( + dag_id="POC_winrm_parallel", + default_args=default_args, + schedule_interval="0 0 * * *", + start_date=days_ago(2), + dagrun_timeout=timedelta(minutes=60), + tags=["example"], +) as dag: + + cmd = "ls -l" + run_this_last = DummyOperator(task_id="run_this_last") + + winRMHook = WinRMHook(ssh_conn_id="ssh_POC1") + + t1 = WinRMOperator(task_id="wintask1", command="ls -altr", winrm_hook=winRMHook) + + t2 = WinRMOperator(task_id="wintask2", command="sleep 60", winrm_hook=winRMHook) + + t3 = WinRMOperator( + task_id="wintask3", command="echo 'luke test' ", winrm_hook=winRMHook + ) + + [t1, t2, t3] >> run_this_last diff --git a/reference/providers/microsoft/winrm/hooks/__init__.py b/reference/providers/microsoft/winrm/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/microsoft/winrm/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/winrm/hooks/winrm.py b/reference/providers/microsoft/winrm/hooks/winrm.py new file mode 100644 index 0000000..e707ba2 --- /dev/null +++ b/reference/providers/microsoft/winrm/hooks/winrm.py @@ -0,0 +1,247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""Hook for winrm remote execution.""" +import getpass +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from winrm.protocol import Protocol + + +# TODO: Fixme please - I have too complex implementation +# pylint: disable=too-many-instance-attributes,too-many-arguments,too-many-branches +class WinRMHook(BaseHook): + """ + Hook for winrm remote execution using pywinrm. + + :seealso: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py + + :param ssh_conn_id: connection id from airflow Connections from where + all the required parameters can be fetched like username and password. + Thought the priority is given to the param passed during init + :type ssh_conn_id: str + :param endpoint: When not set, endpoint will be constructed like this: + 'http://{remote_host}:{remote_port}/wsman' + :type endpoint: str + :param remote_host: Remote host to connect to. Ignored if `endpoint` is set. + :type remote_host: str + :param remote_port: Remote port to connect to. Ignored if `endpoint` is set. + :type remote_port: int + :param transport: transport type, one of 'plaintext' (default), 'kerberos', 'ssl', 'ntlm', 'credssp' + :type transport: str + :param username: username to connect to the remote_host + :type username: str + :param password: password of the username to connect to the remote_host + :type password: str + :param service: the service name, default is HTTP + :type service: str + :param keytab: the path to a keytab file if you are using one + :type keytab: str + :param ca_trust_path: Certification Authority trust path + :type ca_trust_path: str + :param cert_pem: client authentication certificate file path in PEM format + :type cert_pem: str + :param cert_key_pem: client authentication certificate key file path in PEM format + :type cert_key_pem: str + :param server_cert_validation: whether server certificate should be validated on + Python versions that support it; one of 'validate' (default), 'ignore' + :type server_cert_validation: str + :param kerberos_delegation: if True, TGT is sent to target server to + allow multiple hops + :type kerberos_delegation: bool + :param read_timeout_sec: maximum seconds to wait before an HTTP connect/read times out (default 30). + This value should be slightly higher than operation_timeout_sec, + as the server can block *at least* that long. + :type read_timeout_sec: int + :param operation_timeout_sec: maximum allowed time in seconds for any single wsman + HTTP operation (default 20). Note that operation timeouts while receiving output + (the only wsman operation that should take any significant time, + and where these timeouts are expected) will be silently retried indefinitely. + :type operation_timeout_sec: int + :param kerberos_hostname_override: the hostname to use for the kerberos exchange + (defaults to the hostname in the endpoint URL) + :type kerberos_hostname_override: str + :param message_encryption: Will encrypt the WinRM messages if set + and the transport auth supports message encryption. (Default 'auto') + :type message_encryption: str + :param credssp_disable_tlsv1_2: Whether to disable TLSv1.2 support and work with older + protocols like TLSv1.0, default is False + :type credssp_disable_tlsv1_2: bool + :param send_cbt: Will send the channel bindings over a HTTPS channel (Default: True) + :type send_cbt: bool + """ + + def __init__( + self, + ssh_conn_id: Optional[str] = None, + endpoint: Optional[str] = None, + remote_host: Optional[str] = None, + remote_port: int = 5985, + transport: str = "plaintext", + username: Optional[str] = None, + password: Optional[str] = None, + service: str = "HTTP", + keytab: Optional[str] = None, + ca_trust_path: Optional[str] = None, + cert_pem: Optional[str] = None, + cert_key_pem: Optional[str] = None, + server_cert_validation: str = "validate", + kerberos_delegation: bool = False, + read_timeout_sec: int = 30, + operation_timeout_sec: int = 20, + kerberos_hostname_override: Optional[str] = None, + message_encryption: Optional[str] = "auto", + credssp_disable_tlsv1_2: bool = False, + send_cbt: bool = True, + ) -> None: + super().__init__() + self.ssh_conn_id = ssh_conn_id + self.endpoint = endpoint + self.remote_host = remote_host + self.remote_port = remote_port + self.transport = transport + self.username = username + self.password = password + self.service = service + self.keytab = keytab + self.ca_trust_path = ca_trust_path + self.cert_pem = cert_pem + self.cert_key_pem = cert_key_pem + self.server_cert_validation = server_cert_validation + self.kerberos_delegation = kerberos_delegation + self.read_timeout_sec = read_timeout_sec + self.operation_timeout_sec = operation_timeout_sec + self.kerberos_hostname_override = kerberos_hostname_override + self.message_encryption = message_encryption + self.credssp_disable_tlsv1_2 = credssp_disable_tlsv1_2 + self.send_cbt = send_cbt + + self.client = None + self.winrm_protocol = None + + def get_conn(self): + if self.client: + return self.client + + self.log.debug("Creating WinRM client for conn_id: %s", self.ssh_conn_id) + if self.ssh_conn_id is not None: + conn = self.get_connection(self.ssh_conn_id) + + if self.username is None: + self.username = conn.login + if self.password is None: + self.password = conn.password + if self.remote_host is None: + self.remote_host = conn.host + + if conn.extra is not None: + extra_options = conn.extra_dejson + + if "endpoint" in extra_options: + self.endpoint = str(extra_options["endpoint"]) + if "remote_port" in extra_options: + self.remote_port = int(extra_options["remote_port"]) + if "transport" in extra_options: + self.transport = str(extra_options["transport"]) + if "service" in extra_options: + self.service = str(extra_options["service"]) + if "keytab" in extra_options: + self.keytab = str(extra_options["keytab"]) + if "ca_trust_path" in extra_options: + self.ca_trust_path = str(extra_options["ca_trust_path"]) + if "cert_pem" in extra_options: + self.cert_pem = str(extra_options["cert_pem"]) + if "cert_key_pem" in extra_options: + self.cert_key_pem = str(extra_options["cert_key_pem"]) + if "server_cert_validation" in extra_options: + self.server_cert_validation = str( + extra_options["server_cert_validation"] + ) + if "kerberos_delegation" in extra_options: + self.kerberos_delegation = ( + str(extra_options["kerberos_delegation"]).lower() == "true" + ) + if "read_timeout_sec" in extra_options: + self.read_timeout_sec = int(extra_options["read_timeout_sec"]) + if "operation_timeout_sec" in extra_options: + self.operation_timeout_sec = int( + extra_options["operation_timeout_sec"] + ) + if "kerberos_hostname_override" in extra_options: + self.kerberos_hostname_override = str( + extra_options["kerberos_hostname_override"] + ) + if "message_encryption" in extra_options: + self.message_encryption = str(extra_options["message_encryption"]) + if "credssp_disable_tlsv1_2" in extra_options: + self.credssp_disable_tlsv1_2 = ( + str(extra_options["credssp_disable_tlsv1_2"]).lower() == "true" + ) + if "send_cbt" in extra_options: + self.send_cbt = str(extra_options["send_cbt"]).lower() == "true" + + if not self.remote_host: + raise AirflowException("Missing required param: remote_host") + + # Auto detecting username values from system + if not self.username: + self.log.debug( + "username to WinRM to host: %s is not specified for connection id" + " %s. Using system's default provided by getpass.getuser()", + self.remote_host, + self.ssh_conn_id, + ) + self.username = getpass.getuser() + + # If endpoint is not set, then build a standard wsman endpoint from host and port. + if not self.endpoint: + self.endpoint = f"http://{self.remote_host}:{self.remote_port}/wsman" + + try: + if self.password and self.password.strip(): + self.winrm_protocol = Protocol( + endpoint=self.endpoint, + transport=self.transport, + username=self.username, + password=self.password, + service=self.service, + keytab=self.keytab, + ca_trust_path=self.ca_trust_path, + cert_pem=self.cert_pem, + cert_key_pem=self.cert_key_pem, + server_cert_validation=self.server_cert_validation, + kerberos_delegation=self.kerberos_delegation, + read_timeout_sec=self.read_timeout_sec, + operation_timeout_sec=self.operation_timeout_sec, + kerberos_hostname_override=self.kerberos_hostname_override, + message_encryption=self.message_encryption, + credssp_disable_tlsv1_2=self.credssp_disable_tlsv1_2, + send_cbt=self.send_cbt, + ) + + self.log.info("Establishing WinRM connection to host: %s", self.remote_host) + self.client = self.winrm_protocol.open_shell() + + except Exception as error: + error_msg = f"Error connecting to host: {self.remote_host}, error: {error}" + self.log.error(error_msg) + raise AirflowException(error_msg) + + return self.client diff --git a/reference/providers/microsoft/winrm/operators/__init__.py b/reference/providers/microsoft/winrm/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/microsoft/winrm/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/microsoft/winrm/operators/winrm.py b/reference/providers/microsoft/winrm/operators/winrm.py new file mode 100644 index 0000000..99fd775 --- /dev/null +++ b/reference/providers/microsoft/winrm/operators/winrm.py @@ -0,0 +1,164 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from base64 import b64encode +from typing import Optional, Union + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook +from airflow.utils.decorators import apply_defaults +from winrm.exceptions import WinRMOperationTimeoutError + +# Hide the following error message in urllib3 when making WinRM connections: +# requests.packages.urllib3.exceptions.HeaderParsingError: [StartBoundaryNotFoundDefect(), +# MultipartInvariantViolationDefect()], unparsed data: '' +logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) + + +class WinRMOperator(BaseOperator): + """ + WinRMOperator to execute commands on given remote host using the winrm_hook. + + :param winrm_hook: predefined ssh_hook to use for remote execution + :type winrm_hook: airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook + :param ssh_conn_id: connection id from airflow Connections + :type ssh_conn_id: str + :param remote_host: remote host to connect + :type remote_host: str + :param command: command to execute on remote host. (templated) + :type command: str + :param ps_path: path to powershell, `powershell` for v5.1- and `pwsh` for v6+. + If specified, it will execute the command as powershell script. + :type ps_path: str + :param output_encoding: the encoding used to decode stout and stderr + :type output_encoding: str + :param timeout: timeout for executing the command. + :type timeout: int + """ + + template_fields = ("command",) + + @apply_defaults + def __init__( + self, + *, + winrm_hook: Optional[WinRMHook] = None, + ssh_conn_id: Optional[str] = None, + remote_host: Optional[str] = None, + command: Optional[str] = None, + ps_path: Optional[str] = None, + output_encoding: str = "utf-8", + timeout: int = 10, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.winrm_hook = winrm_hook + self.ssh_conn_id = ssh_conn_id + self.remote_host = remote_host + self.command = command + self.ps_path = ps_path + self.output_encoding = output_encoding + self.timeout = timeout + + def execute(self, context: dict) -> Union[list, str]: + if self.ssh_conn_id and not self.winrm_hook: + self.log.info("Hook not found, creating...") + self.winrm_hook = WinRMHook(ssh_conn_id=self.ssh_conn_id) + + if not self.winrm_hook: + raise AirflowException("Cannot operate without winrm_hook or ssh_conn_id.") + + if self.remote_host is not None: + self.winrm_hook.remote_host = self.remote_host + + if not self.command: + raise AirflowException("No command specified so nothing to execute here.") + + winrm_client = self.winrm_hook.get_conn() + + # pylint: disable=too-many-nested-blocks + try: + if self.ps_path is not None: + self.log.info( + "Running command as powershell script: '%s'...", self.command + ) + encoded_ps = b64encode(self.command.encode("utf_16_le")).decode("ascii") + command_id = self.winrm_hook.winrm_protocol.run_command( # type: ignore[attr-defined] + winrm_client, f"{self.ps_path} -encodedcommand {encoded_ps}" + ) + else: + self.log.info("Running command: '%s'...", self.command) + command_id = self.winrm_hook.winrm_protocol.run_command( # type: ignore[attr-defined] + winrm_client, self.command + ) + + # See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py + stdout_buffer = [] + stderr_buffer = [] + command_done = False + while not command_done: + try: + # pylint: disable=protected-access + ( + stdout, + stderr, + return_code, + command_done, + ) = self.winrm_hook.winrm_protocol._raw_get_command_output( # type: ignore[attr-defined] + winrm_client, command_id + ) + + # Only buffer stdout if we need to so that we minimize memory usage. + if self.do_xcom_push: + stdout_buffer.append(stdout) + stderr_buffer.append(stderr) + + for line in stdout.decode(self.output_encoding).splitlines(): + self.log.info(line) + for line in stderr.decode(self.output_encoding).splitlines(): + self.log.warning(line) + except WinRMOperationTimeoutError: + # this is an expected error when waiting for a + # long-running process, just silently retry + pass + + self.winrm_hook.winrm_protocol.cleanup_command( # type: ignore[attr-defined] + winrm_client, command_id + ) + self.winrm_hook.winrm_protocol.close_shell(winrm_client) # type: ignore[attr-defined] + + except Exception as e: + raise AirflowException(f"WinRM operator error: {str(e)}") + + if return_code == 0: + # returning output if do_xcom_push is set + enable_pickling = conf.getboolean("core", "enable_xcom_pickling") + if enable_pickling: + return stdout_buffer + else: + return b64encode(b"".join(stdout_buffer)).decode(self.output_encoding) + else: + error_msg = "Error running cmd: {}, return code: {}, error: {}".format( + self.command, + return_code, + b"".join(stderr_buffer).decode(self.output_encoding), + ) + raise AirflowException(error_msg) diff --git a/reference/providers/microsoft/winrm/provider.yaml b/reference/providers/microsoft/winrm/provider.yaml new file mode 100644 index 0000000..92280cd --- /dev/null +++ b/reference/providers/microsoft/winrm/provider.yaml @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-microsoft-winrm +name: Windows Remote Management (WinRM) +description: | + `Windows Remote Management (WinRM) `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Windows Remote Management (WinRM) + external-doc-url: https://docs.microsoft.com/en-us/windows/win32/winrm/portal + logo: /integration-logos/winrm/WinRM.png + tags: [protocol] + +operators: + - integration-name: Windows Remote Management (WinRM) + python-modules: + - airflow.providers.microsoft.winrm.operators.winrm + +hooks: + - integration-name: Windows Remote Management (WinRM) + python-modules: + - airflow.providers.microsoft.winrm.hooks.winrm diff --git a/reference/providers/mongo/CHANGELOG.rst b/reference/providers/mongo/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/mongo/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/mongo/__init__.py b/reference/providers/mongo/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/mongo/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/mongo/hooks/__init__.py b/reference/providers/mongo/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/mongo/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/mongo/hooks/mongo.py b/reference/providers/mongo/hooks/mongo.py new file mode 100644 index 0000000..cd269fa --- /dev/null +++ b/reference/providers/mongo/hooks/mongo.py @@ -0,0 +1,363 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Mongo DB""" +from ssl import CERT_NONE +from types import TracebackType +from typing import List, Optional, Type + +import pymongo +from airflow.hooks.base import BaseHook +from pymongo import MongoClient, ReplaceOne + + +class MongoHook(BaseHook): + """ + PyMongo Wrapper to Interact With Mongo Database + Mongo Connection Documentation + https://docs.mongodb.com/manual/reference/connection-string/index.html + You can specify connection string options in extra field of your connection + https://docs.mongodb.com/manual/reference/connection-string/index.html#connection-string-options + + If you want use DNS seedlist, set `srv` to True. + + ex. + {"srv": true, "replicaSet": "test", "ssl": true, "connectTimeoutMS": 30000} + """ + + conn_name_attr = "conn_id" + default_conn_name = "mongo_default" + conn_type = "mongo" + hook_name = "MongoDB" + + def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None: + + super().__init__() + self.mongo_conn_id = conn_id + self.connection = self.get_connection(conn_id) + self.extras = self.connection.extra_dejson.copy() + self.client = None + + srv = self.extras.pop("srv", False) + scheme = "mongodb+srv" if srv else "mongodb" + + self.uri = "{scheme}://{creds}{host}{port}/{database}".format( + scheme=scheme, + creds=f"{self.connection.login}:{self.connection.password}@" + if self.connection.login + else "", + host=self.connection.host, + port="" if self.connection.port is None else f":{self.connection.port}", + database=self.connection.schema, + ) + + def __enter__(self): + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if self.client is not None: + self.close_conn() + + def get_conn(self) -> MongoClient: + """Fetches PyMongo Client""" + if self.client is not None: + return self.client + + # Mongo Connection Options dict that is unpacked when passed to MongoClient + options = self.extras + + # If we are using SSL disable requiring certs from specific hostname + if options.get("ssl", False): + options.update({"ssl_cert_reqs": CERT_NONE}) + + self.client = MongoClient(self.uri, **options) + + return self.client + + def close_conn(self) -> None: + """Closes connection""" + client = self.client + if client is not None: + client.close() + self.client = None + + def get_collection( + self, mongo_collection: str, mongo_db: Optional[str] = None + ) -> pymongo.collection.Collection: + """ + Fetches a mongo collection object for querying. + + Uses connection schema as DB unless specified. + """ + mongo_db = mongo_db if mongo_db is not None else self.connection.schema + mongo_conn: MongoClient = self.get_conn() + + return mongo_conn.get_database(mongo_db).get_collection(mongo_collection) + + def aggregate( + self, + mongo_collection: str, + aggregate_query: list, + mongo_db: Optional[str] = None, + **kwargs, + ) -> pymongo.command_cursor.CommandCursor: + """ + Runs an aggregation pipeline and returns the results + https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.aggregate + https://api.mongodb.com/python/current/examples/aggregation.html + """ + collection = self.get_collection(mongo_collection, mongo_db=mongo_db) + + return collection.aggregate(aggregate_query, **kwargs) + + def find( + self, + mongo_collection: str, + query: dict, + find_one: bool = False, + mongo_db: Optional[str] = None, + **kwargs, + ) -> pymongo.cursor.Cursor: + """ + Runs a mongo find query and returns the results + https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.find + """ + collection = self.get_collection(mongo_collection, mongo_db=mongo_db) + + if find_one: + return collection.find_one(query, **kwargs) + else: + return collection.find(query, **kwargs) + + def insert_one( + self, mongo_collection: str, doc: dict, mongo_db: Optional[str] = None, **kwargs + ) -> pymongo.results.InsertOneResult: + """ + Inserts a single document into a mongo collection + https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.insert_one + """ + collection = self.get_collection(mongo_collection, mongo_db=mongo_db) + + return collection.insert_one(doc, **kwargs) + + def insert_many( + self, + mongo_collection: str, + docs: dict, + mongo_db: Optional[str] = None, + **kwargs, + ) -> pymongo.results.InsertManyResult: + """ + Inserts many docs into a mongo collection. + https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.insert_many + """ + collection = self.get_collection(mongo_collection, mongo_db=mongo_db) + + return collection.insert_many(docs, **kwargs) + + def update_one( + self, + mongo_collection: str, + filter_doc: dict, + update_doc: dict, + mongo_db: Optional[str] = None, + **kwargs, + ) -> pymongo.results.UpdateResult: + """ + Updates a single document in a mongo collection. + https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_one + + :param mongo_collection: The name of the collection to update. + :type mongo_collection: str + :param filter_doc: A query that matches the documents to update. + :type filter_doc: dict + :param update_doc: The modifications to apply. + :type update_doc: dict + :param mongo_db: The name of the database to use. + Can be omitted; then the database from the connection string is used. + :type mongo_db: str + + """ + collection = self.get_collection(mongo_collection, mongo_db=mongo_db) + + return collection.update_one(filter_doc, update_doc, **kwargs) + + def update_many( + self, + mongo_collection: str, + filter_doc: dict, + update_doc: dict, + mongo_db: Optional[str] = None, + **kwargs, + ) -> pymongo.results.UpdateResult: + """ + Updates one or more documents in a mongo collection. + https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_many + + :param mongo_collection: The name of the collection to update. + :type mongo_collection: str + :param filter_doc: A query that matches the documents to update. + :type filter_doc: dict + :param update_doc: The modifications to apply. + :type update_doc: dict + :param mongo_db: The name of the database to use. + Can be omitted; then the database from the connection string is used. + :type mongo_db: str + + """ + collection = self.get_collection(mongo_collection, mongo_db=mongo_db) + + return collection.update_many(filter_doc, update_doc, **kwargs) + + def replace_one( + self, + mongo_collection: str, + doc: dict, + filter_doc: Optional[dict] = None, + mongo_db: Optional[str] = None, + **kwargs, + ) -> pymongo.results.UpdateResult: + """ + Replaces a single document in a mongo collection. + https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.replace_one + + .. note:: + If no ``filter_doc`` is given, it is assumed that the replacement + document contain the ``_id`` field which is then used as filters. + + :param mongo_collection: The name of the collection to update. + :type mongo_collection: str + :param doc: The new document. + :type doc: dict + :param filter_doc: A query that matches the documents to replace. + Can be omitted; then the _id field from doc will be used. + :type filter_doc: dict + :param mongo_db: The name of the database to use. + Can be omitted; then the database from the connection string is used. + :type mongo_db: str + """ + collection = self.get_collection(mongo_collection, mongo_db=mongo_db) + + if not filter_doc: + filter_doc = {"_id": doc["_id"]} + + return collection.replace_one(filter_doc, doc, **kwargs) + + def replace_many( + self, + mongo_collection: str, + docs: List[dict], + filter_docs: Optional[List[dict]] = None, + mongo_db: Optional[str] = None, + upsert: bool = False, + collation: Optional[pymongo.collation.Collation] = None, + **kwargs, + ) -> pymongo.results.BulkWriteResult: + """ + Replaces many documents in a mongo collection. + + Uses bulk_write with multiple ReplaceOne operations + https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.bulk_write + + .. note:: + If no ``filter_docs``are given, it is assumed that all + replacement documents contain the ``_id`` field which are then + used as filters. + + :param mongo_collection: The name of the collection to update. + :type mongo_collection: str + :param docs: The new documents. + :type docs: list[dict] + :param filter_docs: A list of queries that match the documents to replace. + Can be omitted; then the _id fields from docs will be used. + :type filter_docs: list[dict] + :param mongo_db: The name of the database to use. + Can be omitted; then the database from the connection string is used. + :type mongo_db: str + :param upsert: If ``True``, perform an insert if no documents + match the filters for the replace operation. + :type upsert: bool + :param collation: An instance of + :class:`~pymongo.collation.Collation`. This option is only + supported on MongoDB 3.4 and above. + :type collation: pymongo.collation.Collation + + """ + collection = self.get_collection(mongo_collection, mongo_db=mongo_db) + + if not filter_docs: + filter_docs = [{"_id": doc["_id"]} for doc in docs] + + requests = [ + ReplaceOne(filter_docs[i], docs[i], upsert=upsert, collation=collation) + for i in range(len(docs)) + ] + + return collection.bulk_write(requests, **kwargs) + + def delete_one( + self, + mongo_collection: str, + filter_doc: dict, + mongo_db: Optional[str] = None, + **kwargs, + ) -> pymongo.results.DeleteResult: + """ + Deletes a single document in a mongo collection. + https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_one + + :param mongo_collection: The name of the collection to delete from. + :type mongo_collection: str + :param filter_doc: A query that matches the document to delete. + :type filter_doc: dict + :param mongo_db: The name of the database to use. + Can be omitted; then the database from the connection string is used. + :type mongo_db: str + + """ + collection = self.get_collection(mongo_collection, mongo_db=mongo_db) + + return collection.delete_one(filter_doc, **kwargs) + + def delete_many( + self, + mongo_collection: str, + filter_doc: dict, + mongo_db: Optional[str] = None, + **kwargs, + ) -> pymongo.results.DeleteResult: + """ + Deletes one or more documents in a mongo collection. + https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_many + + :param mongo_collection: The name of the collection to delete from. + :type mongo_collection: str + :param filter_doc: A query that matches the documents to delete. + :type filter_doc: dict + :param mongo_db: The name of the database to use. + Can be omitted; then the database from the connection string is used. + :type mongo_db: str + + """ + collection = self.get_collection(mongo_collection, mongo_db=mongo_db) + + return collection.delete_many(filter_doc, **kwargs) diff --git a/reference/providers/mongo/provider.yaml b/reference/providers/mongo/provider.yaml new file mode 100644 index 0000000..bbe3d1d --- /dev/null +++ b/reference/providers/mongo/provider.yaml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-mongo +name: MongoDB +description: | + `MongoDB `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: MongoDB + external-doc-url: https://www.mongodb.com/what-is-mongodb + logo: /integration-logos/mongo/MongoDB.png + tags: [software] + +sensors: + - integration-name: MongoDB + python-modules: + - airflow.providers.mongo.sensors.mongo +hooks: + - integration-name: MongoDB + python-modules: + - airflow.providers.mongo.hooks.mongo + +hook-class-names: + - airflow.providers.mongo.hooks.mongo.MongoHook diff --git a/reference/providers/mongo/sensors/__init__.py b/reference/providers/mongo/sensors/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/mongo/sensors/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/mongo/sensors/mongo.py b/reference/providers/mongo/sensors/mongo.py new file mode 100644 index 0000000..8519b0d --- /dev/null +++ b/reference/providers/mongo/sensors/mongo.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from airflow.providers.mongo.hooks.mongo import MongoHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class MongoSensor(BaseSensorOperator): + """ + Checks for the existence of a document which + matches the given query in MongoDB. Example: + + >>> mongo_sensor = MongoSensor(collection="coll", + ... query={"key": "value"}, + ... mongo_conn_id="mongo_default", + ... task_id="mongo_sensor") + + :param collection: Target MongoDB collection. + :type collection: str + :param query: The query to find the target document. + :type query: dict + :param mongo_conn_id: The connection ID to use + when connecting to MongoDB. + :type mongo_conn_id: str + """ + + template_fields = ("collection", "query") + + @apply_defaults + def __init__( + self, + *, + collection: str, + query: dict, + mongo_conn_id: str = "mongo_default", + **kwargs + ) -> None: + super().__init__(**kwargs) + self.mongo_conn_id = mongo_conn_id + self.collection = collection + self.query = query + + def poke(self, context: dict) -> bool: + self.log.info( + "Sensor check existence of the document that matches the following query: %s", + self.query, + ) + hook = MongoHook(self.mongo_conn_id) + return hook.find(self.collection, self.query, find_one=True) is not None diff --git a/reference/providers/mysql/CHANGELOG.rst b/reference/providers/mysql/CHANGELOG.rst new file mode 100644 index 0000000..bba4e54 --- /dev/null +++ b/reference/providers/mysql/CHANGELOG.rst @@ -0,0 +1,42 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +The version of MySQL server has to be 5.6.4+. The exact version upper bound depends +on the version of ``mysqlclient`` package. For example, ``mysqlclient`` 1.3.12 can only be +used with MySQL server 5.6.4 through 5.7. + +Changelog +--------- + +1.0.2 +..... + +Bug fixes +~~~~~~~~~ + +* ``MySQL hook respects conn_name_attr (#14240)`` + +1.0.1 +..... + +Updated documentation and readme files. + + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/mysql/__init__.py b/reference/providers/mysql/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/mysql/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/mysql/example_dags/__init__.py b/reference/providers/mysql/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/mysql/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/mysql/example_dags/example_mysql.py b/reference/providers/mysql/example_dags/example_mysql.py new file mode 100644 index 0000000..13e0d68 --- /dev/null +++ b/reference/providers/mysql/example_dags/example_mysql.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example use of MySql related operators. +""" + +from airflow import DAG +from airflow.providers.mysql.operators.mysql import MySqlOperator +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", +} + +dag = DAG( + "example_mysql", + default_args=default_args, + start_date=days_ago(2), + tags=["example"], +) + +# [START howto_operator_mysql] + +drop_table_mysql_task = MySqlOperator( + task_id="create_table_mysql", + mysql_conn_id="mysql_conn_id", + sql=r"""DROP TABLE table_name;""", + dag=dag, +) + +# [END howto_operator_mysql] + +# [START howto_operator_mysql_external_file] + +mysql_task = MySqlOperator( + task_id="create_table_mysql_external_file", + mysql_conn_id="mysql_conn_id", + sql="/scripts/drop_table.sql", + dag=dag, +) + +# [END howto_operator_mysql_external_file] + +drop_table_mysql_task >> mysql_task diff --git a/reference/providers/mysql/hooks/__init__.py b/reference/providers/mysql/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/mysql/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/mysql/hooks/mysql.py b/reference/providers/mysql/hooks/mysql.py new file mode 100644 index 0000000..22193c4 --- /dev/null +++ b/reference/providers/mysql/hooks/mysql.py @@ -0,0 +1,265 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module allows to connect to a MySQL database.""" +import json +from typing import Dict, Optional, Tuple + +from airflow.hooks.dbapi import DbApiHook +from airflow.models import Connection + + +class MySqlHook(DbApiHook): + """ + Interact with MySQL. + + You can specify charset in the extra field of your connection + as ``{"charset": "utf8"}``. Also you can choose cursor as + ``{"cursor": "SSCursor"}``. Refer to the MySQLdb.cursors for more details. + + Note: For AWS IAM authentication, use iam in the extra connection parameters + and set it to true. Leave the password field empty. This will use the + "aws_default" connection to get the temporary token unless you override + in extras. + extras example: ``{"iam":true, "aws_conn_id":"my_aws_conn"}`` + """ + + conn_name_attr = "mysql_conn_id" + default_conn_name = "mysql_default" + conn_type = "mysql" + hook_name = "MySQL" + supports_autocommit = True + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.schema = kwargs.pop("schema", None) + self.connection = kwargs.pop("connection", None) + + def set_autocommit(self, conn: Connection, autocommit: bool) -> None: # noqa: D403 + """MySql connection sets autocommit in a different way.""" + conn.autocommit(autocommit) + + def get_autocommit(self, conn: Connection) -> bool: # noqa: D403 + """ + MySql connection gets autocommit in a different way. + + :param conn: connection to get autocommit setting from. + :type conn: connection object. + :return: connection autocommit setting + :rtype: bool + """ + return conn.get_autocommit() + + def _get_conn_config_mysql_client(self, conn: Connection) -> Dict: + conn_config = { + "user": conn.login, + "passwd": conn.password or "", + "host": conn.host or "localhost", + "db": self.schema or conn.schema or "", + } + + # check for authentication via AWS IAM + if conn.extra_dejson.get("iam", False): + conn_config["passwd"], conn.port = self.get_iam_token(conn) + conn_config["read_default_group"] = "enable-cleartext-plugin" + + conn_config["port"] = int(conn.port) if conn.port else 3306 + + if conn.extra_dejson.get("charset", False): + conn_config["charset"] = conn.extra_dejson["charset"] + if conn_config["charset"].lower() in ("utf8", "utf-8"): + conn_config["use_unicode"] = True + if conn.extra_dejson.get("cursor", False): + import MySQLdb.cursors + + if (conn.extra_dejson["cursor"]).lower() == "sscursor": + conn_config["cursorclass"] = MySQLdb.cursors.SSCursor + elif (conn.extra_dejson["cursor"]).lower() == "dictcursor": + conn_config["cursorclass"] = MySQLdb.cursors.DictCursor + elif (conn.extra_dejson["cursor"]).lower() == "ssdictcursor": + conn_config["cursorclass"] = MySQLdb.cursors.SSDictCursor + local_infile = conn.extra_dejson.get("local_infile", False) + if conn.extra_dejson.get("ssl", False): + # SSL parameter for MySQL has to be a dictionary and in case + # of extra/dejson we can get string if extra is passed via + # URL parameters + dejson_ssl = conn.extra_dejson["ssl"] + if isinstance(dejson_ssl, str): + dejson_ssl = json.loads(dejson_ssl) + conn_config["ssl"] = dejson_ssl + if conn.extra_dejson.get("unix_socket"): + conn_config["unix_socket"] = conn.extra_dejson["unix_socket"] + if local_infile: + conn_config["local_infile"] = 1 + return conn_config + + def _get_conn_config_mysql_connector_python(self, conn: Connection) -> Dict: + conn_config = { + "user": conn.login, + "password": conn.password or "", + "host": conn.host or "localhost", + "database": self.schema or conn.schema or "", + "port": int(conn.port) if conn.port else 3306, + } + + if conn.extra_dejson.get("allow_local_infile", False): + conn_config["allow_local_infile"] = True + + return conn_config + + def get_conn(self): + """ + Establishes a connection to a mysql database + by extracting the connection configuration from the Airflow connection. + + .. note:: By default it connects to the database via the mysqlclient library. + But you can also choose the mysql-connector-python library which lets you connect through ssl + without any further ssl parameters required. + + :return: a mysql connection object + """ + conn = self.connection or self.get_connection( + getattr(self, self.conn_name_attr) + ) # pylint: disable=no-member + + client_name = conn.extra_dejson.get("client", "mysqlclient") + + if client_name == "mysqlclient": + import MySQLdb + + conn_config = self._get_conn_config_mysql_client(conn) + return MySQLdb.connect(**conn_config) + + if client_name == "mysql-connector-python": + import mysql.connector # pylint: disable=no-name-in-module + + conn_config = self._get_conn_config_mysql_connector_python(conn) + return mysql.connector.connect(**conn_config) # pylint: disable=no-member + + raise ValueError("Unknown MySQL client name provided!") + + def get_uri(self) -> str: + conn = self.get_connection(getattr(self, self.conn_name_attr)) + uri = super().get_uri() + if conn.extra_dejson.get("charset", False): + charset = conn.extra_dejson["charset"] + return f"{uri}?charset={charset}" + return uri + + def bulk_load(self, table: str, tmp_file: str) -> None: + """Loads a tab-delimited file into a database table""" + conn = self.get_conn() + cur = conn.cursor() + cur.execute( + f""" + LOAD DATA LOCAL INFILE '{tmp_file}' + INTO TABLE {table} + """ + ) + conn.commit() + + def bulk_dump(self, table: str, tmp_file: str) -> None: + """Dumps a database table into a tab-delimited file""" + conn = self.get_conn() + cur = conn.cursor() + cur.execute( + f""" + SELECT * INTO OUTFILE '{tmp_file}' + FROM {table} + """ + ) + conn.commit() + + @staticmethod + def _serialize_cell( + cell: object, conn: Optional[Connection] = None + ) -> object: # pylint: disable=signature-differs # noqa: D403 + """ + MySQLdb converts an argument to a literal + when passing those separately to execute. Hence, this method does nothing. + + :param cell: The cell to insert into the table + :type cell: object + :param conn: The database connection + :type conn: connection object + :return: The same cell + :rtype: object + """ + return cell + + def get_iam_token(self, conn: Connection) -> Tuple[str, int]: + """ + Uses AWSHook to retrieve a temporary password to connect to MySQL + Port is required. If none is provided, default 3306 is used + """ + from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + aws_conn_id = conn.extra_dejson.get("aws_conn_id", "aws_default") + aws_hook = AwsBaseHook(aws_conn_id, client_type="rds") + if conn.port is None: + port = 3306 + else: + port = conn.port + client = aws_hook.get_conn() + token = client.generate_db_auth_token(conn.host, port, conn.login) + return token, port + + def bulk_load_custom( + self, + table: str, + tmp_file: str, + duplicate_key_handling: str = "IGNORE", + extra_options: str = "", + ) -> None: + """ + A more configurable way to load local data from a file into the database. + + .. warning:: According to the mysql docs using this function is a + `security risk `_. + If you want to use it anyway you can do so by setting a client-side + server-side option. + This depends on the mysql client library used. + + :param table: The table were the file will be loaded into. + :type table: str + :param tmp_file: The file (name) that contains the data. + :type tmp_file: str + :param duplicate_key_handling: Specify what should happen to duplicate data. + You can choose either `IGNORE` or `REPLACE`. + + .. seealso:: + https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-duplicate-key-handling + :type duplicate_key_handling: str + :param extra_options: More sql options to specify exactly how to load the data. + + .. seealso:: https://dev.mysql.com/doc/refman/8.0/en/load-data.html + :type extra_options: str + """ + conn = self.get_conn() + cursor = conn.cursor() + + cursor.execute( + f""" + LOAD DATA LOCAL INFILE '{tmp_file}' + {duplicate_key_handling} + INTO TABLE {table} + {extra_options} + """ + ) + + cursor.close() + conn.commit() diff --git a/reference/providers/mysql/operators/__init__.py b/reference/providers/mysql/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/mysql/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/mysql/operators/mysql.py b/reference/providers/mysql/operators/mysql.py new file mode 100644 index 0000000..1b16a42 --- /dev/null +++ b/reference/providers/mysql/operators/mysql.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Dict, Iterable, Mapping, Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.utils.decorators import apply_defaults + + +class MySqlOperator(BaseOperator): + """ + Executes sql code in a specific MySQL database + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MySqlOperator` + + :param sql: the sql code to be executed. Can receive a str representing a + sql statement, a list of str (sql statements), or reference to a template file. + Template reference are recognized by str ending in '.sql' + (templated) + :type sql: str or list[str] + :param mysql_conn_id: reference to a specific mysql database + :type mysql_conn_id: str + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: dict or iterable + :param autocommit: if True, each command is automatically committed. + (default value: False) + :type autocommit: bool + :param database: name of database which overwrite defined one in connection + :type database: str + """ + + template_fields = ("sql",) + template_ext = (".sql",) + ui_color = "#ededed" + + @apply_defaults + def __init__( + self, + *, + sql: str, + mysql_conn_id: str = "mysql_default", + parameters: Optional[Union[Mapping, Iterable]] = None, + autocommit: bool = False, + database: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.mysql_conn_id = mysql_conn_id + self.sql = sql + self.autocommit = autocommit + self.parameters = parameters + self.database = database + + def execute(self, context: Dict) -> None: + self.log.info("Executing: %s", self.sql) + hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database) + hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) diff --git a/reference/providers/mysql/provider.yaml b/reference/providers/mysql/provider.yaml new file mode 100644 index 0000000..a9b408f --- /dev/null +++ b/reference/providers/mysql/provider.yaml @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-mysql +name: MySQL +description: | + `MySQL `__ + +versions: + - 1.0.2 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: MySQL + external-doc-url: https://www.mysql.com/ + how-to-guide: + - /docs/apache-airflow-providers-mysql/operators.rst + logo: /integration-logos/mysql/MySQL.png + tags: [software] + +operators: + - integration-name: MySQL + + python-modules: + - airflow.providers.mysql.operators.mysql + +hooks: + - integration-name: MySQL + python-modules: + - airflow.providers.mysql.hooks.mysql + +transfers: + - source-integration-name: Vertica + target-integration-name: MySQL + python-module: airflow.providers.mysql.transfers.vertica_to_mysql + - source-integration-name: Amazon Simple Storage Service (S3) + target-integration-name: MySQL + python-module: airflow.providers.mysql.transfers.s3_to_mysql + - source-integration-name: Snowflake + target-integration-name: MySQL + python-module: airflow.providers.mysql.transfers.presto_to_mysql + +hook-class-names: + - airflow.providers.mysql.hooks.mysql.MySqlHook diff --git a/reference/providers/mysql/transfers/__init__.py b/reference/providers/mysql/transfers/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/mysql/transfers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/mysql/transfers/presto_to_mysql.py b/reference/providers/mysql/transfers/presto_to_mysql.py new file mode 100644 index 0000000..9cab81f --- /dev/null +++ b/reference/providers/mysql/transfers/presto_to_mysql.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Dict, Optional + +from airflow.models import BaseOperator +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.providers.presto.hooks.presto import PrestoHook +from airflow.utils.decorators import apply_defaults + + +class PrestoToMySqlOperator(BaseOperator): + """ + Moves data from Presto to MySQL, note that for now the data is loaded + into memory before being pushed to MySQL, so this operator should + be used for smallish amount of data. + + :param sql: SQL query to execute against Presto. (templated) + :type sql: str + :param mysql_table: target MySQL table, use dot notation to target a + specific database. (templated) + :type mysql_table: str + :param mysql_conn_id: source mysql connection + :type mysql_conn_id: str + :param presto_conn_id: source presto connection + :type presto_conn_id: str + :param mysql_preoperator: sql statement to run against mysql prior to + import, typically use to truncate of delete in place + of the data coming in, allowing the task to be idempotent (running + the task twice won't double load data). (templated) + :type mysql_preoperator: str + """ + + template_fields = ("sql", "mysql_table", "mysql_preoperator") + template_ext = (".sql",) + ui_color = "#a0e08c" + + @apply_defaults + def __init__( + self, + *, + sql: str, + mysql_table: str, + presto_conn_id: str = "presto_default", + mysql_conn_id: str = "mysql_default", + mysql_preoperator: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sql = sql + self.mysql_table = mysql_table + self.mysql_conn_id = mysql_conn_id + self.mysql_preoperator = mysql_preoperator + self.presto_conn_id = presto_conn_id + + def execute(self, context: Dict) -> None: + presto = PrestoHook(presto_conn_id=self.presto_conn_id) + self.log.info("Extracting data from Presto: %s", self.sql) + results = presto.get_records(self.sql) + + mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) + if self.mysql_preoperator: + self.log.info("Running MySQL preoperator") + self.log.info(self.mysql_preoperator) + mysql.run(self.mysql_preoperator) + + self.log.info("Inserting rows into MySQL") + mysql.insert_rows(table=self.mysql_table, rows=results) diff --git a/reference/providers/mysql/transfers/s3_to_mysql.py b/reference/providers/mysql/transfers/s3_to_mysql.py new file mode 100644 index 0000000..d93d429 --- /dev/null +++ b/reference/providers/mysql/transfers/s3_to_mysql.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from typing import Dict, Optional + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.utils.decorators import apply_defaults + + +class S3ToMySqlOperator(BaseOperator): + """ + Loads a file from S3 into a MySQL table. + + :param s3_source_key: The path to the file (S3 key) that will be loaded into MySQL. + :type s3_source_key: str + :param mysql_table: The MySQL table into where the data will be sent. + :type mysql_table: str + :param mysql_duplicate_key_handling: Specify what should happen to duplicate data. + You can choose either `IGNORE` or `REPLACE`. + + .. seealso:: + https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-duplicate-key-handling + :type mysql_duplicate_key_handling: str + :param mysql_extra_options: MySQL options to specify exactly how to load the data. + :type mysql_extra_options: Optional[str] + :param aws_conn_id: The S3 connection that contains the credentials to the S3 Bucket. + :type aws_conn_id: str + :param mysql_conn_id: The MySQL connection that contains the credentials to the MySQL data base. + :type mysql_conn_id: str + """ + + template_fields = ( + "s3_source_key", + "mysql_table", + ) + template_ext = () + ui_color = "#f4a460" + + @apply_defaults + def __init__( + self, + *, + s3_source_key: str, + mysql_table: str, + mysql_duplicate_key_handling: str = "IGNORE", + mysql_extra_options: Optional[str] = None, + aws_conn_id: str = "aws_default", + mysql_conn_id: str = "mysql_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.s3_source_key = s3_source_key + self.mysql_table = mysql_table + self.mysql_duplicate_key_handling = mysql_duplicate_key_handling + self.mysql_extra_options = mysql_extra_options or "" + self.aws_conn_id = aws_conn_id + self.mysql_conn_id = mysql_conn_id + + def execute(self, context: Dict) -> None: + """ + Executes the transfer operation from S3 to MySQL. + + :param context: The context that is being provided when executing. + :type context: dict + """ + self.log.info( + "Loading %s to MySql table %s...", self.s3_source_key, self.mysql_table + ) + + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) + file = s3_hook.download_file(key=self.s3_source_key) + + try: + mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) + mysql.bulk_load_custom( + table=self.mysql_table, + tmp_file=file, + duplicate_key_handling=self.mysql_duplicate_key_handling, + extra_options=self.mysql_extra_options, + ) + finally: + # Remove file downloaded from s3 to be idempotent. + os.remove(file) diff --git a/reference/providers/mysql/transfers/vertica_to_mysql.py b/reference/providers/mysql/transfers/vertica_to_mysql.py new file mode 100644 index 0000000..f0d1c39 --- /dev/null +++ b/reference/providers/mysql/transfers/vertica_to_mysql.py @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from contextlib import closing +from tempfile import NamedTemporaryFile +from typing import Optional + +import MySQLdb +import unicodecsv as csv +from airflow.models import BaseOperator +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.providers.vertica.hooks.vertica import VerticaHook +from airflow.utils.decorators import apply_defaults + + +class VerticaToMySqlOperator(BaseOperator): + """ + Moves data from Vertica to MySQL. + + :param sql: SQL query to execute against the Vertica database. (templated) + :type sql: str + :param vertica_conn_id: source Vertica connection + :type vertica_conn_id: str + :param mysql_table: target MySQL table, use dot notation to target a + specific database. (templated) + :type mysql_table: str + :param mysql_conn_id: source mysql connection + :type mysql_conn_id: str + :param mysql_preoperator: sql statement to run against MySQL prior to + import, typically use to truncate of delete in place of the data + coming in, allowing the task to be idempotent (running the task + twice won't double load data). (templated) + :type mysql_preoperator: str + :param mysql_postoperator: sql statement to run against MySQL after the + import, typically used to move data from staging to production + and issue cleanup commands. (templated) + :type mysql_postoperator: str + :param bulk_load: flag to use bulk_load option. This loads MySQL directly + from a tab-delimited text file using the LOAD DATA LOCAL INFILE command. + This option requires an extra connection parameter for the + destination MySQL connection: {'local_infile': true}. + :type bulk_load: bool + """ + + template_fields = ("sql", "mysql_table", "mysql_preoperator", "mysql_postoperator") + template_ext = (".sql",) + ui_color = "#a0e08c" + + @apply_defaults + def __init__( + self, + sql: str, + mysql_table: str, + vertica_conn_id: str = "vertica_default", + mysql_conn_id: str = "mysql_default", + mysql_preoperator: Optional[str] = None, + mysql_postoperator: Optional[str] = None, + bulk_load: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.sql = sql + self.mysql_table = mysql_table + self.mysql_conn_id = mysql_conn_id + self.mysql_preoperator = mysql_preoperator + self.mysql_postoperator = mysql_postoperator + self.vertica_conn_id = vertica_conn_id + self.bulk_load = bulk_load + + def execute(self, context): + vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) + mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) + + tmpfile = None + result = None + + selected_columns = [] + + count = 0 + with closing(vertica.get_conn()) as conn: + with closing(conn.cursor()) as cursor: + cursor.execute(self.sql) + selected_columns = [d.name for d in cursor.description] + + if self.bulk_load: + tmpfile = NamedTemporaryFile("w") + + self.log.info( + "Selecting rows from Vertica to local file %s...", tmpfile.name + ) + self.log.info(self.sql) + + csv_writer = csv.writer(tmpfile, delimiter="\t", encoding="utf-8") + for row in cursor.iterate(): + csv_writer.writerow(row) + count += 1 + + tmpfile.flush() + else: + self.log.info("Selecting rows from Vertica...") + self.log.info(self.sql) + + result = cursor.fetchall() + count = len(result) + + self.log.info("Selected rows from Vertica %s", count) + + if self.mysql_preoperator: + self.log.info("Running MySQL preoperator...") + mysql.run(self.mysql_preoperator) + + try: + if self.bulk_load: + self.log.info("Bulk inserting rows into MySQL...") + with closing(mysql.get_conn()) as conn: + with closing(conn.cursor()) as cursor: + cursor.execute( + "LOAD DATA LOCAL INFILE '%s' INTO " + "TABLE %s LINES TERMINATED BY '\r\n' (%s)" + % ( + tmpfile.name, + self.mysql_table, + ", ".join(selected_columns), + ) + ) + conn.commit() + tmpfile.close() + else: + self.log.info("Inserting rows into MySQL...") + mysql.insert_rows( + table=self.mysql_table, rows=result, target_fields=selected_columns + ) + self.log.info("Inserted rows into MySQL %s", count) + except (MySQLdb.Error, MySQLdb.Warning): # pylint: disable=no-member + self.log.info("Inserted rows into MySQL 0") + raise + + if self.mysql_postoperator: + self.log.info("Running MySQL postoperator...") + mysql.run(self.mysql_postoperator) + + self.log.info("Done") diff --git a/reference/providers/neo4j/CHANGELOG.rst b/reference/providers/neo4j/CHANGELOG.rst new file mode 100644 index 0000000..54c7707 --- /dev/null +++ b/reference/providers/neo4j/CHANGELOG.rst @@ -0,0 +1,35 @@ + + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Changelog +--------- + +1.0.1 +..... + +Bug fixes +~~~~~~~~~ + +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` + + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/neo4j/README.md b/reference/providers/neo4j/README.md new file mode 100644 index 0000000..ef14aff --- /dev/null +++ b/reference/providers/neo4j/README.md @@ -0,0 +1,18 @@ + diff --git a/reference/providers/neo4j/__init__.py b/reference/providers/neo4j/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/neo4j/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/neo4j/example_dags/__init__.py b/reference/providers/neo4j/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/neo4j/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/neo4j/example_dags/example_neo4j.py b/reference/providers/neo4j/example_dags/example_neo4j.py new file mode 100644 index 0000000..956fc7b --- /dev/null +++ b/reference/providers/neo4j/example_dags/example_neo4j.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example use of Neo4j related operators. +""" + +from airflow import DAG +from airflow.providers.neo4j.operators.neo4j import Neo4jOperator +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", +} + +dag = DAG( + "example_neo4j", + default_args=default_args, + start_date=days_ago(2), + tags=["example"], +) + +# [START run_query_neo4j_operator] + +neo4j_task = Neo4jOperator( + task_id="run_neo4j_query", + neo4j_conn_id="neo4j_conn_id", + sql='MATCH (tom {name: "Tom Hanks"}) RETURN tom', + dag=dag, +) + +# [END run_query_neo4j_operator] + +neo4j_task diff --git a/reference/providers/neo4j/hooks/__init__.py b/reference/providers/neo4j/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/neo4j/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/neo4j/hooks/neo4j.py b/reference/providers/neo4j/hooks/neo4j.py new file mode 100644 index 0000000..110e8eb --- /dev/null +++ b/reference/providers/neo4j/hooks/neo4j.py @@ -0,0 +1,118 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module allows to connect to a Neo4j database.""" + +from airflow.hooks.base import BaseHook +from airflow.models import Connection +from neo4j import GraphDatabase, Neo4jDriver, Result + + +class Neo4jHook(BaseHook): + """ + Interact with Neo4j. + + Performs a connection to Neo4j and runs the query. + """ + + conn_name_attr = "neo4j_conn_id" + default_conn_name = "neo4j_default" + conn_type = "neo4j" + hook_name = "Neo4j" + + def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.neo4j_conn_id = conn_id + self.connection = kwargs.pop("connection", None) + self.client = None + self.extras = None + self.uri = None + + def get_conn(self) -> Neo4jDriver: + """ + Function that initiates a new Neo4j connection + with username, password and database schema. + """ + self.connection = self.get_connection(self.neo4j_conn_id) + self.extras = self.connection.extra_dejson.copy() + + self.uri = self.get_uri(self.connection) + self.log.info("URI: %s", self.uri) + + if self.client is not None: + return self.client + + is_encrypted = self.connection.extra_dejson.get("encrypted", False) + + self.client = GraphDatabase.driver( + self.uri, + auth=(self.connection.login, self.connection.password), + encrypted=is_encrypted, + ) + + return self.client + + def get_uri(self, conn: Connection) -> str: + """ + Build the uri based on extras + - Default - uses bolt scheme(bolt://) + - neo4j_scheme - neo4j:// + - certs_self_signed - neo4j+ssc:// + - certs_trusted_ca - neo4j+s:// + :param conn: connection object. + :return: uri + """ + use_neo4j_scheme = conn.extra_dejson.get("neo4j_scheme", False) + scheme = "neo4j" if use_neo4j_scheme else "bolt" + + # Self signed certificates + ssc = conn.extra_dejson.get("certs_self_signed", False) + + # Only certificates signed by CA. + trusted_ca = conn.extra_dejson.get("certs_trusted_ca", False) + encryption_scheme = "" + + if ssc: + encryption_scheme = "+ssc" + elif trusted_ca: + encryption_scheme = "+s" + + return "{scheme}{encryption_scheme}://{host}:{port}".format( + scheme=scheme, + encryption_scheme=encryption_scheme, + host=conn.host, + port="7687" if conn.port is None else f"{conn.port}", + ) + + def run(self, query) -> Result: + """ + Function to create a neo4j session + and execute the query in the session. + + + :param query: Neo4j query + :return: Result + """ + driver = self.get_conn() + if not self.connection.schema: + with driver.session() as session: + result = session.run(query) + else: + with driver.session(database=self.connection.schema) as session: + result = session.run(query) + return result diff --git a/reference/providers/neo4j/operators/__init__.py b/reference/providers/neo4j/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/neo4j/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/neo4j/operators/neo4j.py b/reference/providers/neo4j/operators/neo4j.py new file mode 100644 index 0000000..4e96611 --- /dev/null +++ b/reference/providers/neo4j/operators/neo4j.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Dict, Iterable, Mapping, Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.neo4j.hooks.neo4j import Neo4jHook +from airflow.utils.decorators import apply_defaults + + +class Neo4jOperator(BaseOperator): + """ + Executes sql code in a specific Neo4j database + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:Neo4jOperator` + + :param sql: the sql code to be executed. Can receive a str representing a + sql statement, a list of str (sql statements) + :type sql: str or list[str] + :param neo4j_conn_id: reference to a specific Neo4j database + :type neo4j_conn_id: str + """ + + @apply_defaults + def __init__( + self, + *, + sql: str, + neo4j_conn_id: str = "neo4j_default", + parameters: Optional[Union[Mapping, Iterable]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.neo4j_conn_id = neo4j_conn_id + self.sql = sql + self.parameters = parameters + self.hook = None + + def get_hook(self): + """Function to retrieve the Neo4j Hook.""" + return Neo4jHook(conn_id=self.neo4j_conn_id) + + def execute(self, context: Dict) -> None: + self.log.info("Executing: %s", self.sql) + self.hook = self.get_hook() + self.hook.run(self.sql) diff --git a/reference/providers/neo4j/provider.yaml b/reference/providers/neo4j/provider.yaml new file mode 100644 index 0000000..9d249d6 --- /dev/null +++ b/reference/providers/neo4j/provider.yaml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-neo4j +name: Neo4j +description: | + `Neo4j `__ + +versions: + - 1.0.1 + - 1.0.0 +integrations: + - integration-name: Neo4j + external-doc-url: https://neo4j.com/ + how-to-guide: + - /docs/apache-airflow-providers-neo4j/operators/neo4j.rst + tags: [software] + +operators: + - integration-name: Neo4j + python-modules: + - airflow.providers.neo4j.operators.neo4j + +hooks: + - integration-name: Neo4j + python-modules: + - airflow.providers.neo4j.hooks.neo4j + +hook-class-names: + - airflow.providers.neo4j.hooks.neo4j.Neo4jHook diff --git a/reference/providers/odbc/CHANGELOG.rst b/reference/providers/odbc/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/odbc/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/odbc/__init__.py b/reference/providers/odbc/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/odbc/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/odbc/hooks/__init__.py b/reference/providers/odbc/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/odbc/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/odbc/hooks/odbc.py b/reference/providers/odbc/hooks/odbc.py new file mode 100644 index 0000000..4243682 --- /dev/null +++ b/reference/providers/odbc/hooks/odbc.py @@ -0,0 +1,208 @@ +# pylint: disable=c-extension-no-member +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains ODBC hook.""" +from typing import Any, Optional +from urllib.parse import quote_plus + +import pyodbc +from airflow.hooks.dbapi import DbApiHook +from airflow.utils.helpers import merge_dicts + + +class OdbcHook(DbApiHook): + """ + Interact with odbc data sources using pyodbc. + + See :doc:`/connections/odbc` for full documentation. + """ + + DEFAULT_SQLALCHEMY_SCHEME = "mssql+pyodbc" + conn_name_attr = "odbc_conn_id" + default_conn_name = "odbc_default" + conn_type = "odbc" + hook_name = "ODBC" + supports_autocommit = True + + def __init__( + self, + *args, + database: Optional[str] = None, + driver: Optional[str] = None, + dsn: Optional[str] = None, + connect_kwargs: Optional[dict] = None, + sqlalchemy_scheme: Optional[str] = None, + **kwargs, + ) -> None: + """ + :param args: passed to DbApiHook + :param database: database to use -- overrides connection ``schema`` + :param driver: name of driver or path to driver. overrides driver supplied in connection ``extra`` + :param dsn: name of DSN to use. overrides DSN supplied in connection ``extra`` + :param connect_kwargs: keyword arguments passed to ``pyodbc.connect`` + :param sqlalchemy_scheme: Scheme sqlalchemy connection. Default is ``mssql+pyodbc`` Only used for + ``get_sqlalchemy_engine`` and ``get_sqlalchemy_connection`` methods. + :param kwargs: passed to DbApiHook + """ + super().__init__(*args, **kwargs) + self._database = database + self._driver = driver + self._dsn = dsn + self._conn_str = None + self._sqlalchemy_scheme = sqlalchemy_scheme + self._connection = None + self._connect_kwargs = connect_kwargs + + @property + def connection(self): + """``airflow.Connection`` object with connection id ``odbc_conn_id``""" + if not self._connection: + self._connection = self.get_connection(getattr(self, self.conn_name_attr)) + return self._connection + + @property + def database(self) -> Optional[str]: + """Database provided in init if exists; otherwise, ``schema`` from ``Connection`` object.""" + return self._database or self.connection.schema + + @property + def sqlalchemy_scheme(self) -> Optional[str]: + """Database provided in init if exists; otherwise, ``schema`` from ``Connection`` object.""" + return ( + self._sqlalchemy_scheme + or self.connection_extra_lower.get("sqlalchemy_scheme") + or self.DEFAULT_SQLALCHEMY_SCHEME + ) + + @property + def connection_extra_lower(self) -> dict: + """ + ``connection.extra_dejson`` but where keys are converted to lower case. + + This is used internally for case-insensitive access of odbc params. + """ + return {k.lower(): v for k, v in self.connection.extra_dejson.items()} + + @property + def driver(self) -> Optional[str]: + """Driver from init param if given; else try to find one in connection extra.""" + if not self._driver: + driver = self.connection_extra_lower.get("driver") + if driver: + self._driver = driver + return self._driver and self._driver.strip().lstrip("{").rstrip("}").strip() + + @property + def dsn(self) -> Optional[str]: + """DSN from init param if given; else try to find one in connection extra.""" + if not self._dsn: + dsn = self.connection_extra_lower.get("dsn") + if dsn: + self._dsn = dsn.strip() + return self._dsn + + @property + def odbc_connection_string(self): + """ + ODBC connection string + We build connection string instead of using ``pyodbc.connect`` params because, for example, there is + no param representing ``ApplicationIntent=ReadOnly``. Any key-value pairs provided in + ``Connection.extra`` will be added to the connection string. + """ + if not self._conn_str: + conn_str = "" + if self.driver: + conn_str += f"DRIVER={{{self.driver}}};" + if self.dsn: + conn_str += f"DSN={self.dsn};" + if self.connection.host: + conn_str += f"SERVER={self.connection.host};" + database = self.database or self.connection.schema + if database: + conn_str += f"DATABASE={database};" + if self.connection.login: + conn_str += f"UID={self.connection.login};" + if self.connection.password: + conn_str += f"PWD={self.connection.password};" + if self.connection.port: + f"PORT={self.connection.port};" + + extra_exclude = {"driver", "dsn", "connect_kwargs", "sqlalchemy_scheme"} + extra_params = { + k: v + for k, v in self.connection.extra_dejson.items() + if not k.lower() in extra_exclude + } + for k, v in extra_params.items(): + conn_str += f"{k}={v};" + + self._conn_str = conn_str + return self._conn_str + + @property + def connect_kwargs(self) -> dict: + """ + Returns effective kwargs to be passed to ``pyodbc.connect`` after merging between conn extra, + ``connect_kwargs`` and hook init. + + Hook ``connect_kwargs`` precedes ``connect_kwargs`` from conn extra. + + String values for 'true' and 'false' are converted to bool type. + + If ``attrs_before`` provided, keys and values are converted to int, as required by pyodbc. + """ + + def clean_bool(val): # pylint: disable=inconsistent-return-statements + if hasattr(val, "lower"): + if val.lower() == "true": + return True + elif val.lower() == "false": + return False + else: + return val + + conn_connect_kwargs = self.connection_extra_lower.get("connect_kwargs", {}) + hook_connect_kwargs = self._connect_kwargs or {} + merged_connect_kwargs = merge_dicts(conn_connect_kwargs, hook_connect_kwargs) + + if "attrs_before" in merged_connect_kwargs: + merged_connect_kwargs["attrs_before"] = { + int(k): int(v) for k, v in merged_connect_kwargs["attrs_before"].items() + } + + return {k: clean_bool(v) for k, v in merged_connect_kwargs.items()} + + def get_conn(self) -> pyodbc.Connection: + """Returns a pyodbc connection object.""" + conn = pyodbc.connect(self.odbc_connection_string, **self.connect_kwargs) + return conn + + def get_uri(self) -> str: + """URI invoked in :py:meth:`~airflow.hooks.dbapi.DbApiHook.get_sqlalchemy_engine` method""" + quoted_conn_str = quote_plus(self.odbc_connection_string) + uri = f"{self.sqlalchemy_scheme}:///?odbc_connect={quoted_conn_str}" + return uri + + def get_sqlalchemy_connection( + self, + connect_kwargs: Optional[dict] = None, + engine_kwargs: Optional[dict] = None, + ) -> Any: + """Sqlalchemy connection object""" + engine = self.get_sqlalchemy_engine(engine_kwargs=engine_kwargs) + cnx = engine.connect(**(connect_kwargs or {})) + return cnx diff --git a/reference/providers/odbc/provider.yaml b/reference/providers/odbc/provider.yaml new file mode 100644 index 0000000..a84e554 --- /dev/null +++ b/reference/providers/odbc/provider.yaml @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-odbc +name: ODBC +description: | + `ODBC `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: ODBC + external-doc-url: https://github.com/mkleehammer/pyodbc/wiki + logo: /integration-logos/odbc/ODBC.png + tags: [protocol] + +hooks: + - integration-name: ODBC + python-modules: + - airflow.providers.odbc.hooks.odbc + +hook-class-names: + - airflow.providers.odbc.hooks.odbc.OdbcHook diff --git a/reference/providers/openfaas/CHANGELOG.rst b/reference/providers/openfaas/CHANGELOG.rst new file mode 100644 index 0000000..dcc9add --- /dev/null +++ b/reference/providers/openfaas/CHANGELOG.rst @@ -0,0 +1,41 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.1.1 +..... + +Bug fixes +~~~~~~~~~ + +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` + + +1.1.0 +..... + +Updated documentation and readme files. + +* ``Add openfaas sync call (#13356)`` + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/openfaas/__init__.py b/reference/providers/openfaas/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/openfaas/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/openfaas/hooks/__init__.py b/reference/providers/openfaas/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/openfaas/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/openfaas/hooks/openfaas.py b/reference/providers/openfaas/hooks/openfaas.py new file mode 100644 index 0000000..26e455f --- /dev/null +++ b/reference/providers/openfaas/hooks/openfaas.py @@ -0,0 +1,123 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict + +import requests +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + +OK_STATUS_CODE = 202 + + +class OpenFaasHook(BaseHook): + """ + Interact with OpenFaaS to query, deploy, invoke and update function + + :param function_name: Name of the function, Defaults to None + :type function_name: str + :param conn_id: openfaas connection to use, Defaults to open_faas_default + for example host : http://openfaas.faas.com, Conn Type : Http + :type conn_id: str + """ + + GET_FUNCTION = "/system/function/" + INVOKE_ASYNC_FUNCTION = "/async-function/" + INVOKE_FUNCTION = "/function/" + DEPLOY_FUNCTION = "/system/functions" + UPDATE_FUNCTION = "/system/functions" + + def __init__( + self, function_name=None, conn_id: str = "open_faas_default", *args, **kwargs + ) -> None: + super().__init__(*args, **kwargs) + self.function_name = function_name + self.conn_id = conn_id + + def get_conn(self): + conn = self.get_connection(self.conn_id) + return conn + + def deploy_function( + self, overwrite_function_if_exist: bool, body: Dict[str, Any] + ) -> None: + """Deploy OpenFaaS function""" + if overwrite_function_if_exist: + self.log.info( + "Function already exist %s going to update", self.function_name + ) + self.update_function(body) + else: + url = self.get_conn().host + self.DEPLOY_FUNCTION + self.log.info("Deploying function %s", url) + response = requests.post(url, body) + if response.status_code != OK_STATUS_CODE: + self.log.error("Response status %d", response.status_code) + self.log.error("Failed to deploy") + raise AirflowException("failed to deploy") + else: + self.log.info("Function deployed %s", self.function_name) + + def invoke_async_function(self, body: Dict[str, Any]) -> None: + """Invoking function asynchronously""" + url = self.get_conn().host + self.INVOKE_ASYNC_FUNCTION + self.function_name + self.log.info("Invoking function asynchronously %s", url) + response = requests.post(url, body) + if response.ok: + self.log.info("Invoked %s", self.function_name) + else: + self.log.error("Response status %d", response.status_code) + raise AirflowException("failed to invoke function") + + def invoke_function(self, body: Dict[str, Any]) -> None: + """Invoking function synchronously, will block until function completes and returns""" + url = self.get_conn().host + self.INVOKE_FUNCTION + self.function_name + self.log.info("Invoking function synchronously %s", url) + response = requests.post(url, body) + if response.ok: + self.log.info("Invoked %s", self.function_name) + self.log.info("Response code %s", response.status_code) + self.log.info("Response %s", response.text) + else: + self.log.error("Response status %d", response.status_code) + raise AirflowException("failed to invoke function") + + def update_function(self, body: Dict[str, Any]) -> None: + """Update OpenFaaS function""" + url = self.get_conn().host + self.UPDATE_FUNCTION + self.log.info("Updating function %s", url) + response = requests.put(url, body) + if response.status_code != OK_STATUS_CODE: + self.log.error("Response status %d", response.status_code) + self.log.error( + "Failed to update response %s", response.content.decode("utf-8") + ) + raise AirflowException("failed to update " + self.function_name) + else: + self.log.info("Function was updated") + + def does_function_exist(self) -> bool: + """Whether OpenFaaS function exists or not""" + url = self.get_conn().host + self.GET_FUNCTION + self.function_name + + response = requests.get(url) + if response.ok: + return True + else: + self.log.error("Failed to find function %s", self.function_name) + return False diff --git a/reference/providers/openfaas/provider.yaml b/reference/providers/openfaas/provider.yaml new file mode 100644 index 0000000..8036fc1 --- /dev/null +++ b/reference/providers/openfaas/provider.yaml @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-openfaas +name: OpenFaaS +description: | + `OpenFaaS `__ + +versions: + - 1.1.1 + - 1.1.0 + - 1.0.0 + +integrations: + - integration-name: OpenFaaS + external-doc-url: https://www.openfaas.com/ + logo: /integration-logos/openfaas/OpenFaaS.png + tags: [software] + +hooks: + - integration-name: OpenFaaS + python-modules: + - airflow.providers.openfaas.hooks.openfaas diff --git a/reference/providers/opsgenie/CHANGELOG.rst b/reference/providers/opsgenie/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/opsgenie/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/opsgenie/__init__.py b/reference/providers/opsgenie/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/opsgenie/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/opsgenie/hooks/__init__.py b/reference/providers/opsgenie/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/opsgenie/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/opsgenie/hooks/opsgenie_alert.py b/reference/providers/opsgenie/hooks/opsgenie_alert.py new file mode 100644 index 0000000..9e24f69 --- /dev/null +++ b/reference/providers/opsgenie/hooks/opsgenie_alert.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import json +from typing import Any, Optional + +import requests +from airflow.exceptions import AirflowException +from airflow.providers.http.hooks.http import HttpHook + + +class OpsgenieAlertHook(HttpHook): + """ + This hook allows you to post alerts to Opsgenie. + Accepts a connection that has an Opsgenie API key as the connection's password. + This hook sets the domain to conn_id.host, and if not set will default + to ``https://api.opsgenie.com``. + + Each Opsgenie API key can be pre-configured to a team integration. + You can override these defaults in this hook. + + :param opsgenie_conn_id: The name of the Opsgenie connection to use + :type opsgenie_conn_id: str + + """ + + def __init__( + self, opsgenie_conn_id: str = "opsgenie_default", *args, **kwargs + ) -> None: + super().__init__(http_conn_id=opsgenie_conn_id, *args, **kwargs) # type: ignore[misc] + + def _get_api_key(self) -> str: + """Get Opsgenie api_key for creating alert""" + conn = self.get_connection(self.http_conn_id) + api_key = conn.password + if not api_key: + raise AirflowException( + "Opsgenie API Key is required for this hook, please check your conn_id configuration." + ) + return api_key + + def get_conn(self, headers: Optional[dict] = None) -> requests.Session: + """ + Overwrite HttpHook get_conn because this hook just needs base_url + and headers, and does not need generic params + + :param headers: additional headers to be passed through as a dictionary + :type headers: dict + """ + conn = self.get_connection(self.http_conn_id) + self.base_url = conn.host if conn.host else "https://api.opsgenie.com" + session = requests.Session() + if headers: + session.headers.update(headers) + return session + + def execute(self, payload: Optional[dict] = None) -> Any: + """ + Execute the Opsgenie Alert call + + :param payload: Opsgenie API Create Alert payload values + See https://docs.opsgenie.com/docs/alert-api#section-create-alert + :type payload: dict + """ + payload = payload or {} + api_key = self._get_api_key() + return self.run( + endpoint="v2/alerts", + data=json.dumps(payload), + headers={ + "Content-Type": "application/json", + "Authorization": f"GenieKey {api_key}", + }, + ) diff --git a/reference/providers/opsgenie/operators/__init__.py b/reference/providers/opsgenie/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/opsgenie/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/opsgenie/operators/opsgenie_alert.py b/reference/providers/opsgenie/operators/opsgenie_alert.py new file mode 100644 index 0000000..4840b92 --- /dev/null +++ b/reference/providers/opsgenie/operators/opsgenie_alert.py @@ -0,0 +1,143 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Any, Dict, List, Optional + +from airflow.models import BaseOperator +from airflow.providers.opsgenie.hooks.opsgenie_alert import OpsgenieAlertHook +from airflow.utils.decorators import apply_defaults + + +class OpsgenieAlertOperator(BaseOperator): + """ + This operator allows you to post alerts to Opsgenie. + Accepts a connection that has an Opsgenie API key as the connection's password. + This operator sets the domain to conn_id.host, and if not set will default + to ``https://api.opsgenie.com``. + + Each Opsgenie API key can be pre-configured to a team integration. + You can override these defaults in this operator. + + :param opsgenie_conn_id: The name of the Opsgenie connection to use + :type opsgenie_conn_id: str + :param message: The Message of the Opsgenie alert (templated) + :type message: str + :param alias: Client-defined identifier of the alert (templated) + :type alias: str + :param description: Description field of the alert (templated) + :type description: str + :param responders: Teams, users, escalations and schedules that + the alert will be routed to send notifications. + :type responders: list[dict] + :param visible_to: Teams and users that the alert will become visible + to without sending any notification. + :type visible_to: list[dict] + :param actions: Custom actions that will be available for the alert. + :type actions: list[str] + :param tags: Tags of the alert. + :type tags: list[str] + :param details: Map of key-value pairs to use as custom properties of the alert. + :type details: dict + :param entity: Entity field of the alert that is + generally used to specify which domain alert is related to. (templated) + :type entity: str + :param # Source field of the alert. Default value is + IP address of the incoming request. + :type # str + :param priority: Priority level of the alert. Default value is P3. (templated) + :type priority: str + :param user: Display name of the request owner. + :type user: str + :param note: Additional note that will be added while creating the alert. (templated) + :type note: str + """ + + template_fields = ("message", "alias", "description", "entity", "priority", "note") + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + message: str, + opsgenie_conn_id: str = "opsgenie_default", + alias: Optional[str] = None, + description: Optional[str] = None, + responders: Optional[List[dict]] = None, + visible_to: Optional[List[dict]] = None, + actions: Optional[List[str]] = None, + tags: Optional[List[str]] = None, + details: Optional[dict] = None, + entity: Optional[str] = None, + # Optional[str] = None, + priority: Optional[str] = None, + user: Optional[str] = None, + note: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.message = message + self.opsgenie_conn_id = opsgenie_conn_id + self.alias = alias + self.description = description + self.responders = responders + self.visible_to = visible_to + self.actions = actions + self.tags = tags + self.details = details + self.entity = entity + self.source = source + self.priority = priority + self.user = user + self.note = note + self.hook: Optional[OpsgenieAlertHook] = None + + def _build_opsgenie_payload(self) -> Dict[str, Any]: + """ + Construct the Opsgenie JSON payload. All relevant parameters are combined here + to a valid Opsgenie JSON payload. + + :return: Opsgenie payload (dict) to send + """ + payload = {} + + for key in [ + "message", + "alias", + "description", + "responders", + "visible_to", + "actions", + "tags", + "details", + "entity", + "source", + "priority", + "user", + "note", + ]: + val = getattr(self, key) + if val: + payload[key] = val + return payload + + def execute(self, context) -> None: + """Call the OpsgenieAlertHook to post message""" + self.hook = OpsgenieAlertHook(self.opsgenie_conn_id) + self.hook.execute(self._build_opsgenie_payload()) diff --git a/reference/providers/opsgenie/provider.yaml b/reference/providers/opsgenie/provider.yaml new file mode 100644 index 0000000..bd86dd8 --- /dev/null +++ b/reference/providers/opsgenie/provider.yaml @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-opsgenie +name: Opsgenie +description: | + `Opsgenie `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Opsgenie + external-doc-url: https://www.opsgenie.com/ + logo: /integration-logos/opsgenie/Opsgenie.png + tags: [service] + +operators: + - integration-name: Opsgenie + python-modules: + - airflow.providers.opsgenie.operators.opsgenie_alert + +hooks: + - integration-name: Opsgenie + python-modules: + - airflow.providers.opsgenie.hooks.opsgenie_alert diff --git a/reference/providers/oracle/CHANGELOG.rst b/reference/providers/oracle/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/oracle/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/oracle/__init__.py b/reference/providers/oracle/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/oracle/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/oracle/hooks/__init__.py b/reference/providers/oracle/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/oracle/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/oracle/hooks/oracle.py b/reference/providers/oracle/hooks/oracle.py new file mode 100644 index 0000000..dadfb22 --- /dev/null +++ b/reference/providers/oracle/hooks/oracle.py @@ -0,0 +1,247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import List, Optional + +import cx_Oracle +import numpy +from airflow.hooks.dbapi import DbApiHook + + +class OracleHook(DbApiHook): + """Interact with Oracle SQL.""" + + conn_name_attr = "oracle_conn_id" + default_conn_name = "oracle_default" + conn_type = "oracle" + hook_name = "Oracle" + + supports_autocommit = False + + # pylint: disable=c-extension-no-member + def get_conn(self) -> "OracleHook": + """ + Returns a oracle connection object + Optional parameters for using a custom DSN connection + (instead of using a server alias from tnsnames.ora) + The dsn (data source name) is the TNS entry + (from the Oracle names server or tnsnames.ora file) + or is a string like the one returned from makedsn(). + + :param dsn: the host address for the Oracle server + :param service_name: the db_unique_name of the database + that you are connecting to (CONNECT_DATA part of TNS) + + You can set these parameters in the extra fields of your connection + as in ``{ "dsn":"some.host.address" , "service_name":"some.service.name" }`` + see more param detail in + `cx_Oracle.connect `_ + """ + conn = self.get_connection( + self.oracle_conn_id # type: ignore[attr-defined] # pylint: disable=no-member + ) + conn_config = {"user": conn.login, "password": conn.password} + dsn = conn.extra_dejson.get("dsn") + sid = conn.extra_dejson.get("sid") + mod = conn.extra_dejson.get("module") + + service_name = conn.extra_dejson.get("service_name") + port = conn.port if conn.port else 1521 + if dsn and sid and not service_name: + conn_config["dsn"] = cx_Oracle.makedsn(dsn, port, sid) + elif dsn and service_name and not sid: + conn_config["dsn"] = cx_Oracle.makedsn(dsn, port, service_name=service_name) + else: + conn_config["dsn"] = conn.host + + if "encoding" in conn.extra_dejson: + conn_config["encoding"] = conn.extra_dejson.get("encoding") + # if `encoding` is specific but `nencoding` is not + # `nencoding` should use same values as `encoding` to set encoding, inspired by + # https://github.com/oracle/python-cx_Oracle/issues/157#issuecomment-371877993 + if "nencoding" not in conn.extra_dejson: + conn_config["nencoding"] = conn.extra_dejson.get("encoding") + if "nencoding" in conn.extra_dejson: + conn_config["nencoding"] = conn.extra_dejson.get("nencoding") + if "threaded" in conn.extra_dejson: + conn_config["threaded"] = conn.extra_dejson.get("threaded") + if "events" in conn.extra_dejson: + conn_config["events"] = conn.extra_dejson.get("events") + + mode = conn.extra_dejson.get("mode", "").lower() + if mode == "sysdba": + conn_config["mode"] = cx_Oracle.SYSDBA + elif mode == "sysasm": + conn_config["mode"] = cx_Oracle.SYSASM + elif mode == "sysoper": + conn_config["mode"] = cx_Oracle.SYSOPER + elif mode == "sysbkp": + conn_config["mode"] = cx_Oracle.SYSBKP + elif mode == "sysdgd": + conn_config["mode"] = cx_Oracle.SYSDGD + elif mode == "syskmt": + conn_config["mode"] = cx_Oracle.SYSKMT + elif mode == "sysrac": + conn_config["mode"] = cx_Oracle.SYSRAC + + purity = conn.extra_dejson.get("purity", "").lower() + if purity == "new": + conn_config["purity"] = cx_Oracle.ATTR_PURITY_NEW + elif purity == "self": + conn_config["purity"] = cx_Oracle.ATTR_PURITY_SELF + elif purity == "default": + conn_config["purity"] = cx_Oracle.ATTR_PURITY_DEFAULT + + conn = cx_Oracle.connect(**conn_config) + if mod is not None: + conn.module = mod + + return conn + + def insert_rows( + self, + table: str, + rows: List[tuple], + target_fields=None, + commit_every: int = 1000, + replace: Optional[bool] = False, + **kwargs, + ) -> None: + """ + A generic way to insert a set of tuples into a table, + the whole set of inserts is treated as one transaction + Changes from standard DbApiHook implementation: + + - Oracle SQL queries in cx_Oracle can not be terminated with a semicolon (`;`) + - Replace NaN values with NULL using `numpy.nan_to_num` (not using + `is_nan()` because of input types error for strings) + - Coerce datetime cells to Oracle DATETIME format during insert + + :param table: target Oracle table, use dot notation to target a + specific database + :type table: str + :param rows: the rows to insert into the table + :type rows: iterable of tuples + :param target_fields: the names of the columns to fill in the table + :type target_fields: iterable of str + :param commit_every: the maximum number of rows to insert in one transaction + Default 1000, Set greater than 0. + Set 1 to insert each row in each single transaction + :type commit_every: int + :param replace: Whether to replace instead of insert + :type replace: bool + """ + if target_fields: + target_fields = ", ".join(target_fields) + target_fields = f"({target_fields})" + else: + target_fields = "" + conn = self.get_conn() + cur = conn.cursor() # type: ignore[attr-defined] + if self.supports_autocommit: + cur.execute("SET autocommit = 0") + conn.commit() # type: ignore[attr-defined] + i = 0 + for row in rows: + i += 1 + lst = [] + for cell in row: + if isinstance(cell, str): + lst.append("'" + str(cell).replace("'", "''") + "'") + elif cell is None: + lst.append("NULL") + elif isinstance(cell, float) and numpy.isnan( + cell + ): # coerce numpy NaN to NULL + lst.append("NULL") + elif isinstance(cell, numpy.datetime64): + lst.append("'" + str(cell) + "'") + elif isinstance(cell, datetime): + lst.append( + "to_date('" + + cell.strftime("%Y-%m-%d %H:%M:%S") + + "','YYYY-MM-DD HH24:MI:SS')" + ) + else: + lst.append(str(cell)) + values = tuple(lst) + sql = f"INSERT /*+ APPEND */ INTO {table} {target_fields} VALUES ({','.join(values)})" + cur.execute(sql) + if i % commit_every == 0: + conn.commit() # type: ignore[attr-defined] + self.log.info("Loaded %s into %s rows so far", i, table) + conn.commit() # type: ignore[attr-defined] + cur.close() + conn.close() # type: ignore[attr-defined] + self.log.info("Done loading. Loaded a total of %s rows", i) + + def bulk_insert_rows( + self, + table: str, + rows: List[tuple], + target_fields: Optional[List[str]] = None, + commit_every: int = 5000, + ): + """ + A performant bulk insert for cx_Oracle + that uses prepared statements via `executemany()`. + For best performance, pass in `rows` as an iterator. + + :param table: target Oracle table, use dot notation to target a + specific database + :type table: str + :param rows: the rows to insert into the table + :type rows: iterable of tuples + :param target_fields: the names of the columns to fill in the table, default None. + If None, each rows should have some order as table columns name + :type target_fields: iterable of str Or None + :param commit_every: the maximum number of rows to insert in one transaction + Default 5000. Set greater than 0. Set 1 to insert each row in each transaction + :type commit_every: int + """ + if not rows: + raise ValueError("parameter rows could not be None or empty iterable") + conn = self.get_conn() + cursor = conn.cursor() # type: ignore[attr-defined] + values_base = target_fields if target_fields else rows[0] + prepared_stm = "insert into {tablename} {columns} values ({values})".format( + tablename=table, + columns="({})".format(", ".join(target_fields)) if target_fields else "", + values=", ".join(":%s" % i for i in range(1, len(values_base) + 1)), + ) + row_count = 0 + # Chunk the rows + row_chunk = [] + for row in rows: + row_chunk.append(row) + row_count += 1 + if row_count % commit_every == 0: + cursor.prepare(prepared_stm) + cursor.executemany(None, row_chunk) + conn.commit() # type: ignore[attr-defined] + self.log.info("[%s] inserted %s rows", table, row_count) + # Empty chunk + row_chunk = [] + # Commit the leftover chunk + cursor.prepare(prepared_stm) + cursor.executemany(None, row_chunk) + conn.commit() # type: ignore[attr-defined] + self.log.info("[%s] inserted %s rows", table, row_count) + cursor.close() + conn.close() # type: ignore[attr-defined] diff --git a/reference/providers/oracle/operators/__init__.py b/reference/providers/oracle/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/oracle/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/oracle/operators/oracle.py b/reference/providers/oracle/operators/oracle.py new file mode 100644 index 0000000..ac7b876 --- /dev/null +++ b/reference/providers/oracle/operators/oracle.py @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Iterable, Mapping, Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.oracle.hooks.oracle import OracleHook +from airflow.utils.decorators import apply_defaults + + +class OracleOperator(BaseOperator): + """ + Executes sql code in a specific Oracle database + + :param sql: the sql code to be executed. Can receive a str representing a sql statement, + a list of str (sql statements), or reference to a template file. + Template reference are recognized by str ending in '.sql' + (templated) + :type sql: str or list[str] + :param oracle_conn_id: reference to a specific Oracle database + :type oracle_conn_id: str + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: dict or iterable + :param autocommit: if True, each command is automatically committed. + (default value: False) + :type autocommit: bool + """ + + template_fields = ("sql",) + template_ext = (".sql",) + ui_color = "#ededed" + + @apply_defaults + def __init__( + self, + *, + sql: str, + oracle_conn_id: str = "oracle_default", + parameters: Optional[Union[Mapping, Iterable]] = None, + autocommit: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.oracle_conn_id = oracle_conn_id + self.sql = sql + self.autocommit = autocommit + self.parameters = parameters + + def execute(self, context) -> None: + self.log.info("Executing: %s", self.sql) + hook = OracleHook(oracle_conn_id=self.oracle_conn_id) + hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) diff --git a/reference/providers/oracle/provider.yaml b/reference/providers/oracle/provider.yaml new file mode 100644 index 0000000..2b41715 --- /dev/null +++ b/reference/providers/oracle/provider.yaml @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-oracle +name: Oracle +description: | + `Oracle `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Oracle + external-doc-url: https://www.oracle.com/en/database/ + logo: /integration-logos/oracle/Oracle.png + tags: [software] + +operators: + - integration-name: Oracle + python-modules: + - airflow.providers.oracle.operators.oracle + +hooks: + - integration-name: Oracle + python-modules: + - airflow.providers.oracle.hooks.oracle + +transfers: + - source-integration-name: Oracle + target-integration-name: Oracle + python-module: airflow.providers.oracle.transfers.oracle_to_oracle + +hook-class-names: + - airflow.providers.oracle.hooks.oracle.OracleHook diff --git a/reference/providers/oracle/transfers/__init__.py b/reference/providers/oracle/transfers/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/oracle/transfers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/oracle/transfers/oracle_to_oracle.py b/reference/providers/oracle/transfers/oracle_to_oracle.py new file mode 100644 index 0000000..5e047e8 --- /dev/null +++ b/reference/providers/oracle/transfers/oracle_to_oracle.py @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from airflow.models import BaseOperator +from airflow.providers.oracle.hooks.oracle import OracleHook +from airflow.utils.decorators import apply_defaults + + +class OracleToOracleOperator(BaseOperator): + """ + Moves data from Oracle to Oracle. + + + :param oracle_destination_conn_id: destination Oracle connection. + :type oracle_destination_conn_id: str + :param destination_table: destination table to insert rows. + :type destination_table: str + :param oracle_source_conn_id: source Oracle connection. + :type oracle_source_conn_id: str + :param source_sql: SQL query to execute against the source Oracle + database. (templated) + :type source_sql: str + :param source_sql_params: Parameters to use in sql query. (templated) + :type source_sql_params: dict + :param rows_chunk: number of rows per chunk to commit. + :type rows_chunk: int + """ + + template_fields = ("source_sql", "source_sql_params") + ui_color = "#e08c8c" + + @apply_defaults + def __init__( + self, + *, + oracle_destination_conn_id: str, + destination_table: str, + oracle_source_conn_id: str, + source_sql: str, + source_sql_params: Optional[dict] = None, + rows_chunk: int = 5000, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if source_sql_params is None: + source_sql_params = {} + self.oracle_destination_conn_id = oracle_destination_conn_id + self.destination_table = destination_table + self.oracle_source_conn_id = oracle_source_conn_id + self.source_sql = source_sql + self.source_sql_params = source_sql_params + self.rows_chunk = rows_chunk + + # pylint: disable=unused-argument + def _execute(self, src_hook, dest_hook, context) -> None: + with src_hook.get_conn() as src_conn: + cursor = src_conn.cursor() + self.log.info("Querying data from # %s", self.oracle_source_conn_id) + cursor.execute(self.source_sql, self.source_sql_params) + target_fields = list(map(lambda field: field[0], cursor.description)) + + rows_total = 0 + rows = cursor.fetchmany(self.rows_chunk) + while len(rows) > 0: + rows_total += len(rows) + dest_hook.bulk_insert_rows( + self.destination_table, + rows, + target_fields=target_fields, + commit_every=self.rows_chunk, + ) + rows = cursor.fetchmany(self.rows_chunk) + self.log.info("Total inserted: %s rows", rows_total) + + self.log.info("Finished data transfer.") + cursor.close() + + def execute(self, context) -> None: + src_hook = OracleHook(oracle_conn_id=self.oracle_source_conn_id) + dest_hook = OracleHook(oracle_conn_id=self.oracle_destination_conn_id) + self._execute(src_hook, dest_hook, context) diff --git a/reference/providers/pagerduty/CHANGELOG.rst b/reference/providers/pagerduty/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/pagerduty/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/pagerduty/__init__.py b/reference/providers/pagerduty/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/pagerduty/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/pagerduty/hooks/__init__.py b/reference/providers/pagerduty/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/pagerduty/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/pagerduty/hooks/pagerduty.py b/reference/providers/pagerduty/hooks/pagerduty.py new file mode 100644 index 0000000..c6f32d4 --- /dev/null +++ b/reference/providers/pagerduty/hooks/pagerduty.py @@ -0,0 +1,172 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for sending or receiving data from PagerDuty as well as creating PagerDuty incidents.""" +from typing import Any, Dict, List, Optional + +import pdpyras +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +class PagerdutyHook(BaseHook): + """ + Takes both PagerDuty API token directly and connection that has PagerDuty API token. + + If both supplied, PagerDuty API token will be used. + + :param token: PagerDuty API token + :param pagerduty_conn_id: connection that has PagerDuty API token in the password field + """ + + def __init__( + self, token: Optional[str] = None, pagerduty_conn_id: Optional[str] = None + ) -> None: + super().__init__() + self.routing_key = None + self._session = None + + if pagerduty_conn_id is not None: + conn = self.get_connection(pagerduty_conn_id) + self.token = conn.get_password() + + routing_key = conn.extra_dejson.get("routing_key") + if routing_key: + self.routing_key = routing_key + + if token is not None: # token takes higher priority + self.token = token + + if self.token is None: + raise AirflowException( + "Cannot get token: No valid api token nor pagerduty_conn_id supplied." + ) + + def get_session(self) -> pdpyras.APISession: + """ + Returns `pdpyras.APISession` for use with sending or receiving data through the PagerDuty REST API. + + The `pdpyras` library supplies a class `pdpyras.APISession` extending `requests.Session` from the + Requests HTTP library. + + Documentation on how to use the `APISession` class can be found at: + https://pagerduty.github.io/pdpyras/#data-access-abstraction + """ + self._session = pdpyras.APISession(self.token) + return self._session + + # pylint: disable=too-many-arguments + def create_event( + self, + summary: str, + severity: str, + # str = "airflow", + action: str = "trigger", + routing_key: Optional[str] = None, + dedup_key: Optional[str] = None, + custom_details: Optional[Any] = None, + group: Optional[str] = None, + component: Optional[str] = None, + class_type: Optional[str] = None, + images: Optional[List[Any]] = None, + links: Optional[List[Any]] = None, + ) -> Dict: + """ + Create event for service integration. + + :param summary: Summary for the event + :type summary: str + :param severity: Severity for the event, needs to be one of: info, warning, error, critical + :type severity: str + :param # Specific human-readable unique identifier, such as a + hostname, for the system having the problem. + :type # str + :param action: Event action, needs to be one of: trigger, acknowledge, + resolve. Default to trigger if not specified. + :type action: str + :param routing_key: Integration key. If not specified, will try to read + from connection's extra json blob. + :type routing_key: str + :param dedup_key: A string which identifies the alert triggered for the given event. + Required for the actions acknowledge and resolve. + :type dedup_key: str + :param custom_details: Free-form details from the event. Can be a dictionary or a string. + If a dictionary is passed it will show up in PagerDuty as a table. + :type custom_details: dict or str + :param group: A cluster or grouping of sources. For example, sources + “prod-datapipe-02” and “prod-datapipe-03” might both be part of “prod-datapipe” + :type group: str + :param component: The part or component of the affected system that is broken. + :type component: str + :param class_type: The class/type of the event. + :type class_type: str + :param images: List of images to include. Each dictionary in the list accepts the following keys: + `src`: The source (URL) of the image being attached to the incident. This image must be served via + HTTPS. + `href`: [Optional] URL to make the image a clickable link. + `alt`: [Optional] Alternative text for the image. + :type images: list[dict] + :param links: List of links to include. Each dictionary in the list accepts the following keys: + `href`: URL of the link to be attached. + `text`: [Optional] Plain text that describes the purpose of the link, and can be used as the + link's text. + :type links: list[dict] + :return: PagerDuty Events API v2 response. + :rtype: dict + """ + if routing_key is None: + routing_key = self.routing_key + if routing_key is None: + raise AirflowException("No routing/integration key specified.") + payload = { + "summary": summary, + "severity": severity, + "source": source, + } + if custom_details is not None: + payload["custom_details"] = custom_details + if component: + payload["component"] = component + if group: + payload["group"] = group + if class_type: + payload["class"] = class_type + + actions = ("trigger", "acknowledge", "resolve") + if action not in actions: + raise ValueError(f"Event action must be one of: {', '.join(actions)}") + data = { + "event_action": action, + "payload": payload, + } + if dedup_key: + data["dedup_key"] = dedup_key + elif action != "trigger": + raise ValueError( + "The dedup_key property is required for event_action=%s events, and it must \ + be a string." + % action + ) + if images is not None: + data["images"] = images + if links is not None: + data["links"] = links + + session = pdpyras.EventsAPISession(routing_key) + resp = session.post("/v2/enqueue", json=data) + resp.raise_for_status() + return resp.json() diff --git a/reference/providers/pagerduty/provider.yaml b/reference/providers/pagerduty/provider.yaml new file mode 100644 index 0000000..6585199 --- /dev/null +++ b/reference/providers/pagerduty/provider.yaml @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-pagerduty +name: Pagerduty +description: | + `Pagerduty `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Pagerduty + external-doc-url: https://www.pagerduty.com/ + logo: /integration-logos/pagerduty/PagerDuty.png + tags: [service] + + +hooks: + - integration-name: Pagerduty + python-modules: + - airflow.providers.pagerduty.hooks.pagerduty diff --git a/reference/providers/papermill/CHANGELOG.rst b/reference/providers/papermill/CHANGELOG.rst new file mode 100644 index 0000000..3e97015 --- /dev/null +++ b/reference/providers/papermill/CHANGELOG.rst @@ -0,0 +1,39 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.2 +..... + +Bug fixes +~~~~~~~~~ + +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` + + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/papermill/__init__.py b/reference/providers/papermill/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/papermill/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/papermill/example_dags/__init__.py b/reference/providers/papermill/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/papermill/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/papermill/example_dags/example_papermill.py b/reference/providers/papermill/example_dags/example_papermill.py new file mode 100644 index 0000000..ad4b609 --- /dev/null +++ b/reference/providers/papermill/example_dags/example_papermill.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This DAG will use Papermill to run the notebook "hello_world", based on the execution date +it will create an output notebook "out-". All fields, including the keys in the parameters, are +templated. +""" +import os +from datetime import timedelta + +import scrapbook as sb +from airflow import DAG +from airflow.lineage import AUTO +from airflow.operators.python import PythonOperator +from airflow.providers.papermill.operators.papermill import PapermillOperator +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", +} + +with DAG( + dag_id="example_papermill_operator", + default_args=default_args, + schedule_interval="0 0 * * *", + start_date=days_ago(2), + dagrun_timeout=timedelta(minutes=60), + tags=["example"], +) as dag_1: + # [START howto_operator_papermill] + run_this = PapermillOperator( + task_id="run_example_notebook", + input_nb="/tmp/hello_world.ipynb", + output_nb="/tmp/out-{{ execution_date }}.ipynb", + parameters={"msgs": "Ran from Airflow at {{ execution_date }}!"}, + ) + # [END howto_operator_papermill] + + +def check_notebook(inlets, execution_date): + """ + Verify the message in the notebook + """ + notebook = sb.read_notebook(inlets[0].url) + message = notebook.scraps["message"] + print(f"Message in notebook {message} for {execution_date}") + + if message.data != f"Ran from Airflow at {execution_date}!": + return False + + return True + + +with DAG( + dag_id="example_papermill_operator", + default_args=default_args, + schedule_interval="0 0 * * *", + start_date=days_ago(2), + dagrun_timeout=timedelta(minutes=60), +) as dag_2: + + run_this = PapermillOperator( + task_id="run_example_notebook", + input_nb=os.path.join( + os.path.dirname(os.path.realpath(__file__)), "input_notebook.ipynb" + ), + output_nb="/tmp/out-{{ execution_date }}.ipynb", + parameters={"msgs": "Ran from Airflow at {{ execution_date }}!"}, + ) + + check_output = PythonOperator( + task_id="check_out", python_callable=check_notebook, inlets=AUTO + ) + + check_output.set_upstream(run_this) diff --git a/reference/providers/papermill/example_dags/input_notebook.ipynb b/reference/providers/papermill/example_dags/input_notebook.ipynb new file mode 100644 index 0000000..eb73f82 --- /dev/null +++ b/reference/providers/papermill/example_dags/input_notebook.ipynb @@ -0,0 +1,120 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Licensed to the Apache Software Foundation (ASF) under one\n", + " or more contributor license agreements. See the NOTICE file\n", + " distributed with this work for additional information\n", + " regarding copyright ownership. The ASF licenses this file\n", + " to you under the Apache License, Version 2.0 (the\n", + " \"License\"); you may not use this file except in compliance\n", + " with the License. You may obtain a copy of the License at\n", + "\n", + " http://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + " Unless required by applicable law or agreed to in writing,\n", + " software distributed under the License is distributed on an\n", + " \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", + " KIND, either express or implied. See the License for the\n", + " specific language governing permissions and limitations\n", + " under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is an example jupyter notebook for Apache Airflow that shows how to use\n", + "papermill in combination with scrapbook" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import scrapbook as sb" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The parameter tag for cells is used to tell papermill where it can find variables it needs to set" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "msgs = \"Hello!\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inside the notebook you can save data by calling the glue function. Then later you can read the results of that notebook by “scrap” name (see the Airflow Papermill example DAG)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/scrapbook.scrap.text+json": { + "data": "Hello!", + "encoder": "text", + "name": "message", + "version": 1 + } + }, + "metadata": { + "scrapbook": { + "data": true, + "display": false, + "name": "message" + } + }, + "output_type": "display_data" + } + ], + "source": [ + "sb.glue('message', msgs)" + ] + } + ], + "metadata": { + "celltoolbar": "Tags", + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/reference/providers/papermill/operators/__init__.py b/reference/providers/papermill/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/papermill/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/papermill/operators/papermill.py b/reference/providers/papermill/operators/papermill.py new file mode 100644 index 0000000..7e0b122 --- /dev/null +++ b/reference/providers/papermill/operators/papermill.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Dict, Optional + +import attr +import papermill as pm +from airflow.lineage.entities import File +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +@attr.s(auto_attribs=True) +class NoteBook(File): + """Jupyter notebook""" + + type_hint: Optional[str] = "jupyter_notebook" + parameters: Optional[Dict] = {} + + meta_schema: str = __name__ + ".NoteBook" + + +class PapermillOperator(BaseOperator): + """ + Executes a jupyter notebook through papermill that is annotated with parameters + + :param input_nb: input notebook (can also be a NoteBook or a File inlet) + :type input_nb: str + :param output_nb: output notebook (can also be a NoteBook or File outlet) + :type output_nb: str + :param parameters: the notebook parameters to set + :type parameters: dict + """ + + supports_lineage = True + + @apply_defaults + def __init__( + self, + *, + input_nb: Optional[str] = None, + output_nb: Optional[str] = None, + parameters: Optional[Dict] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if input_nb: + self.inlets.append(NoteBook(url=input_nb, parameters=parameters)) + if output_nb: + self.outlets.append(NoteBook(url=output_nb)) + + def execute(self, context): + if not self.inlets or not self.outlets: + raise ValueError("Input notebook or output notebook is not specified") + + for i in range(len(self.inlets)): + pm.execute_notebook( + self.inlets[i].url, + self.outlets[i].url, + parameters=self.inlets[i].parameters, + progress_bar=False, + report_mode=True, + ) diff --git a/reference/providers/papermill/provider.yaml b/reference/providers/papermill/provider.yaml new file mode 100644 index 0000000..dada8ce --- /dev/null +++ b/reference/providers/papermill/provider.yaml @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-papermill +name: Papermill +description: | + `Papermill `__ + +versions: + - 1.0.2 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Papermill + external-doc-url: https://github.com/nteract/papermill + how-to-guide: + - /docs/apache-airflow-providers-papermill/operators.rst + logo: /integration-logos/papermill/Papermill.png + tags: [software] + +operators: + - integration-name: Papermill + python-modules: + - airflow.providers.papermill.operators.papermill diff --git a/reference/providers/plexus/CHANGELOG.rst b/reference/providers/plexus/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/plexus/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/plexus/__init__.py b/reference/providers/plexus/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/plexus/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/plexus/example_dags/__init__.py b/reference/providers/plexus/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/plexus/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/plexus/example_dags/example_plexus.py b/reference/providers/plexus/example_dags/example_plexus.py new file mode 100644 index 0000000..38f06b4 --- /dev/null +++ b/reference/providers/plexus/example_dags/example_plexus.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow import DAG +from airflow.providers.plexus.operators.job import PlexusJobOperator +from airflow.utils.dates import days_ago + +HOME = "/home/acc" +T3_PRERUN_SCRIPT = ( + "cp {home}/imdb/run_scripts/mlflow.sh {home}/ && chmod +x mlflow.sh".format( + home=HOME + ) +) + +args = {"owner": "core scientific", "retries": 1} + +dag = DAG( + "test", + default_args=args, + description="testing plexus operator", + start_date=days_ago(1), + schedule_interval="@once", + catchup=False, +) + +t1 = PlexusJobOperator( + task_id="test", + job_params={ + "name": "test", + "app": "MLFlow Pipeline 01", + "queue": "DGX-2 (gpu:Tesla V100-SXM3-32GB)", + "num_nodes": 1, + "num_cores": 1, + "prerun_script": T3_PRERUN_SCRIPT, + }, + dag=dag, +) + +t1 diff --git a/reference/providers/plexus/hooks/__init__.py b/reference/providers/plexus/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/plexus/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/plexus/hooks/plexus.py b/reference/providers/plexus/hooks/plexus.py new file mode 100644 index 0000000..13e03a1 --- /dev/null +++ b/reference/providers/plexus/hooks/plexus.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any + +import arrow +import jwt +import requests +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.models import Variable + + +class PlexusHook(BaseHook): + """ + Used for jwt token generation and storage to + make Plexus API calls. Requires email and password + Airflow variables be created. + + Example: + - export AIRFLOW_VAR_EMAIL = user@corescientific.com + - export AIRFLOW_VAR_PASSWORD = ******* + + """ + + def __init__(self) -> None: + super().__init__() + self.__token = None + self.__token_exp = None + self.host = "https://apiplexus.corescientific.com/" + self.user_id = None + + def _generate_token(self) -> Any: + login = Variable.get("email") + pwd = Variable.get("password") + if login is None or pwd is None: + raise AirflowException("No valid email/password supplied.") + token_endpoint = self.host + "sso/jwt-token/" + response = requests.post( + token_endpoint, data={"email": login, "password": pwd}, timeout=5 + ) + if not response.ok: + raise AirflowException( + "Could not retrieve JWT Token. Status Code: [{}]. " + "Reason: {} - {}".format( + response.status_code, response.reason, response.text + ) + ) + token = response.json()["access"] + payload = jwt.decode(token, verify=False) + self.user_id = payload["user_id"] + self.__token_exp = payload["exp"] + + return token + + @property + def token(self) -> Any: + """Returns users token""" + if self.__token is not None: + if arrow.get(self.__token_exp) <= arrow.now(): + self.__token = self._generate_token() + return self.__token + else: + self.__token = self._generate_token() + return self.__token diff --git a/reference/providers/plexus/operators/__init__.py b/reference/providers/plexus/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/plexus/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/plexus/operators/job.py b/reference/providers/plexus/operators/job.py new file mode 100644 index 0000000..7d38ac2 --- /dev/null +++ b/reference/providers/plexus/operators/job.py @@ -0,0 +1,169 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +import time +from typing import Any, Dict, Optional + +import requests +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.plexus.hooks.plexus import PlexusHook +from airflow.utils.decorators import apply_defaults + +logger = logging.getLogger(__name__) + + +class PlexusJobOperator(BaseOperator): + """ + Submits a Plexus job. + + :param job_params: parameters required to launch a job. + :type job_params: dict + + Required job parameters are the following + - "name": job name created by user. + - "app": name of the application to run. found in Plexus UI. + - "queue": public cluster name. found in Plexus UI. + - "num_nodes": number of nodes. + - "num_cores": number of cores per node. + + """ + + @apply_defaults + def __init__(self, job_params: Dict, **kwargs) -> None: + super().__init__(**kwargs) + + self.job_params = job_params + self.required_params = {"name", "app", "queue", "num_cores", "num_nodes"} + self.lookups = { + "app": ("apps/", "id", "name"), + "billing_account_id": ("users/{}/billingaccounts/", "id", None), + "queue": ("queues/", "id", "public_name"), + } + self.job_params.update({"billing_account_id": None}) + self.is_service = None + + def execute(self, context: Any) -> Any: + hook = PlexusHook() + params = self.construct_job_params(hook) + if self.is_service is True: + if self.job_params.get("expected_runtime") is None: + end_state = "Running" + else: + end_state = "Finished" + elif self.is_service is False: + end_state = "Completed" + else: + raise AirflowException( + "Unable to determine if application " + "is running as a batch job or service. " + "Contact Core Scientific AI Team." + ) + logger.info("creating job w/ following params: %s", params) + jobs_endpoint = hook.host + "jobs/" + headers = {"Authorization": f"Bearer {hook.token}"} + create_job = requests.post( + jobs_endpoint, headers=headers, data=params, timeout=5 + ) + if create_job.ok: + job = create_job.json() + jid = job["id"] + state = job["last_state"] + while state != end_state: + time.sleep(3) + jid_endpoint = jobs_endpoint + f"{jid}/" + get_job = requests.get(jid_endpoint, headers=headers, timeout=5) + if not get_job.ok: + raise AirflowException( + "Could not retrieve job status. Status Code: [{}]. " + "Reason: {} - {}".format( + get_job.status_code, get_job.reason, get_job.text + ) + ) + new_state = get_job.json()["last_state"] + if new_state in ("Cancelled", "Failed"): + raise AirflowException(f"Job {new_state}") + elif new_state != state: + logger.info("job is %s", new_state) + state = new_state + else: + raise AirflowException( + "Could not start job. Status Code: [{}]. " + "Reason: {} - {}".format( + create_job.status_code, create_job.reason, create_job.text + ) + ) + + def _api_lookup(self, param: str, hook): + lookup = self.lookups[param] + key = lookup[1] + mapping = None if lookup[2] is None else (lookup[2], self.job_params[param]) + + if param == "billing_account_id": + endpoint = hook.host + lookup[0].format(hook.user_id) + else: + endpoint = hook.host + lookup[0] + headers = {"Authorization": f"Bearer {hook.token}"} + response = requests.get(endpoint, headers=headers, timeout=5) + results = response.json()["results"] + + v = None + if mapping is None: + v = results[0][key] + else: + for dct in results: + if dct[mapping[0]] == mapping[1]: + v = dct[key] + if param == "app": + self.is_service = dct["is_service"] + if v is None: + raise AirflowException( + f"Could not locate value for param:{key} at endpoint: {endpoint}" + ) + + return v + + def construct_job_params(self, hook: Any) -> Dict[Any, Optional[Any]]: + """ + Creates job_params dict for api call to + launch a Plexus job. + + Some parameters required to launch a job + are not available to the user in the Plexus + UI. For example, an app id is required, but + only the app name is provided in the UI. + This function acts as a backend lookup + of the required param value using the + user-provided value. + + :param hook: plexus hook object + :type hook: airflow hook + """ + missing_params = self.required_params - set(self.job_params) + if len(missing_params) > 0: + raise AirflowException( + f"Missing the following required job_params: {', '.join(missing_params)}" + ) + params = {} + for prm in self.job_params: + if prm in self.lookups: + v = self._api_lookup(param=prm, hook=hook) + params[prm] = v + else: + params[prm] = self.job_params[prm] + return params diff --git a/reference/providers/plexus/provider.yaml b/reference/providers/plexus/provider.yaml new file mode 100644 index 0000000..b3b8009 --- /dev/null +++ b/reference/providers/plexus/provider.yaml @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-plexus +name: Plexus +description: | + `Plexus `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Plexus + external-doc-url: https://plexus.corescientific.com/ + logo: /integration-logos/plexus/Plexus.png + tags: [service] + +operators: + - integration-name: Plexus + python-modules: + - airflow.providers.plexus.operators.job +hooks: + - integration-name: Plexus + python-modules: + - airflow.providers.plexus.hooks.plexus diff --git a/reference/providers/postgres/CHANGELOG.rst b/reference/providers/postgres/CHANGELOG.rst new file mode 100644 index 0000000..722ece7 --- /dev/null +++ b/reference/providers/postgres/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. Added HowTo guide for Postgres Operator. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/postgres/__init__.py b/reference/providers/postgres/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/postgres/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/postgres/example_dags/__init__.py b/reference/providers/postgres/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/postgres/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/postgres/example_dags/example_postgres.py b/reference/providers/postgres/example_dags/example_postgres.py new file mode 100644 index 0000000..0ca8ce3 --- /dev/null +++ b/reference/providers/postgres/example_dags/example_postgres.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# [START postgres_operator_howto_guide] +import datetime + +from airflow import DAG +from airflow.providers.postgres.operators.postgres import PostgresOperator + +default_args = {"owner": "airflow"} + +# create_pet_table, populate_pet_table, get_all_pets, and get_birth_date are examples of tasks created by +# instantiating the Postgres Operator + +with DAG( + dag_id="postgres_operator_dag", + start_date=datetime.datetime(2020, 2, 2), + schedule_interval="@once", + default_args=default_args, + catchup=False, +) as dag: + # [START postgres_operator_howto_guide_create_pet_table] + create_pet_table = PostgresOperator( + task_id="create_pet_table", + postgres_conn_id="postgres_default", + sql=""" + CREATE TABLE IF NOT EXISTS pet ( + pet_id SERIAL PRIMARY KEY, + name VARCHAR NOT NULL, + pet_type VARCHAR NOT NULL, + birth_date DATE NOT NULL, + OWNER VARCHAR NOT NULL); + """, + ) + # [END postgres_operator_howto_guide_create_pet_table] + # [START postgres_operator_howto_guide_populate_pet_table] + populate_pet_table = PostgresOperator( + task_id="populate_pet_table", + postgres_conn_id="postgres_default", + sql=""" + INSERT INTO pet VALUES ( 'Max', 'Dog', '2018-07-05', 'Jane'); + INSERT INTO pet VALUES ( 'Susie', 'Cat', '2019-05-01', 'Phil'); + INSERT INTO pet VALUES ( 'Lester', 'Hamster', '2020-06-23', 'Lily'); + INSERT INTO pet VALUES ( 'Quincy', 'Parrot', '2013-08-11', 'Anne'); + """, + ) + # [END postgres_operator_howto_guide_populate_pet_table] + # [START postgres_operator_howto_guide_get_all_pets] + get_all_pets = PostgresOperator( + task_id="get_all_pets", + postgres_conn_id="postgres_default", + sql="SELECT * FROM pet;", + ) + # [END postgres_operator_howto_guide_get_all_pets] + # [START postgres_operator_howto_guide_get_birth_date] + get_birth_date = PostgresOperator( + task_id="get_birth_date", + postgres_conn_id="postgres_default", + sql=""" + SELECT * FROM pet + WHERE birth_date + BETWEEN SYMMETRIC {{ params.begin_date }} AND {{ params.end_date }}; + """, + params={"begin_date": "2020-01-01", "end_date": "2020-12-31"}, + ) + # [START postgres_operator_howto_guide_get_birth_date] + + create_pet_table >> populate_pet_table >> get_all_pets >> get_birth_date + # [END postgres_operator_howto_guide] diff --git a/reference/providers/postgres/hooks/__init__.py b/reference/providers/postgres/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/postgres/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/postgres/hooks/postgres.py b/reference/providers/postgres/hooks/postgres.py new file mode 100644 index 0000000..d80f837 --- /dev/null +++ b/reference/providers/postgres/hooks/postgres.py @@ -0,0 +1,255 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from contextlib import closing +from typing import Iterable, Optional, Tuple, Union + +import psycopg2 +import psycopg2.extensions +import psycopg2.extras +from airflow.hooks.dbapi import DbApiHook +from airflow.models.connection import Connection +from psycopg2.extensions import connection +from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor + +CursorType = Union[DictCursor, RealDictCursor, NamedTupleCursor] + + +class PostgresHook(DbApiHook): + """ + Interact with Postgres. + + You can specify ssl parameters in the extra field of your connection + as ``{"sslmode": "require", "sslcert": "/path/to/cert.pem", etc}``. + Also you can choose cursor as ``{"cursor": "dictcursor"}``. Refer to the + psycopg2.extras for more details. + + Note: For Redshift, use keepalives_idle in the extra connection parameters + and set it to less than 300 seconds. + + Note: For AWS IAM authentication, use iam in the extra connection parameters + and set it to true. Leave the password field empty. This will use the + "aws_default" connection to get the temporary token unless you override + in extras. + extras example: ``{"iam":true, "aws_conn_id":"my_aws_conn"}`` + For Redshift, also use redshift in the extra connection parameters and + set it to true. The cluster-identifier is extracted from the beginning of + the host field, so is optional. It can however be overridden in the extra field. + extras example: ``{"iam":true, "redshift":true, "cluster-identifier": "my_cluster_id"}`` + """ + + conn_name_attr = "postgres_conn_id" + default_conn_name = "postgres_default" + conn_type = "postgres" + hook_name = "Postgres" + supports_autocommit = True + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.schema: Optional[str] = kwargs.pop("schema", None) + self.connection: Optional[Connection] = kwargs.pop("connection", None) + self.conn: connection = None + + def _get_cursor(self, raw_cursor: str) -> CursorType: + _cursor = raw_cursor.lower() + if _cursor == "dictcursor": + return psycopg2.extras.DictCursor + if _cursor == "realdictcursor": + return psycopg2.extras.RealDictCursor + if _cursor == "namedtuplecursor": + return psycopg2.extras.NamedTupleCursor + raise ValueError(f"Invalid cursor passed {_cursor}") + + def get_conn(self) -> connection: + """Establishes a connection to a postgres database.""" + conn_id = getattr(self, self.conn_name_attr) + conn = self.connection or self.get_connection(conn_id) + + # check for authentication via AWS IAM + if conn.extra_dejson.get("iam", False): + conn.login, conn.password, conn.port = self.get_iam_token(conn) + + conn_args = dict( + host=conn.host, + user=conn.login, + password=conn.password, + dbname=self.schema or conn.schema, + port=conn.port, + ) + raw_cursor = conn.extra_dejson.get("cursor", False) + if raw_cursor: + conn_args["cursor_factory"] = self._get_cursor(raw_cursor) + + for arg_name, arg_val in conn.extra_dejson.items(): + if arg_name not in [ + "iam", + "redshift", + "cursor", + ]: + conn_args[arg_name] = arg_val + + self.conn = psycopg2.connect(**conn_args) + return self.conn + + def copy_expert(self, sql: str, filename: str) -> None: + """ + Executes SQL using psycopg2 copy_expert method. + Necessary to execute COPY command without access to a superuser. + + Note: if this method is called with a "COPY FROM" statement and + the specified input file does not exist, it creates an empty + file and no data is loaded, but the operation succeeds. + So if users want to be aware when the input file does not exist, + they have to check its existence by themselves. + """ + if not os.path.isfile(filename): + with open(filename, "w"): + pass + + with open(filename, "r+") as file: + with closing(self.get_conn()) as conn: + with closing(conn.cursor()) as cur: + cur.copy_expert(sql, file) + file.truncate(file.tell()) + conn.commit() + + def bulk_load(self, table: str, tmp_file: str) -> None: + """Loads a tab-delimited file into a database table""" + self.copy_expert(f"COPY {table} FROM STDIN", tmp_file) + + def bulk_dump(self, table: str, tmp_file: str) -> None: + """Dumps a database table into a tab-delimited file""" + self.copy_expert(f"COPY {table} TO STDOUT", tmp_file) + + # pylint: disable=signature-differs + @staticmethod + def _serialize_cell(cell: object, conn: Optional[connection] = None) -> object: + """ + Postgresql will adapt all arguments to the execute() method internally, + hence we return cell without any conversion. + + See http://initd.org/psycopg/docs/advanced.html#adapting-new-types for + more information. + + :param cell: The cell to insert into the table + :type cell: object + :param conn: The database connection + :type conn: connection object + :return: The cell + :rtype: object + """ + return cell + + def get_iam_token(self, conn: Connection) -> Tuple[str, str, int]: + """ + Uses AWSHook to retrieve a temporary password to connect to Postgres + or Redshift. Port is required. If none is provided, default is used for + each service + """ + from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + redshift = conn.extra_dejson.get("redshift", False) + aws_conn_id = conn.extra_dejson.get("aws_conn_id", "aws_default") + aws_hook = AwsBaseHook(aws_conn_id, client_type="rds") + login = conn.login + if conn.port is None: + port = 5439 if redshift else 5432 + else: + port = conn.port + if redshift: + # Pull the custer-identifier from the beginning of the Redshift URL + # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster + cluster_identifier = conn.extra_dejson.get( + "cluster-identifier", conn.host.split(".")[0] + ) + client = aws_hook.get_client_type("redshift") + cluster_creds = client.get_cluster_credentials( + DbUser=conn.login, + DbName=self.schema or conn.schema, + ClusterIdentifier=cluster_identifier, + AutoCreate=False, + ) + token = cluster_creds["DbPassword"] + login = cluster_creds["DbUser"] + else: + token = aws_hook.conn.generate_db_auth_token(conn.host, port, conn.login) + return login, token, port + + @staticmethod + def _generate_insert_sql( + table: str, + values: Tuple[str, ...], + target_fields: Iterable[str], + replace: bool, + **kwargs, + ) -> str: + """ + Static helper method that generate the INSERT SQL statement. + The REPLACE variant is specific to MySQL syntax. + + :param table: Name of the target table + :type table: str + :param values: The row to insert into the table + :type values: tuple of cell values + :param target_fields: The names of the columns to fill in the table + :type target_fields: iterable of strings + :param replace: Whether to replace instead of insert + :type replace: bool + :param replace_index: the column or list of column names to act as + index for the ON CONFLICT clause + :type replace_index: str or list + :return: The generated INSERT or REPLACE SQL statement + :rtype: str + """ + placeholders = [ + "%s", + ] * len(values) + replace_index = kwargs.get("replace_index") + + if target_fields: + target_fields_fragment = ", ".join(target_fields) + target_fields_fragment = f"({target_fields_fragment})" + else: + target_fields_fragment = "" + + sql = f"INSERT INTO {table} {target_fields_fragment} VALUES ({','.join(placeholders)})" + + if replace: + if target_fields is None: + raise ValueError( + "PostgreSQL ON CONFLICT upsert syntax requires column names" + ) + if replace_index is None: + raise ValueError( + "PostgreSQL ON CONFLICT upsert syntax requires an unique index" + ) + if isinstance(replace_index, str): + replace_index = [replace_index] + replace_index_set = set(replace_index) + + replace_target = [ + "{0} = excluded.{0}".format(col) + for col in target_fields + if col not in replace_index_set + ] + sql += " ON CONFLICT ({}) DO UPDATE SET {}".format( + ", ".join(replace_index), + ", ".join(replace_target), + ) + return sql diff --git a/reference/providers/postgres/operators/__init__.py b/reference/providers/postgres/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/postgres/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/postgres/operators/postgres.py b/reference/providers/postgres/operators/postgres.py new file mode 100644 index 0000000..5273451 --- /dev/null +++ b/reference/providers/postgres/operators/postgres.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Iterable, Mapping, Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.postgres.hooks.postgres import PostgresHook +from airflow.utils.decorators import apply_defaults + + +class PostgresOperator(BaseOperator): + """ + Executes sql code in a specific Postgres database + + :param sql: the sql code to be executed. (templated) + :type sql: Can receive a str representing a sql statement, + a list of str (sql statements), or reference to a template file. + Template reference are recognized by str ending in '.sql' + :param postgres_conn_id: reference to a specific postgres database + :type postgres_conn_id: str + :param autocommit: if True, each command is automatically committed. + (default value: False) + :type autocommit: bool + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: dict or iterable + :param database: name of database which overwrite defined one in connection + :type database: str + """ + + template_fields = ("sql",) + template_fields_renderers = {"sql": "sql"} + template_ext = (".sql",) + ui_color = "#ededed" + + @apply_defaults + def __init__( + self, + *, + sql: str, + postgres_conn_id: str = "postgres_default", + autocommit: bool = False, + parameters: Optional[Union[Mapping, Iterable]] = None, + database: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sql = sql + self.postgres_conn_id = postgres_conn_id + self.autocommit = autocommit + self.parameters = parameters + self.database = database + self.hook = None + + def execute(self, context): + self.log.info("Executing: %s", self.sql) + self.hook = PostgresHook( + postgres_conn_id=self.postgres_conn_id, schema=self.database + ) + self.hook.run(self.sql, self.autocommit, parameters=self.parameters) + for output in self.hook.conn.notices: + self.log.info(output) diff --git a/reference/providers/postgres/provider.yaml b/reference/providers/postgres/provider.yaml new file mode 100644 index 0000000..3099b39 --- /dev/null +++ b/reference/providers/postgres/provider.yaml @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-postgres +name: PostgreSQL +description: | + `PostgreSQL `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: PostgreSQL + external-doc-url: https://www.postgresql.org/ + how-to-guide: + - /docs/apache-airflow-providers-postgres/operators/postgres_operator_howto_guide.rst + logo: /integration-logos/postgress/Postgress.png + tags: [software] + +operators: + - integration-name: PostgreSQL + python-modules: + - airflow.providers.postgres.operators.postgres + +hooks: + - integration-name: PostgreSQL + python-modules: + - airflow.providers.postgres.hooks.postgres + +hook-class-names: + - airflow.providers.postgres.hooks.postgres.PostgresHook diff --git a/reference/providers/presto/CHANGELOG.rst b/reference/providers/presto/CHANGELOG.rst new file mode 100644 index 0000000..cbdc115 --- /dev/null +++ b/reference/providers/presto/CHANGELOG.rst @@ -0,0 +1,38 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.2 +..... + +Bug fixes +~~~~~~~~~ + +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/presto/__init__.py b/reference/providers/presto/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/presto/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/presto/hooks/__init__.py b/reference/providers/presto/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/presto/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/presto/hooks/presto.py b/reference/providers/presto/hooks/presto.py new file mode 100644 index 0000000..46205a9 --- /dev/null +++ b/reference/providers/presto/hooks/presto.py @@ -0,0 +1,198 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from typing import Any, Iterable, Optional + +import prestodb +from airflow import AirflowException +from airflow.configuration import conf +from airflow.hooks.dbapi import DbApiHook +from airflow.models import Connection +from prestodb.exceptions import DatabaseError +from prestodb.transaction import IsolationLevel + + +class PrestoException(Exception): + """Presto exception""" + + +def _boolify(value): + if isinstance(value, bool): + return value + if isinstance(value, str): + if value.lower() == "false": + return False + elif value.lower() == "true": + return True + return value + + +class PrestoHook(DbApiHook): + """ + Interact with Presto through prestodb. + + >>> ph = PrestoHook() + >>> sql = "SELECT count(1) AS num FROM airflow.static_babynames" + >>> ph.get_records(sql) + [[340698]] + """ + + conn_name_attr = "presto_conn_id" + default_conn_name = "presto_default" + conn_type = "presto" + hook_name = "Presto" + + def get_conn(self) -> Connection: + """Returns a connection object""" + db = self.get_connection( + self.presto_conn_id # type: ignore[attr-defined] # pylint: disable=no-member + ) + extra = db.extra_dejson + auth = None + if db.password and extra.get("auth") == "kerberos": + raise AirflowException("Kerberos authorization doesn't support password.") + elif db.password: + auth = prestodb.auth.BasicAuthentication(db.login, db.password) + elif extra.get("auth") == "kerberos": + auth = prestodb.auth.KerberosAuthentication( + config=extra.get("kerberos__config", os.environ.get("KRB5_CONFIG")), + service_name=extra.get("kerberos__service_name"), + mutual_authentication=_boolify( + extra.get("kerberos__mutual_authentication", False) + ), + force_preemptive=_boolify( + extra.get("kerberos__force_preemptive", False) + ), + hostname_override=extra.get("kerberos__hostname_override"), + sanitize_mutual_error_response=_boolify( + extra.get("kerberos__sanitize_mutual_error_response", True) + ), + principal=extra.get( + "kerberos__principal", conf.get("kerberos", "principal") + ), + delegate=_boolify(extra.get("kerberos__delegate", False)), + ca_bundle=extra.get("kerberos__ca_bundle"), + ) + + presto_conn = prestodb.dbapi.connect( + host=db.host, + port=db.port, + user=db.login, + source=db.extra_dejson.get("source", "airflow"), + http_scheme=db.extra_dejson.get("protocol", "http"), + catalog=db.extra_dejson.get("catalog", "hive"), + schema=db.schema, + auth=auth, + isolation_level=self.get_isolation_level(), # type: ignore[func-returns-value] + ) + if extra.get("verify") is not None: + # Unfortunately verify parameter is available via public API. + # The PR is merged in the presto library, but has not been released. + # See: https://github.com/prestosql/presto-python-client/pull/31 + presto_conn._http_session.verify = _boolify( + extra["verify"] + ) # pylint: disable=protected-access + + return presto_conn + + def get_isolation_level(self) -> Any: + """Returns an isolation level""" + db = self.get_connection( + self.presto_conn_id # type: ignore[attr-defined] # pylint: disable=no-member + ) + isolation_level = db.extra_dejson.get("isolation_level", "AUTOCOMMIT").upper() + return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT) + + @staticmethod + def _strip_sql(sql: str) -> str: + return sql.strip().rstrip(";") + + def get_records(self, hql, parameters: Optional[dict] = None): + """Get a set of records from Presto""" + try: + return super().get_records(self._strip_sql(hql), parameters) + except DatabaseError as e: + raise PrestoException(e) + + def get_first(self, hql: str, parameters: Optional[dict] = None) -> Any: + """Returns only the first row, regardless of how many rows the query returns.""" + try: + return super().get_first(self._strip_sql(hql), parameters) + except DatabaseError as e: + raise PrestoException(e) + + def get_pandas_df(self, hql, parameters=None, **kwargs): + """Get a pandas dataframe from a sql query.""" + import pandas + + cursor = self.get_cursor() + try: + cursor.execute(self._strip_sql(hql), parameters) + data = cursor.fetchall() + except DatabaseError as e: + raise PrestoException(e) + column_descriptions = cursor.description + if data: + df = pandas.DataFrame(data, **kwargs) + df.columns = [c[0] for c in column_descriptions] + else: + df = pandas.DataFrame(**kwargs) + return df + + def run( + self, + hql, + autocommit: bool = False, + parameters: Optional[dict] = None, + ) -> None: + """Execute the statement against Presto. Can be used to create views.""" + return super().run(sql=self._strip_sql(hql), parameters=parameters) + + def insert_rows( + self, + table: str, + rows: Iterable[tuple], + target_fields: Optional[Iterable[str]] = None, + commit_every: int = 0, + replace: bool = False, + **kwargs, + ) -> None: + """ + A generic way to insert a set of tuples into a table. + + :param table: Name of the target table + :type table: str + :param rows: The rows to insert into the table + :type rows: iterable of tuples + :param target_fields: The names of the columns to fill in the table + :type target_fields: iterable of strings + :param commit_every: The maximum number of rows to insert in one + transaction. Set to 0 to insert all rows in one transaction. + :type commit_every: int + :param replace: Whether to replace instead of insert + :type replace: bool + """ + if self.get_isolation_level() == IsolationLevel.AUTOCOMMIT: + self.log.info( + "Transactions are not enable in presto connection. " + "Please use the isolation_level property to enable it. " + "Falling back to insert all rows in one transaction." + ) + commit_every = 0 + + super().insert_rows(table, rows, target_fields, commit_every) diff --git a/reference/providers/presto/provider.yaml b/reference/providers/presto/provider.yaml new file mode 100644 index 0000000..6d003ff --- /dev/null +++ b/reference/providers/presto/provider.yaml @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-presto +name: Presto +description: | + `Presto `__ + +versions: + - 1.0.2 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Presto + external-doc-url: http://prestodb.github.io/ + logo: /integration-logos/presto/PrestoDB.png + tags: [software] + +hooks: + - integration-name: Presto + python-modules: + - airflow.providers.presto.hooks.presto + +hook-class-names: + - airflow.providers.presto.hooks.presto.PrestoHook diff --git a/reference/providers/qubole/CHANGELOG.rst b/reference/providers/qubole/CHANGELOG.rst new file mode 100644 index 0000000..eac8d8e --- /dev/null +++ b/reference/providers/qubole/CHANGELOG.rst @@ -0,0 +1,38 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.2 +..... + +Features +~~~~~~~~ + +* ``Refactor SQL/BigQuery/Qubole/Druid Check operators (#12677)`` + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/qubole/__init__.py b/reference/providers/qubole/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/qubole/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/qubole/example_dags/__init__.py b/reference/providers/qubole/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/qubole/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/qubole/example_dags/example_qubole.py b/reference/providers/qubole/example_dags/example_qubole.py new file mode 100644 index 0000000..6cdb1ad --- /dev/null +++ b/reference/providers/qubole/example_dags/example_qubole.py @@ -0,0 +1,286 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import filecmp +import random +import textwrap + +from airflow import DAG +from airflow.operators.dummy import DummyOperator +from airflow.operators.python import BranchPythonOperator, PythonOperator +from airflow.providers.qubole.operators.qubole import QuboleOperator +from airflow.providers.qubole.sensors.qubole import ( + QuboleFileSensor, + QubolePartitionSensor, +) +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, +} + +with DAG( + dag_id="example_qubole_operator", + default_args=default_args, + schedule_interval=None, + start_date=days_ago(2), + tags=["example"], +) as dag: + dag.doc_md = textwrap.dedent( + """ + This is only an example DAG to highlight usage of QuboleOperator in various scenarios, + some of these tasks may or may not work based on your Qubole account setup. + + Run a shell command from Qubole Analyze against your Airflow cluster with following to + trigger it manually `airflow dags trigger example_qubole_operator`. + + *Note: Make sure that connection `qubole_default` is properly set before running this + example. Also be aware that it might spin up clusters to run these examples.* + """ + ) + + def compare_result_fn(**kwargs): + """ + Compares the results of two QuboleOperator tasks. + + :param kwargs: The context of the executed task. + :type kwargs: dict + :return: True if the files are the same, False otherwise. + :rtype: bool + """ + ti = kwargs["ti"] + qubole_result_1 = hive_show_table.get_results(ti) + qubole_result_2 = hive_s3_location.get_results(ti) + return filecmp.cmp(qubole_result_1, qubole_result_2) + + hive_show_table = QuboleOperator( + task_id="hive_show_table", + command_type="hivecmd", + query="show tables", + cluster_label="{{ params.cluster_label }}", + fetch_logs=True, + # If `fetch_logs`=true, will fetch qubole command logs and concatenate + # them into corresponding airflow task logs + tags="airflow_example_run", + # To attach tags to qubole command, auto attach 3 tags - dag_id, task_id, run_id + qubole_conn_id="qubole_default", + # Connection id to submit commands inside QDS, if not set "qubole_default" is used + params={ + "cluster_label": "default", + }, + ) + + hive_s3_location = QuboleOperator( + task_id="hive_s3_location", + command_type="hivecmd", + script_location="s3n://public-qubole/qbol-library/scripts/show_table.hql", + notify=True, + tags=["tag1", "tag2"], + # If the script at s3 location has any qubole specific macros to be replaced + # macros='[{"date": "{{ ds }}"}, {"name" : "abc"}]', + trigger_rule="all_done", + ) + + compare_result = PythonOperator( + task_id="compare_result", + python_callable=compare_result_fn, + trigger_rule="all_done", + ) + + compare_result << [hive_show_table, hive_s3_location] + + options = ["hadoop_jar_cmd", "presto_cmd", "db_query", "spark_cmd"] + + branching = BranchPythonOperator( + task_id="branching", python_callable=lambda: random.choice(options) + ) + + branching << compare_result + + join = DummyOperator(task_id="join", trigger_rule="one_success") + + hadoop_jar_cmd = QuboleOperator( + task_id="hadoop_jar_cmd", + command_type="hadoopcmd", + sub_command="jar s3://paid-qubole/HadoopAPIExamples/" + "jars/hadoop-0.20.1-dev-streaming.jar " + "-mapper wc " + "-numReduceTasks 0 -input s3://paid-qubole/HadoopAPITests/" + "data/3.tsv -output " + "s3://paid-qubole/HadoopAPITests/data/3_wc", + cluster_label="{{ params.cluster_label }}", + fetch_logs=True, + params={ + "cluster_label": "default", + }, + ) + + pig_cmd = QuboleOperator( + task_id="pig_cmd", + command_type="pigcmd", + script_location="s3://public-qubole/qbol-library/scripts/script1-hadoop-s3-small.pig", + parameters="key1=value1 key2=value2", + trigger_rule="all_done", + ) + + pig_cmd << hadoop_jar_cmd << branching + pig_cmd >> join + + presto_cmd = QuboleOperator( + task_id="presto_cmd", command_type="prestocmd", query="show tables" + ) + + shell_cmd = QuboleOperator( + task_id="shell_cmd", + command_type="shellcmd", + script_location="s3://public-qubole/qbol-library/scripts/shellx.sh", + parameters="param1 param2", + trigger_rule="all_done", + ) + + shell_cmd << presto_cmd << branching + shell_cmd >> join + + db_query = QuboleOperator( + task_id="db_query", + command_type="dbtapquerycmd", + query="show tables", + db_tap_id=2064, + ) + + db_export = QuboleOperator( + task_id="db_export", + command_type="dbexportcmd", + mode=1, + hive_table="default_qubole_airline_origin_destination", + db_table="exported_airline_origin_destination", + partition_spec="dt=20110104-02", + dbtap_id=2064, + trigger_rule="all_done", + ) + + db_export << db_query << branching + db_export >> join + + db_import = QuboleOperator( + task_id="db_import", + command_type="dbimportcmd", + mode=1, + hive_table="default_qubole_airline_origin_destination", + db_table="exported_airline_origin_destination", + where_clause="id < 10", + parallelism=2, + dbtap_id=2064, + trigger_rule="all_done", + ) + + prog = """ + import scala.math.random + + import org.apache.spark._ + + /** Computes an approximation to pi */ + object SparkPi { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("Spark Pi") + val spark = new SparkContext(conf) + val slices = if (args.length > 0) args(0).toInt else 2 + val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow + val count = spark.parallelize(1 until n, slices).map { i => + val x = random * 2 - 1 + val y = random * 2 - 1 + if (x*x + y*y < 1) 1 else 0 + }.reduce(_ + _) + println("Pi is roughly " + 4.0 * count / n) + spark.stop() + } + } + + """ + + spark_cmd = QuboleOperator( + task_id="spark_cmd", + command_type="sparkcmd", + program=prog, + language="scala", + arguments="--class SparkPi", + tags="airflow_example_run", + ) + + spark_cmd << db_import << branching + spark_cmd >> join + +with DAG( + dag_id="example_qubole_sensor", + default_args=default_args, + schedule_interval=None, + start_date=days_ago(2), + doc_md=__doc__, + tags=["example"], +) as dag2: + dag2.doc_md = textwrap.dedent( + """ + This is only an example DAG to highlight usage of QuboleSensor in various scenarios, + some of these tasks may or may not work based on your QDS account setup. + + Run a shell command from Qubole Analyze against your Airflow cluster with following to + trigger it manually `airflow dags trigger example_qubole_sensor`. + + *Note: Make sure that connection `qubole_default` is properly set before running + this example.* + """ + ) + + check_s3_file = QuboleFileSensor( + task_id="check_s3_file", + qubole_conn_id="qubole_default", + poke_interval=60, + timeout=600, + data={ + "files": [ + "s3://paid-qubole/HadoopAPIExamples/jars/hadoop-0.20.1-dev-streaming.jar", + "s3://paid-qubole/HadoopAPITests/data/{{ ds.split('-')[2] }}.tsv", + ] # will check for availability of all the files in array + }, + ) + + check_hive_partition = QubolePartitionSensor( + task_id="check_hive_partition", + poke_interval=10, + timeout=60, + data={ + "schema": "default", + "table": "my_partitioned_table", + "columns": [ + {"column": "month", "values": ["{{ ds.split('-')[1] }}"]}, + { + "column": "day", + "values": [ + "{{ ds.split('-')[2] }}", + "{{ yesterday_ds.split('-')[2] }}", + ], + }, + ], # will check for partitions like [month=12/day=12,month=12/day=13] + }, + ) + + check_s3_file >> check_hive_partition diff --git a/reference/providers/qubole/hooks/__init__.py b/reference/providers/qubole/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/qubole/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/qubole/hooks/qubole.py b/reference/providers/qubole/hooks/qubole.py new file mode 100644 index 0000000..6064b60 --- /dev/null +++ b/reference/providers/qubole/hooks/qubole.py @@ -0,0 +1,301 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""Qubole hook""" +import datetime +import logging +import os +import pathlib +import time +from typing import Dict, List, Tuple + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.utils.state import State +from qds_sdk.commands import ( + Command, + DbExportCommand, + DbImportCommand, + DbTapQueryCommand, + HadoopCommand, + HiveCommand, + JupyterNotebookCommand, + PigCommand, + PrestoCommand, + ShellCommand, + SparkCommand, + SqlCommand, +) +from qds_sdk.qubole import Qubole + +log = logging.getLogger(__name__) + +COMMAND_CLASSES = { + "hivecmd": HiveCommand, + "prestocmd": PrestoCommand, + "hadoopcmd": HadoopCommand, + "shellcmd": ShellCommand, + "pigcmd": PigCommand, + "sparkcmd": SparkCommand, + "dbtapquerycmd": DbTapQueryCommand, + "dbexportcmd": DbExportCommand, + "dbimportcmd": DbImportCommand, + "sqlcmd": SqlCommand, + "jupytercmd": JupyterNotebookCommand, +} + +POSITIONAL_ARGS = { + "hadoopcmd": ["sub_command"], + "shellcmd": ["parameters"], + "pigcmd": ["parameters"], +} + + +def flatten_list(list_of_lists) -> list: + """Flatten the list""" + return [element for array in list_of_lists for element in array] + + +def filter_options(options: list) -> list: + """Remove options from the list""" + options_to_remove = ["help", "print-logs-live", "print-logs", "pool"] + return [option for option in options if option not in options_to_remove] + + +def get_options_list(command_class) -> list: + """Get options list""" + options_list = [ + option.get_opt_string().strip("--") + for option in command_class.optparser.option_list + ] + return filter_options(options_list) + + +def build_command_args() -> Tuple[Dict[str, list], list]: + """Build Command argument from command and options""" + command_args, hyphen_args = {}, set() + for cmd in COMMAND_CLASSES: + + # get all available options from the class + opts_list = get_options_list(COMMAND_CLASSES[cmd]) + + # append positional args if any for the command + if cmd in POSITIONAL_ARGS: + opts_list += POSITIONAL_ARGS[cmd] + + # get args with a hyphen and replace them with underscore + for index, opt in enumerate(opts_list): + if "-" in opt: + opts_list[index] = opt.replace("-", "_") + hyphen_args.add(opts_list[index]) + + command_args[cmd] = opts_list + return command_args, list(hyphen_args) + + +COMMAND_ARGS, HYPHEN_ARGS = build_command_args() + + +class QuboleHook(BaseHook): + """Hook for Qubole communication""" + + conn_name_attr = "qubole_conn_id" + default_conn_name = "qubole_default" + conn_type = "qubole" + hook_name = "Qubole" + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ["login", "schema", "port", "extra"], + "relabeling": { + "host": "API Endpoint", + "password": "Auth Token", + }, + "placeholders": {"host": "https://.qubole.com/api"}, + } + + def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument + super().__init__() + conn = self.get_connection(kwargs.get("qubole_conn_id", self.default_conn_name)) + Qubole.configure(api_token=conn.password, api_url=conn.host) + self.task_id = kwargs["task_id"] + self.dag_id = kwargs["dag"].dag_id + self.kwargs = kwargs + self.cls = COMMAND_CLASSES[self.kwargs["command_type"]] + self.cmd = None + self.task_instance = None + + @staticmethod + def handle_failure_retry(context) -> None: + """Handle retries in case of failures""" + ti = context["ti"] + cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=ti.task_id) + + if cmd_id is not None: + cmd = Command.find(cmd_id) + if cmd is not None: + if cmd.status == "done": + log.info( + "Command ID: %s has been succeeded, hence marking this TI as Success.", + cmd_id, + ) + ti.state = State.SUCCESS + elif cmd.status == "running": + log.info("Cancelling the Qubole Command Id: %s", cmd_id) + cmd.cancel() + + def execute(self, context) -> None: + """Execute call""" + args = self.cls.parse(self.create_cmd_args(context)) + self.cmd = self.cls.create(**args) + self.task_instance = context["task_instance"] + context["task_instance"].xcom_push(key="qbol_cmd_id", value=self.cmd.id) # type: ignore[attr-defined] + self.log.info( + "Qubole command created with Id: %s and Status: %s", + self.cmd.id, # type: ignore[attr-defined] + self.cmd.status, # type: ignore[attr-defined] + ) + + while not Command.is_done(self.cmd.status): # type: ignore[attr-defined] + time.sleep(Qubole.poll_interval) + self.cmd = self.cls.find(self.cmd.id) # type: ignore[attr-defined] + self.log.info( + "Command Id: %s and Status: %s", self.cmd.id, self.cmd.status # type: ignore[attr-defined] + ) + + if "fetch_logs" in self.kwargs and self.kwargs["fetch_logs"] is True: + self.log.info( + "Logs for Command Id: %s \n%s", self.cmd.id, self.cmd.get_log() # type: ignore[attr-defined] + ) + + if self.cmd.status != "done": # type: ignore[attr-defined] + raise AirflowException( + "Command Id: {} failed with Status: {}".format( + self.cmd.id, self.cmd.status # type: ignore[attr-defined] + ) + ) + + def kill(self, ti): + """ + Kill (cancel) a Qubole command + + :param ti: Task Instance of the dag, used to determine the Quboles command id + :return: response from Qubole + """ + if self.cmd is None: + if not ti and not self.task_instance: + raise Exception( + "Unable to cancel Qubole Command, context is unavailable!" + ) + elif not ti: + ti = self.task_instance + cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=ti.task_id) + self.cmd = self.cls.find(cmd_id) + if self.cls and self.cmd: + self.log.info("Sending KILL signal to Qubole Command Id: %s", self.cmd.id) + self.cmd.cancel() + + def get_results( + self, ti=None, fp=None, inline: bool = True, delim=None, fetch: bool = True + ) -> str: + """ + Get results (or just s3 locations) of a command from Qubole and save into a file + + :param ti: Task Instance of the dag, used to determine the Quboles command id + :param fp: Optional file pointer, will create one and return if None passed + :param inline: True to download actual results, False to get s3 locations only + :param delim: Replaces the CTL-A chars with the given delim, defaults to ',' + :param fetch: when inline is True, get results directly from s3 (if large) + :return: file location containing actual results or s3 locations of results + """ + if fp is None: + iso = datetime.datetime.utcnow().isoformat() + logpath = os.path.expanduser(conf.get("logging", "BASE_LOG_FOLDER")) + resultpath = logpath + "/" + self.dag_id + "/" + self.task_id + "/results" + pathlib.Path(resultpath).mkdir(parents=True, exist_ok=True) + fp = open(resultpath + "/" + iso, "wb") + + if self.cmd is None: + cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=self.task_id) + self.cmd = self.cls.find(cmd_id) + + self.cmd.get_results(fp, inline, delim, fetch) # type: ignore[attr-defined] + fp.flush() + fp.close() + return fp.name + + def get_log(self, ti) -> None: + """ + Get Logs of a command from Qubole + + :param ti: Task Instance of the dag, used to determine the Quboles command id + :return: command log as text + """ + if self.cmd is None: + cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=self.task_id) + Command.get_log_id(cmd_id) + + def get_jobs_id(self, ti) -> None: + """ + Get jobs associated with a Qubole commands + + :param ti: Task Instance of the dag, used to determine the Quboles command id + :return: Job information associated with command + """ + if self.cmd is None: + cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=self.task_id) + Command.get_jobs_id(cmd_id) + + def create_cmd_args(self, context) -> List[str]: + """Creates command arguments""" + args = [] + cmd_type = self.kwargs["command_type"] + inplace_args = None + tags = {self.dag_id, self.task_id, context["run_id"]} + positional_args_list = flatten_list(POSITIONAL_ARGS.values()) + + for key, value in self.kwargs.items(): # pylint: disable=too-many-nested-blocks + if key in COMMAND_ARGS[cmd_type]: + if key in HYPHEN_ARGS: + args.append(f"--{key.replace('_', '-')}={value}") + elif key in positional_args_list: + inplace_args = value + elif key == "tags": + self._add_tags(tags, value) + elif key == "notify": + if value is True: + args.append("--notify") + else: + args.append(f"--{key}={value}") + + args.append(f"--tags={','.join(filter(None, tags))}") + + if inplace_args is not None: + args += inplace_args.split(" ") + + return args + + @staticmethod + def _add_tags(tags, value) -> None: + if isinstance(value, str): + tags.add(value) + elif isinstance(value, (list, tuple)): + tags.update(value) diff --git a/reference/providers/qubole/hooks/qubole_check.py b/reference/providers/qubole/hooks/qubole_check.py new file mode 100644 index 0000000..faf00fb --- /dev/null +++ b/reference/providers/qubole/hooks/qubole_check.py @@ -0,0 +1,125 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import logging +from io import StringIO +from typing import List, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.providers.qubole.hooks.qubole import QuboleHook +from qds_sdk.commands import Command + +log = logging.getLogger(__name__) + +COL_DELIM = "\t" +ROW_DELIM = "\r\n" + + +def isint(value) -> bool: + """Whether Qubole column are integer""" + try: + int(value) + return True + except ValueError: + return False + + +def isfloat(value) -> bool: + """Whether Qubole column are float""" + try: + float(value) + return True + except ValueError: + return False + + +def isbool(value) -> bool: + """Whether Qubole column are boolean""" + try: + return value.lower() in ["true", "false"] + except ValueError: + return False + + +def parse_first_row(row_list) -> List[Union[bool, float, int, str]]: + """Parse Qubole first record list""" + record_list = [] + first_row = row_list[0] if row_list else "" + + for col_value in first_row.split(COL_DELIM): + if isint(col_value): + col_value = int(col_value) + elif isfloat(col_value): + col_value = float(col_value) + elif isbool(col_value): + col_value = col_value.lower() == "true" + record_list.append(col_value) + + return record_list + + +class QuboleCheckHook(QuboleHook): + """Qubole check hook""" + + def __init__(self, context, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.results_parser_callable = parse_first_row + if ( + "results_parser_callable" in kwargs + and kwargs["results_parser_callable"] is not None + ): + if not callable(kwargs["results_parser_callable"]): + raise AirflowException( + "`results_parser_callable` param must be callable" + ) + self.results_parser_callable = kwargs["results_parser_callable"] + self.context = context + + @staticmethod + def handle_failure_retry(context) -> None: + ti = context["ti"] + cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=ti.task_id) + + if cmd_id is not None: + cmd = Command.find(cmd_id) + if cmd is not None: + if cmd.status == "running": + log.info("Cancelling the Qubole Command Id: %s", cmd_id) + cmd.cancel() + + def get_first(self, sql): # pylint: disable=unused-argument + """Get Qubole query first record list""" + self.execute(context=self.context) + query_result = self.get_query_results() + row_list = list(filter(None, query_result.split(ROW_DELIM))) + record_list = self.results_parser_callable(row_list) + return record_list + + def get_query_results(self) -> Optional[str]: + """Get Qubole query result""" + if self.cmd is not None: + cmd_id = self.cmd.id + self.log.info("command id: %d", cmd_id) + query_result_buffer = StringIO() + self.cmd.get_results(fp=query_result_buffer, inline=True, delim=COL_DELIM) + query_result = query_result_buffer.getvalue() + query_result_buffer.close() + return query_result + else: + self.log.error("Qubole command not found") + return None diff --git a/reference/providers/qubole/operators/__init__.py b/reference/providers/qubole/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/qubole/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/qubole/operators/qubole.py b/reference/providers/qubole/operators/qubole.py new file mode 100644 index 0000000..ee5eb7d --- /dev/null +++ b/reference/providers/qubole/operators/qubole.py @@ -0,0 +1,292 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Qubole operator""" +import re +from datetime import datetime +from typing import Iterable, Optional + +from airflow.hooks.base import BaseHook +from airflow.models import BaseOperator, BaseOperatorLink +from airflow.models.taskinstance import TaskInstance +from airflow.providers.qubole.hooks.qubole import ( + COMMAND_ARGS, + HYPHEN_ARGS, + POSITIONAL_ARGS, + QuboleHook, + flatten_list, +) +from airflow.utils.decorators import apply_defaults + + +class QDSLink(BaseOperatorLink): + """Link to QDS""" + + name = "Go to QDS" + + def get_link(self, operator: BaseOperator, dttm: datetime) -> str: + """ + Get link to qubole command result page. + + :param operator: operator + :param dttm: datetime + :return: url link + """ + ti = TaskInstance(task=operator, execution_date=dttm) + conn = BaseHook.get_connection( + getattr(operator, "qubole_conn_id", None) + or operator.kwargs["qubole_conn_id"] # type: ignore[attr-defined] + ) + if conn and conn.host: + host = re.sub(r"api$", "v2/analyze?command_id=", conn.host) + else: + host = "https://api.qubole.com/v2/analyze?command_id=" + qds_command_id = ti.xcom_pull(task_ids=operator.task_id, key="qbol_cmd_id") + url = host + str(qds_command_id) if qds_command_id else "" + return url + + +class QuboleOperator(BaseOperator): + """ + Execute tasks (commands) on QDS (https://qubole.com). + + :param qubole_conn_id: Connection id which consists of qds auth_token + :type qubole_conn_id: str + + kwargs: + :command_type: type of command to be executed, e.g. hivecmd, shellcmd, hadoopcmd + :tags: array of tags to be assigned with the command + :cluster_label: cluster label on which the command will be executed + :name: name to be given to command + :notify: whether to send email on command completion or not (default is False) + + **Arguments specific to command types** + + hivecmd: + :query: inline query statement + :script_location: s3 location containing query statement + :sample_size: size of sample in bytes on which to run query + :macros: macro values which were used in query + :sample_size: size of sample in bytes on which to run query + :hive-version: Specifies the hive version to be used. eg: 0.13,1.2,etc. + prestocmd: + :query: inline query statement + :script_location: s3 location containing query statement + :macros: macro values which were used in query + hadoopcmd: + :sub_commnad: must be one these ["jar", "s3distcp", "streaming"] followed by + 1 or more args + shellcmd: + :script: inline command with args + :script_location: s3 location containing query statement + :files: list of files in s3 bucket as file1,file2 format. These files will be + copied into the working directory where the qubole command is being + executed. + :archives: list of archives in s3 bucket as archive1,archive2 format. These + will be unarchived into the working directory where the qubole command is + being executed + :parameters: any extra args which need to be passed to script (only when + script_location is supplied) + pigcmd: + :script: inline query statement (latin_statements) + :script_location: s3 location containing pig query + :parameters: any extra args which need to be passed to script (only when + script_location is supplied + sparkcmd: + :program: the complete Spark Program in Scala, R, or Python + :cmdline: spark-submit command line, all required arguments must be specify + in cmdline itself. + :sql: inline sql query + :script_location: s3 location containing query statement + :language: language of the program, Scala, R, or Python + :app_id: ID of an Spark job server app + :arguments: spark-submit command line arguments. + If `cmdline` is selected, this should not be used because all + required arguments and configurations are to be passed in the `cmdline` itself. + :user_program_arguments: arguments that the user program takes in + :macros: macro values which were used in query + :note_id: Id of the Notebook to run + dbtapquerycmd: + :db_tap_id: data store ID of the target database, in Qubole. + :query: inline query statement + :macros: macro values which were used in query + dbexportcmd: + :mode: Can be 1 for Hive export or 2 for HDFS/S3 export + :schema: Db schema name assumed accordingly by database if not specified + :hive_table: Name of the hive table + :partition_spec: partition specification for Hive table. + :dbtap_id: data store ID of the target database, in Qubole. + :db_table: name of the db table + :db_update_mode: allowinsert or updateonly + :db_update_keys: columns used to determine the uniqueness of rows + :export_dir: HDFS/S3 location from which data will be exported. + :fields_terminated_by: hex of the char used as column separator in the dataset + :use_customer_cluster: To use cluster to run command + :customer_cluster_label: the label of the cluster to run the command on + :additional_options: Additional Sqoop options which are needed enclose options in + double or single quotes e.g. '--map-column-hive id=int,data=string' + dbimportcmd: + :mode: 1 (simple), 2 (advance) + :hive_table: Name of the hive table + :schema: Db schema name assumed accordingly by database if not specified + :hive_serde: Output format of the Hive Table + :dbtap_id: data store ID of the target database, in Qubole. + :db_table: name of the db table + :where_clause: where clause, if any + :parallelism: number of parallel db connections to use for extracting data + :extract_query: SQL query to extract data from db. $CONDITIONS must be part + of the where clause. + :boundary_query: Query to be used get range of row IDs to be extracted + :split_column: Column used as row ID to split data into ranges (mode 2) + :use_customer_cluster: To use cluster to run command + :customer_cluster_label: the label of the cluster to run the command on + :additional_options: Additional Sqoop options which are needed enclose options in + double or single quotes + jupytercmd: + :path: Path including name of the Jupyter notebook to be run with extension. + :arguments: Valid JSON to be sent to the notebook. Specify the parameters in notebooks and pass + the parameter value using the JSON format. key is the parameter’s name and value is + the parameter’s value. Supported types in parameters are string, integer, float and boolean. + + .. note: + + Following fields are template-supported : ``query``, ``script_location``, + ``sub_command``, ``script``, ``files``, ``archives``, ``program``, ``cmdline``, + ``sql``, ``where_clause``, ``extract_query``, ``boundary_query``, ``macros``, + ``tags``, ``name``, ``parameters``, ``dbtap_id``, ``hive_table``, ``db_table``, + ``split_column``, ``note_id``, ``db_update_keys``, ``export_dir``, + ``partition_spec``, ``qubole_conn_id``, ``arguments``, ``user_program_arguments``. + You can also use ``.txt`` files for template driven use cases. + + .. note: + + In QuboleOperator there is a default handler for task failures and retries, + which generally kills the command running at QDS for the corresponding task + instance. You can override this behavior by providing your own failure and retry + handler in task definition. + """ + + template_fields: Iterable[str] = ( + "query", + "script_location", + "sub_command", + "script", + "files", + "archives", + "program", + "cmdline", + "sql", + "where_clause", + "tags", + "extract_query", + "boundary_query", + "macros", + "name", + "parameters", + "dbtap_id", + "hive_table", + "db_table", + "split_column", + "note_id", + "db_update_keys", + "export_dir", + "partition_spec", + "qubole_conn_id", + "arguments", + "user_program_arguments", + "cluster_label", + ) + + template_ext: Iterable[str] = (".txt",) + ui_color = "#3064A1" + ui_fgcolor = "#fff" + qubole_hook_allowed_args_list = ["command_type", "qubole_conn_id", "fetch_logs"] + + operator_extra_links = (QDSLink(),) + + @apply_defaults + def __init__(self, *, qubole_conn_id: str = "qubole_default", **kwargs) -> None: + self.kwargs = kwargs + self.kwargs["qubole_conn_id"] = qubole_conn_id + self.hook: Optional[QuboleHook] = None + filtered_base_kwargs = self._get_filtered_args(kwargs) + super().__init__(**filtered_base_kwargs) + + if self.on_failure_callback is None: + self.on_failure_callback = QuboleHook.handle_failure_retry + + if self.on_retry_callback is None: + self.on_retry_callback = QuboleHook.handle_failure_retry + + def _get_filtered_args(self, all_kwargs) -> dict: + qubole_args = ( + flatten_list(COMMAND_ARGS.values()) + + HYPHEN_ARGS + + flatten_list(POSITIONAL_ARGS.values()) + + self.qubole_hook_allowed_args_list + ) + return { + key: value for key, value in all_kwargs.items() if key not in qubole_args + } + + def execute(self, context) -> None: + return self.get_hook().execute(context) + + def on_kill(self, ti=None) -> None: + if self.hook: + self.hook.kill(ti) + else: + self.get_hook().kill(ti) + + def get_results( + self, ti=None, fp=None, inline: bool = True, delim=None, fetch: bool = True + ) -> str: + """get_results from Qubole""" + return self.get_hook().get_results(ti, fp, inline, delim, fetch) + + def get_log(self, ti) -> None: + """get_log from Qubole""" + return self.get_hook().get_log(ti) + + def get_jobs_id(self, ti) -> None: + """Get jobs_id from Qubole""" + return self.get_hook().get_jobs_id(ti) + + def get_hook(self) -> QuboleHook: + """Reinitialising the hook, as some template fields might have changed""" + return QuboleHook(**self.kwargs) + + def __getattribute__(self, name: str) -> str: + if name in _get_template_fields(self): + if name in self.kwargs: + return self.kwargs[name] + else: + return "" + else: + return object.__getattribute__(self, name) + + def __setattr__(self, name: str, value: str) -> None: + if name in _get_template_fields(self): + self.kwargs[name] = value + else: + object.__setattr__(self, name, value) + + +def _get_template_fields(obj: BaseOperator) -> dict: + class_ = object.__getattribute__(obj, "__class__") + template_fields = object.__getattribute__(class_, "template_fields") + return template_fields diff --git a/reference/providers/qubole/operators/qubole_check.py b/reference/providers/qubole/operators/qubole_check.py new file mode 100644 index 0000000..9fad1b6 --- /dev/null +++ b/reference/providers/qubole/operators/qubole_check.py @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Callable, Iterable, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.operators.sql import SQLCheckOperator, SQLValueCheckOperator +from airflow.providers.qubole.hooks.qubole_check import QuboleCheckHook +from airflow.providers.qubole.operators.qubole import QuboleOperator +from airflow.utils.decorators import apply_defaults + + +class _QuboleCheckOperatorMixin: + """This is a Mixin for Qubole related check operators""" + + def execute(self, context=None) -> None: + """Execute a check operation against Qubole""" + try: + self._hook_context = context + super().execute(context=context) + except AirflowException as e: + handle_airflow_exception(e, self.get_hook()) + + def get_db_hook(self) -> QuboleCheckHook: + """Get QuboleCheckHook""" + return self.get_hook() + + # this overwrite the original QuboleOperator.get_hook() which returns a QuboleHook. + def get_hook(self) -> QuboleCheckHook: + """Reinitialising the hook, as some template fields might have changed""" + return QuboleCheckHook( + context=self._hook_context, + results_parser_callable=self.results_parser_callable, + **self.kwargs, + ) + + +# pylint: disable=too-many-ancestors +class QuboleCheckOperator(_QuboleCheckOperatorMixin, SQLCheckOperator, QuboleOperator): + """ + Performs checks against Qubole Commands. ``QuboleCheckOperator`` expects + a command that will be executed on QDS. + By default, each value on first row of the result of this Qubole Command + is evaluated using python ``bool`` casting. If any of the + values return ``False``, the check is failed and errors out. + + Note that Python bool casting evals the following as ``False``: + + * ``False`` + * ``0`` + * Empty string (``""``) + * Empty list (``[]``) + * Empty dictionary or set (``{}``) + + Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if + the count ``== 0``. You can craft much more complex query that could, + for instance, check that the table has the same number of rows as + the source table upstream, or that the count of today's partition is + greater than yesterday's partition, or that a set of metrics are less + than 3 standard deviation for the 7 day average. + + This operator can be used as a data quality check in your pipeline, and + depending on where you put it in your DAG, you have the choice to + stop the critical path, preventing from + publishing dubious data, or on the side and receive email alerts + without stopping the progress of the DAG. + + :param qubole_conn_id: Connection id which consists of qds auth_token + :type qubole_conn_id: str + + kwargs: + + Arguments specific to Qubole command can be referred from QuboleOperator docs. + + :results_parser_callable: This is an optional parameter to + extend the flexibility of parsing the results of Qubole + command to the users. This is a python callable which + can hold the logic to parse list of rows returned by Qubole command. + By default, only the values on first row are used for performing checks. + This callable should return a list of records on + which the checks have to be performed. + + .. note:: All fields in common with template fields of + QuboleOperator and SQLCheckOperator are template-supported. + + """ + + template_fields: Iterable[str] = set(QuboleOperator.template_fields) | set( + SQLCheckOperator.template_fields + ) + template_ext = QuboleOperator.template_ext + ui_fgcolor = "#000" + + @apply_defaults + def __init__( + self, + *, + qubole_conn_id: str = "qubole_default", + results_parser_callable: Callable = None, + **kwargs, + ) -> None: + sql = get_sql_from_qbol_cmd(kwargs) + kwargs.pop("sql", None) + super().__init__(qubole_conn_id=qubole_conn_id, sql=sql, **kwargs) + self.results_parser_callable = results_parser_callable + self.on_failure_callback = QuboleCheckHook.handle_failure_retry + self.on_retry_callback = QuboleCheckHook.handle_failure_retry + self._hook_context = None + + +# TODO(xinbinhuang): refactor to reduce levels of inheritance +# pylint: disable=too-many-ancestors +class QuboleValueCheckOperator( + _QuboleCheckOperatorMixin, SQLValueCheckOperator, QuboleOperator +): + """ + Performs a simple value check using Qubole command. + By default, each value on the first row of this + Qubole command is compared with a pre-defined value. + The check fails and errors out if the output of the command + is not within the permissible limit of expected value. + + :param qubole_conn_id: Connection id which consists of qds auth_token + :type qubole_conn_id: str + + :param pass_value: Expected value of the query results. + :type pass_value: str or int or float + + :param tolerance: Defines the permissible pass_value range, for example if + tolerance is 2, the Qubole command output can be anything between + -2*pass_value and 2*pass_value, without the operator erring out. + + :type tolerance: int or float + + + kwargs: + + Arguments specific to Qubole command can be referred from QuboleOperator docs. + + :results_parser_callable: This is an optional parameter to + extend the flexibility of parsing the results of Qubole + command to the users. This is a python callable which + can hold the logic to parse list of rows returned by Qubole command. + By default, only the values on first row are used for performing checks. + This callable should return a list of records on + which the checks have to be performed. + + + .. note:: All fields in common with template fields of + QuboleOperator and SQLValueCheckOperator are template-supported. + """ + + template_fields = set(QuboleOperator.template_fields) | set( + SQLValueCheckOperator.template_fields + ) + template_ext = QuboleOperator.template_ext + ui_fgcolor = "#000" + + @apply_defaults + def __init__( + self, + *, + pass_value: Union[str, int, float], + tolerance: Optional[Union[int, float]] = None, + results_parser_callable: Callable = None, + qubole_conn_id: str = "qubole_default", + **kwargs, + ) -> None: + sql = get_sql_from_qbol_cmd(kwargs) + kwargs.pop("sql", None) + super().__init__( + qubole_conn_id=qubole_conn_id, + sql=sql, + pass_value=pass_value, + tolerance=tolerance, + **kwargs, + ) + self.results_parser_callable = results_parser_callable + self.on_failure_callback = QuboleCheckHook.handle_failure_retry + self.on_retry_callback = QuboleCheckHook.handle_failure_retry + self._hook_context = None + + +def get_sql_from_qbol_cmd(params) -> str: + """Get Qubole sql from Qubole command""" + sql = "" + if "query" in params: + sql = params["query"] + elif "sql" in params: + sql = params["sql"] + return sql + + +def handle_airflow_exception(airflow_exception, hook: QuboleCheckHook): + """Qubole check handle Airflow exception""" + cmd = hook.cmd + if cmd is not None: + if cmd.is_success(cmd.status): + qubole_command_results = hook.get_query_results() + qubole_command_id = cmd.id + exception_message = ( + "\nQubole Command Id: {qubole_command_id}" + "\nQubole Command Results:" + "\n{qubole_command_results}".format( + qubole_command_id=qubole_command_id, + qubole_command_results=qubole_command_results, + ) + ) + raise AirflowException(str(airflow_exception) + exception_message) + raise AirflowException(str(airflow_exception)) diff --git a/reference/providers/qubole/provider.yaml b/reference/providers/qubole/provider.yaml new file mode 100644 index 0000000..bd4cd58 --- /dev/null +++ b/reference/providers/qubole/provider.yaml @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-qubole +name: Qubole +description: | + `Qubole `__ + +versions: + - 1.0.2 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Qubole + external-doc-url: https://www.qubole.com/ + logo: /integration-logos/qubole/Qubole.png + tags: [service] + +operators: + - integration-name: Qubole + python-modules: + - airflow.providers.qubole.operators.qubole + - airflow.providers.qubole.operators.qubole_check + +sensors: + - integration-name: Qubole + python-modules: + - airflow.providers.qubole.sensors.qubole + +hooks: + - integration-name: Qubole + python-modules: + - airflow.providers.qubole.hooks.qubole + - airflow.providers.qubole.hooks.qubole_check + +hook-class-names: + - airflow.providers.qubole.hooks.qubole.QuboleHook + +extra-links: + - airflow.providers.qubole.operators.qubole.QDSLink diff --git a/reference/providers/qubole/sensors/__init__.py b/reference/providers/qubole/sensors/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/qubole/sensors/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/qubole/sensors/qubole.py b/reference/providers/qubole/sensors/qubole.py new file mode 100644 index 0000000..9624527 --- /dev/null +++ b/reference/providers/qubole/sensors/qubole.py @@ -0,0 +1,113 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults +from qds_sdk.qubole import Qubole +from qds_sdk.sensors import FileSensor, PartitionSensor + + +class QuboleSensor(BaseSensorOperator): + """Base class for all Qubole Sensors""" + + template_fields = ("data", "qubole_conn_id") + + template_ext = (".txt",) + + @apply_defaults + def __init__( + self, *, data, qubole_conn_id: str = "qubole_default", **kwargs + ) -> None: + self.data = data + self.qubole_conn_id = qubole_conn_id + + if "poke_interval" in kwargs and kwargs["poke_interval"] < 5: + raise AirflowException( + "Sorry, poke_interval can't be less than 5 sec for " + "task '{}' in dag '{}'.".format(kwargs["task_id"], kwargs["dag"].dag_id) + ) + + super().__init__(**kwargs) + + def poke(self, context: dict) -> bool: + + conn = BaseHook.get_connection(self.qubole_conn_id) + Qubole.configure(api_token=conn.password, api_url=conn.host) + + self.log.info("Poking: %s", self.data) + + status = False + try: + status = self.sensor_class.check( # type: ignore[attr-defined] # pylint: disable=no-member + self.data + ) + except Exception as e: # pylint: disable=broad-except + self.log.exception(e) + status = False + + self.log.info("Status of this Poke: %s", status) + + return status + + +class QuboleFileSensor(QuboleSensor): + """ + Wait for a file or folder to be present in cloud storage + and check for its presence via QDS APIs + + :param qubole_conn_id: Connection id which consists of qds auth_token + :type qubole_conn_id: str + :param data: a JSON object containing payload, whose presence needs to be checked + Check this `example `_ for sample payload + structure. + :type data: dict + + .. note:: Both ``data`` and ``qubole_conn_id`` fields support templating. You can + also use ``.txt`` files for template-driven use cases. + """ + + @apply_defaults + def __init__(self, **kwargs) -> None: + self.sensor_class = FileSensor + super().__init__(**kwargs) + + +class QubolePartitionSensor(QuboleSensor): + """ + Wait for a Hive partition to show up in QHS (Qubole Hive Service) + and check for its presence via QDS APIs + + :param qubole_conn_id: Connection id which consists of qds auth_token + :type qubole_conn_id: str + :param data: a JSON object containing payload, whose presence needs to be checked. + Check this `example `_ for sample payload + structure. + :type data: dict + + .. note:: Both ``data`` and ``qubole_conn_id`` fields support templating. You can + also use ``.txt`` files for template-driven use cases. + """ + + @apply_defaults + def __init__(self, **kwargs) -> None: + self.sensor_class = PartitionSensor + super().__init__(**kwargs) diff --git a/reference/providers/redis/CHANGELOG.rst b/reference/providers/redis/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/redis/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/redis/__init__.py b/reference/providers/redis/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/redis/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/redis/hooks/__init__.py b/reference/providers/redis/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/redis/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/redis/hooks/redis.py b/reference/providers/redis/hooks/redis.py new file mode 100644 index 0000000..ecbe947 --- /dev/null +++ b/reference/providers/redis/hooks/redis.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""RedisHook module""" +from airflow.hooks.base import BaseHook +from redis import Redis + + +class RedisHook(BaseHook): + """ + Wrapper for connection to interact with Redis in-memory data structure store + + You can set your db in the extra field of your connection as ``{"db": 3}``. + Also you can set ssl parameters as: + ``{"ssl": true, "ssl_cert_reqs": "require", "ssl_cert_file": "/path/to/cert.pem", etc}``. + """ + + conn_name_attr = "redis_conn_id" + default_conn_name = "redis_default" + conn_type = "redis" + hook_name = "Redis" + + def __init__(self, redis_conn_id: str = default_conn_name) -> None: + """ + Prepares hook to connect to a Redis database. + + :param conn_id: the name of the connection that has the parameters + we need to connect to Redis. + """ + super().__init__() + self.redis_conn_id = redis_conn_id + self.redis = None + self.host = None + self.port = None + self.password = None + self.db = None + + def get_conn(self): + """Returns a Redis connection.""" + conn = self.get_connection(self.redis_conn_id) + self.host = conn.host + self.port = conn.port + self.password = ( + None + if str(conn.password).lower() in ["none", "false", ""] + else conn.password + ) + self.db = conn.extra_dejson.get("db") + + # check for ssl parameters in conn.extra + ssl_arg_names = [ + "ssl", + "ssl_cert_reqs", + "ssl_ca_certs", + "ssl_keyfile", + "ssl_cert_file", + "ssl_check_hostname", + ] + ssl_args = { + name: val + for name, val in conn.extra_dejson.items() + if name in ssl_arg_names + } + + if not self.redis: + self.log.debug( + 'Initializing redis object for conn_id "%s" on %s:%s:%s', + self.redis_conn_id, + self.host, + self.port, + self.db, + ) + self.redis = Redis( + host=self.host, + port=self.port, + password=self.password, + db=self.db, + **ssl_args + ) + + return self.redis diff --git a/reference/providers/redis/operators/__init__.py b/reference/providers/redis/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/redis/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/redis/operators/redis_publish.py b/reference/providers/redis/operators/redis_publish.py new file mode 100644 index 0000000..e7e2094 --- /dev/null +++ b/reference/providers/redis/operators/redis_publish.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict + +from airflow.models import BaseOperator +from airflow.providers.redis.hooks.redis import RedisHook +from airflow.utils.decorators import apply_defaults + + +class RedisPublishOperator(BaseOperator): + """ + Publish a message to Redis. + + :param channel: redis channel to which the message is published (templated) + :type channel: str + :param message: the message to publish (templated) + :type message: str + :param redis_conn_id: redis connection to use + :type redis_conn_id: str + """ + + template_fields = ("channel", "message") + + @apply_defaults + def __init__( + self, + *, + channel: str, + message: str, + redis_conn_id: str = "redis_default", + **kwargs + ) -> None: + + super().__init__(**kwargs) + self.redis_conn_id = redis_conn_id + self.channel = channel + self.message = message + + def execute(self, context: Dict) -> None: + """ + Publish the message to Redis channel + + :param context: the context object + :type context: dict + """ + redis_hook = RedisHook(redis_conn_id=self.redis_conn_id) + + self.log.info( + "Sending message %s to Redis on channel %s", self.message, self.channel + ) + + result = redis_hook.get_conn().publish( + channel=self.channel, message=self.message + ) + + self.log.info("Result of publishing %s", result) diff --git a/reference/providers/redis/provider.yaml b/reference/providers/redis/provider.yaml new file mode 100644 index 0000000..e2d60c3 --- /dev/null +++ b/reference/providers/redis/provider.yaml @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-redis +name: Redis +description: | + `Redis `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Redis + external-doc-url: https://redis.io/ + logo: /integration-logos/redis/Redis.png + tags: [software] + +operators: + - integration-name: Redis + python-modules: + - airflow.providers.redis.operators.redis_publish + +sensors: + - integration-name: Redis + python-modules: + - airflow.providers.redis.sensors.redis_key + - airflow.providers.redis.sensors.redis_pub_sub + +hooks: + - integration-name: Redis + python-modules: + - airflow.providers.redis.hooks.redis + +hook-class-names: + - airflow.providers.redis.hooks.redis.RedisHook diff --git a/reference/providers/redis/sensors/__init__.py b/reference/providers/redis/sensors/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/redis/sensors/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/redis/sensors/redis_key.py b/reference/providers/redis/sensors/redis_key.py new file mode 100644 index 0000000..1158ad6 --- /dev/null +++ b/reference/providers/redis/sensors/redis_key.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Dict + +from airflow.providers.redis.hooks.redis import RedisHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class RedisKeySensor(BaseSensorOperator): + """Checks for the existence of a key in a Redis""" + + template_fields = ("key",) + ui_color = "#f0eee4" + + @apply_defaults + def __init__(self, *, key: str, redis_conn_id: str, **kwargs) -> None: + super().__init__(**kwargs) + self.redis_conn_id = redis_conn_id + self.key = key + + def poke(self, context: Dict) -> bool: + self.log.info("Sensor checks for existence of key: %s", self.key) + return RedisHook(self.redis_conn_id).get_conn().exists(self.key) diff --git a/reference/providers/redis/sensors/redis_pub_sub.py b/reference/providers/redis/sensors/redis_pub_sub.py new file mode 100644 index 0000000..9d8438e --- /dev/null +++ b/reference/providers/redis/sensors/redis_pub_sub.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict, List, Union + +from airflow.providers.redis.hooks.redis import RedisHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class RedisPubSubSensor(BaseSensorOperator): + """ + Redis sensor for reading a message from pub sub channels + + :param channels: The channels to be subscribed to (templated) + :type channels: str or list of str + :param redis_conn_id: the redis connection id + :type redis_conn_id: str + """ + + template_fields = ("channels",) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, *, channels: Union[List[str], str], redis_conn_id: str, **kwargs + ) -> None: + super().__init__(**kwargs) + self.channels = channels + self.redis_conn_id = redis_conn_id + self.pubsub = RedisHook(redis_conn_id=self.redis_conn_id).get_conn().pubsub() + self.pubsub.subscribe(self.channels) + + def poke(self, context: Dict) -> bool: + """ + Check for message on subscribed channels and write to xcom the message with key ``message`` + + An example of message ``{'type': 'message', 'pattern': None, 'channel': b'test', 'data': b'hello'}`` + + :param context: the context object + :type context: dict + :return: ``True`` if message (with type 'message') is available or ``False`` if not + """ + self.log.info( + "RedisPubSubSensor checking for message on channels: %s", self.channels + ) + + message = self.pubsub.get_message() + self.log.info("Message %s from channel %s", message, self.channels) + + # Process only message types + if message and message["type"] == "message": + + context["ti"].xcom_push(key="message", value=message) + self.pubsub.unsubscribe(self.channels) + + return True + + return False diff --git a/reference/providers/salesforce/CHANGELOG.rst b/reference/providers/salesforce/CHANGELOG.rst new file mode 100644 index 0000000..6680d0a --- /dev/null +++ b/reference/providers/salesforce/CHANGELOG.rst @@ -0,0 +1,47 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +2.0.0 +..... + +Tableau provider moved to separate 'tableau' provider + +Things done: + + - Tableau classes imports classes from 'tableau' provider with deprecation warning + +Breaking changes +~~~~~~~~~~~~~~~~ + +You need to install ``apache-airflow-providers-tableau`` provider additionally to get +Tableau integration working. + + +1.0.1 +..... + +Updated documentation and readme files. + + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/salesforce/__init__.py b/reference/providers/salesforce/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/salesforce/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/salesforce/example_dags/__init__.py b/reference/providers/salesforce/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/salesforce/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/salesforce/hooks/__init__.py b/reference/providers/salesforce/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/salesforce/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/salesforce/hooks/salesforce.py b/reference/providers/salesforce/hooks/salesforce.py new file mode 100644 index 0000000..8346cda --- /dev/null +++ b/reference/providers/salesforce/hooks/salesforce.py @@ -0,0 +1,350 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +This module contains a Salesforce Hook which allows you to connect to your Salesforce instance, +retrieve data from it, and write that data to a file for other uses. + +.. note:: this hook also relies on the simple_salesforce package: + https://github.com/simple-salesforce/simple-salesforce +""" +import logging +import time +from typing import Iterable, List, Optional + +import pandas as pd +from airflow.hooks.base import BaseHook +from simple_salesforce import Salesforce, api + +log = logging.getLogger(__name__) + + +class SalesforceHook(BaseHook): + """ + Create new connection to Salesforce and allows you to pull data out of SFDC and save it to a file. + + You can then use that file with other Airflow operators to move the data into another data source. + + :param conn_id: the name of the connection that has the parameters we need to connect to Salesforce. + The connection should be type `http` and include a user's security token in the `Extras` field. + :type conn_id: str + + .. note:: + For the HTTP connection type, you can include a + JSON structure in the `Extras` field. + We need a user's security token to connect to Salesforce. + So we define it in the `Extras` field as `{"security_token":"YOUR_SECURITY_TOKEN"}` + + For sandbox mode, add `{"domain":"test"}` in the `Extras` field + + """ + + def __init__(self, conn_id: str) -> None: + super().__init__() + self.conn_id = conn_id + self.conn = None + + def get_conn(self) -> api.Salesforce: + """Sign into Salesforce, only if we are not already signed in.""" + if not self.conn: + connection = self.get_connection(self.conn_id) + extras = connection.extra_dejson + self.conn = Salesforce( + username=connection.login, + password=connection.password, + security_token=extras["security_token"], + instance_url=connection.host, + domain=extras.get("domain"), + ) + return self.conn + + def make_query( + self, + query: str, + include_deleted: bool = False, + query_params: Optional[dict] = None, + ) -> dict: + """ + Make a query to Salesforce. + + :param query: The query to make to Salesforce. + :type query: str + :param include_deleted: True if the query should include deleted records. + :type include_deleted: bool + :param query_params: Additional optional arguments + :type query_params: dict + :return: The query result. + :rtype: dict + """ + conn = self.get_conn() + + self.log.info("Querying for all objects") + query_params = query_params or {} + query_results = conn.query_all( + query, include_deleted=include_deleted, **query_params + ) + + self.log.info( + "Received results: Total size: %s; Done: %s", + query_results["totalSize"], + query_results["done"], + ) + + return query_results + + def describe_object(self, obj: str) -> dict: + """ + Get the description of an object from Salesforce. + This description is the object's schema and + some extra metadata that Salesforce stores for each object. + + :param obj: The name of the Salesforce object that we are getting a description of. + :type obj: str + :return: the description of the Salesforce object. + :rtype: dict + """ + conn = self.get_conn() + + return conn.__getattr__(obj).describe() + + def get_available_fields(self, obj: str) -> List[str]: + """ + Get a list of all available fields for an object. + + :param obj: The name of the Salesforce object that we are getting a description of. + :type obj: str + :return: the names of the fields. + :rtype: list(str) + """ + self.get_conn() + + obj_description = self.describe_object(obj) + + return [field["name"] for field in obj_description["fields"]] + + def get_object_from_salesforce(self, obj: str, fields: Iterable[str]) -> dict: + """ + Get all instances of the `object` from Salesforce. + For each model, only get the fields specified in fields. + + All we really do underneath the hood is run: + SELECT FROM ; + + :param obj: The object name to get from Salesforce. + :type obj: str + :param fields: The fields to get from the object. + :type fields: iterable + :return: all instances of the object from Salesforce. + :rtype: dict + """ + query = f"SELECT {','.join(fields)} FROM {obj}" + + self.log.info( + "Making query to Salesforce: %s", + query if len(query) < 30 else " ... ".join([query[:15], query[-15:]]), + ) + + return self.make_query(query) + + @classmethod + def _to_timestamp(cls, column: pd.Series) -> pd.Series: + """ + Convert a column of a dataframe to UNIX timestamps if applicable + + :param column: A Series object representing a column of a dataframe. + :type column: pandas.Series + :return: a new series that maintains the same index as the original + :rtype: pandas.Series + """ + # try and convert the column to datetimes + # the column MUST have a four digit year somewhere in the string + # there should be a better way to do this, + # but just letting pandas try and convert every column without a format + # caused it to convert floats as well + # For example, a column of integers + # between 0 and 10 are turned into timestamps + # if the column cannot be converted, + # just return the original column untouched + try: + column = pd.to_datetime(column) + except ValueError: + log.error("Could not convert field to timestamps: %s", column.name) + return column + + # now convert the newly created datetimes into timestamps + # we have to be careful here + # because NaT cannot be converted to a timestamp + # so we have to return NaN + converted = [] + for value in column: + try: + converted.append(value.timestamp()) + except (ValueError, AttributeError): + converted.append(pd.np.NaN) + + return pd.Series(converted, index=column.index) + + def write_object_to_file( + self, + query_results: List[dict], + filename: str, + fmt: str = "csv", + coerce_to_timestamp: bool = False, + record_time_added: bool = False, + ) -> pd.DataFrame: + """ + Write query results to file. + + Acceptable formats are: + - csv: + comma-separated-values file. This is the default format. + - json: + JSON array. Each element in the array is a different row. + - ndjson: + JSON array but each element is new-line delimited instead of comma delimited like in `json` + + This requires a significant amount of cleanup. + Pandas doesn't handle output to CSV and json in a uniform way. + This is especially painful for datetime types. + Pandas wants to write them as strings in CSV, but as millisecond Unix timestamps. + + By default, this function will try and leave all values as they are represented in Salesforce. + You use the `coerce_to_timestamp` flag to force all datetimes to become Unix timestamps (UTC). + This is can be greatly beneficial as it will make all of your datetime fields look the same, + and makes it easier to work with in other database environments + + :param query_results: the results from a SQL query + :type query_results: list of dict + :param filename: the name of the file where the data should be dumped to + :type filename: str + :param fmt: the format you want the output in. Default: 'csv' + :type fmt: str + :param coerce_to_timestamp: True if you want all datetime fields to be converted into Unix timestamps. + False if you want them to be left in the same format as they were in Salesforce. + Leaving the value as False will result in datetimes being strings. Default: False + :type coerce_to_timestamp: bool + :param record_time_added: True if you want to add a Unix timestamp field + to the resulting data that marks when the data was fetched from Salesforce. Default: False + :type record_time_added: bool + :return: the dataframe that gets written to the file. + :rtype: pandas.Dataframe + """ + fmt = fmt.lower() + if fmt not in ["csv", "json", "ndjson"]: + raise ValueError(f"Format value is not recognized: {fmt}") + + df = self.object_to_df( + query_results=query_results, + coerce_to_timestamp=coerce_to_timestamp, + record_time_added=record_time_added, + ) + + # write the CSV or JSON file depending on the option + # NOTE: + # datetimes here are an issue. + # There is no good way to manage the difference + # for to_json, the options are an epoch or a ISO string + # but for to_csv, it will be a string output by datetime + # For JSON we decided to output the epoch timestamp in seconds + # (as is fairly standard for JavaScript) + # And for csv, we do a string + if fmt == "csv": + # there are also a ton of newline objects that mess up our ability to write to csv + # we remove these newlines so that the output is a valid CSV format + self.log.info("Cleaning data and writing to CSV") + possible_strings = df.columns[df.dtypes == "object"] + df[possible_strings] = ( + df[possible_strings] + .astype(str) + .apply(lambda x: x.str.replace("\r\n", "").str.replace("\n", "")) + ) + # write the dataframe + df.to_csv(filename, index=False) + elif fmt == "json": + df.to_json(filename, "records", date_unit="s") + elif fmt == "ndjson": + df.to_json(filename, "records", lines=True, date_unit="s") + + return df + + def object_to_df( + self, + query_results: List[dict], + coerce_to_timestamp: bool = False, + record_time_added: bool = False, + ) -> pd.DataFrame: + """ + Export query results to dataframe. + + By default, this function will try and leave all values as they are represented in Salesforce. + You use the `coerce_to_timestamp` flag to force all datetimes to become Unix timestamps (UTC). + This is can be greatly beneficial as it will make all of your datetime fields look the same, + and makes it easier to work with in other database environments + + :param query_results: the results from a SQL query + :type query_results: list of dict + :param coerce_to_timestamp: True if you want all datetime fields to be converted into Unix timestamps. + False if you want them to be left in the same format as they were in Salesforce. + Leaving the value as False will result in datetimes being strings. Default: False + :type coerce_to_timestamp: bool + :param record_time_added: True if you want to add a Unix timestamp field + to the resulting data that marks when the data was fetched from Salesforce. Default: False + :type record_time_added: bool + :return: the dataframe. + :rtype: pandas.Dataframe + """ + # this line right here will convert all integers to floats + # if there are any None/np.nan values in the column + # that's because None/np.nan cannot exist in an integer column + # we should write all of our timestamps as FLOATS in our final schema + df = pd.DataFrame.from_records(query_results, exclude=["attributes"]) + + df.columns = [column.lower() for column in df.columns] + + # convert columns with datetime strings to datetimes + # not all strings will be datetimes, so we ignore any errors that occur + # we get the object's definition at this point and only consider + # features that are DATE or DATETIME + if coerce_to_timestamp and df.shape[0] > 0: + # get the object name out of the query results + # it's stored in the "attributes" dictionary + # for each returned record + object_name = query_results[0]["attributes"]["type"] + + self.log.info("Coercing timestamps for: %s", object_name) + + schema = self.describe_object(object_name) + + # possible columns that can be converted to timestamps + # are the ones that are either date or datetime types + # strings are too general and we risk unintentional conversion + possible_timestamp_cols = [ + field["name"].lower() + for field in schema["fields"] + if field["type"] in ["date", "datetime"] + and field["name"].lower() in df.columns + ] + df[possible_timestamp_cols] = df[possible_timestamp_cols].apply( + self._to_timestamp + ) + + if record_time_added: + fetched_time = time.time() + df["time_fetched_from_salesforce"] = fetched_time + + return df diff --git a/reference/providers/salesforce/hooks/tableau.py b/reference/providers/salesforce/hooks/tableau.py new file mode 100644 index 0000000..e53a859 --- /dev/null +++ b/reference/providers/salesforce/hooks/tableau.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import warnings + +# pylint: disable=unused-import +from airflow.providers.tableau.hooks.tableau import ( # noqa + TableauHook, + TableauJobFinishCode, +) + +warnings.warn( + "This module is deprecated. Please use `airflow.providers.tableau.hooks.tableau`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/reference/providers/salesforce/operators/__init__.py b/reference/providers/salesforce/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/salesforce/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/salesforce/operators/tableau_refresh_workbook.py b/reference/providers/salesforce/operators/tableau_refresh_workbook.py new file mode 100644 index 0000000..309af33 --- /dev/null +++ b/reference/providers/salesforce/operators/tableau_refresh_workbook.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import warnings + +# pylint: disable=unused-import +from airflow.providers.tableau.operators.tableau_refresh_workbook import ( # noqa + TableauRefreshWorkbookOperator, +) + +warnings.warn( + "This module is deprecated. Please use `airflow.providers.tableau.operators.tableau_refresh_workbook`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/reference/providers/salesforce/provider.yaml b/reference/providers/salesforce/provider.yaml new file mode 100644 index 0000000..12696d8 --- /dev/null +++ b/reference/providers/salesforce/provider.yaml @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-salesforce +name: Salesforce +description: | + `Salesforce `__ + +versions: + - 2.0.0 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Salesforce + external-doc-url: https://www.salesforce.com/ + logo: /integration-logos/salesforce/Salesforce.png + tags: [service] + +operators: + - integration-name: Salesforce + python-modules: + - airflow.providers.salesforce.operators.tableau_refresh_workbook + +sensors: + - integration-name: Salesforce + python-modules: + - airflow.providers.salesforce.sensors.tableau_job_status + +hooks: + - integration-name: Tableau + python-modules: + - airflow.providers.salesforce.hooks.tableau + - integration-name: Salesforce + python-modules: + - airflow.providers.salesforce.hooks.salesforce + +hook-class-names: + - airflow.providers.salesforce.hooks.tableau.TableauHook diff --git a/reference/providers/salesforce/sensors/__init__.py b/reference/providers/salesforce/sensors/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/salesforce/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/salesforce/sensors/tableau_job_status.py b/reference/providers/salesforce/sensors/tableau_job_status.py new file mode 100644 index 0000000..076159e --- /dev/null +++ b/reference/providers/salesforce/sensors/tableau_job_status.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import warnings + +# pylint: disable=unused-import +from airflow.providers.tableau.sensors.tableau_job_status import ( # noqa + TableauJobFailedException, + TableauJobStatusSensor, +) + +warnings.warn( + "This module is deprecated. Please use `airflow.providers.tableau.sensors.tableau_job_status`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/reference/providers/samba/CHANGELOG.rst b/reference/providers/samba/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/samba/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/samba/__init__.py b/reference/providers/samba/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/samba/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/samba/hooks/__init__.py b/reference/providers/samba/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/samba/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/samba/hooks/samba.py b/reference/providers/samba/hooks/samba.py new file mode 100644 index 0000000..61e5223 --- /dev/null +++ b/reference/providers/samba/hooks/samba.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow.hooks.base import BaseHook +from smbclient import SambaClient + + +class SambaHook(BaseHook): + """Allows for interaction with an samba server.""" + + conn_name_attr = "samba_conn_id" + default_conn_name = "samba_default" + conn_type = "samba" + hook_name = "Samba" + + def __init__(self, samba_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn = self.get_connection(samba_conn_id) + + def get_conn(self) -> SambaClient: + samba = SambaClient( + server=self.conn.host, + share=self.conn.schema, + username=self.conn.login, + ip=self.conn.host, + password=self.conn.password, + ) + return samba + + def push_from_local(self, destination_filepath: str, local_filepath: str) -> None: + """Push local file to samba server""" + samba = self.get_conn() + if samba.exists(destination_filepath): + if samba.isfile(destination_filepath): + samba.remove(destination_filepath) + else: + folder = os.path.dirname(destination_filepath) + if not samba.exists(folder): + samba.mkdir(folder) + samba.upload(local_filepath, destination_filepath) diff --git a/reference/providers/samba/provider.yaml b/reference/providers/samba/provider.yaml new file mode 100644 index 0000000..4efc1dd --- /dev/null +++ b/reference/providers/samba/provider.yaml @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-samba +name: Samba +description: | + `Samba `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Samba + external-doc-url: https://www.samba.org/ + logo: /integration-logos/samba/Samba.png + tags: [protocol] + +hooks: + - integration-name: Samba + python-modules: + - airflow.providers.samba.hooks.samba + +hook-class-names: + - airflow.providers.samba.hooks.samba.SambaHook diff --git a/reference/providers/segment/CHANGELOG.rst b/reference/providers/segment/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/segment/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/segment/__init__.py b/reference/providers/segment/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/segment/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/segment/hooks/__init__.py b/reference/providers/segment/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/segment/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/segment/hooks/segment.py b/reference/providers/segment/hooks/segment.py new file mode 100644 index 0000000..bba15a1 --- /dev/null +++ b/reference/providers/segment/hooks/segment.py @@ -0,0 +1,92 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +This module contains a Segment Hook +which allows you to connect to your Segment account, +retrieve data from it or write to that file. + +NOTE: this hook also relies on the Segment analytics package: + https://github.com/segmentio/analytics-python +""" +import analytics +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +class SegmentHook(BaseHook): + """ + Create new connection to Segment + and allows you to pull data out of Segment or write to it. + + You can then use that file with other + Airflow operators to move the data around or interact with segment. + + :param segment_conn_id: the name of the connection that has the parameters + we need to connect to Segment. The connection should be type `json` and include a + write_key security token in the `Extras` field. + :type segment_conn_id: str + :param segment_debug_mode: Determines whether Segment should run in debug mode. + Defaults to False + :type segment_debug_mode: bool + + .. note:: + You must include a JSON structure in the `Extras` field. + We need a user's security token to connect to Segment. + So we define it in the `Extras` field as: + `{"write_key":"YOUR_SECURITY_TOKEN"}` + """ + + conn_name_attr = "segment_conn_id" + default_conn_name = "segment_default" + conn_type = "segment" + hook_name = "Segment" + + def __init__( + self, + segment_conn_id: str = "segment_default", + segment_debug_mode: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__() + self.segment_conn_id = segment_conn_id + self.segment_debug_mode = segment_debug_mode + self._args = args + self._kwargs = kwargs + + # get the connection parameters + self.connection = self.get_connection(self.segment_conn_id) + self.extras = self.connection.extra_dejson + self.write_key = self.extras.get("write_key") + if self.write_key is None: + raise AirflowException("No Segment write key provided") + + def get_conn(self) -> analytics: + self.log.info("Setting write key for Segment analytics connection") + analytics.debug = self.segment_debug_mode + if self.segment_debug_mode: + self.log.info("Setting Segment analytics connection to debug mode") + analytics.on_error = self.on_error + analytics.write_key = self.write_key + return analytics + + def on_error(self, error: str, items: str) -> None: + """Handles error callbacks when using Segment with segment_debug_mode set to True""" + self.log.error("Encountered Segment error: %s with items: %s", error, items) + raise AirflowException(f"Segment error: {error}") diff --git a/reference/providers/segment/operators/__init__.py b/reference/providers/segment/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/segment/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/segment/operators/segment_track_event.py b/reference/providers/segment/operators/segment_track_event.py new file mode 100644 index 0000000..e282165 --- /dev/null +++ b/reference/providers/segment/operators/segment_track_event.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Dict, Optional + +from airflow.models import BaseOperator +from airflow.providers.segment.hooks.segment import SegmentHook +from airflow.utils.decorators import apply_defaults + + +class SegmentTrackEventOperator(BaseOperator): + """ + Send Track Event to Segment for a specified user_id and event + + :param user_id: The ID for this user in your database. (templated) + :type user_id: str + :param event: The name of the event you're tracking. (templated) + :type event: str + :param properties: A dictionary of properties for the event. (templated) + :type properties: dict + :param segment_conn_id: The connection ID to use when connecting to Segment. + :type segment_conn_id: str + :param segment_debug_mode: Determines whether Segment should run in debug mode. + Defaults to False + :type segment_debug_mode: bool + """ + + template_fields = ("user_id", "event", "properties") + ui_color = "#ffd700" + + @apply_defaults + def __init__( + self, + *, + user_id: str, + event: str, + properties: Optional[dict] = None, + segment_conn_id: str = "segment_default", + segment_debug_mode: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.user_id = user_id + self.event = event + properties = properties or {} + self.properties = properties + self.segment_debug_mode = segment_debug_mode + self.segment_conn_id = segment_conn_id + + def execute(self, context: Dict) -> None: + hook = SegmentHook( + segment_conn_id=self.segment_conn_id, + segment_debug_mode=self.segment_debug_mode, + ) + + self.log.info( + "Sending track event (%s) for user id: %s with properties: %s", + self.event, + self.user_id, + self.properties, + ) + + # pylint: disable=no-member + hook.track(user_id=self.user_id, event=self.event, properties=self.properties) # type: ignore diff --git a/reference/providers/segment/provider.yaml b/reference/providers/segment/provider.yaml new file mode 100644 index 0000000..faf0f9b --- /dev/null +++ b/reference/providers/segment/provider.yaml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-segment +name: Segment +description: | + `Segment `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Segment + external-doc-url: https://segment.com/docs/ + logo: /integration-logos/segment/Segment.png + tags: [service] + +operators: + - integration-name: Segment + python-modules: + - airflow.providers.segment.operators.segment_track_event + +hooks: + - integration-name: Segment + python-modules: + - airflow.providers.segment.hooks.segment + +hook-class-names: + - airflow.providers.segment.hooks.segment.SegmentHook diff --git a/reference/providers/sendgrid/CHANGELOG.rst b/reference/providers/sendgrid/CHANGELOG.rst new file mode 100644 index 0000000..d8e7a0b --- /dev/null +++ b/reference/providers/sendgrid/CHANGELOG.rst @@ -0,0 +1,40 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.2 +..... + +Bug fixes +~~~~~~~~~ + +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` + +1.0.1 +..... + +Updated documentation and readme files. + +* ``Deprecate email credentials from environment variables. (#13601)`` + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/sendgrid/__init__.py b/reference/providers/sendgrid/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/sendgrid/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/sendgrid/provider.yaml b/reference/providers/sendgrid/provider.yaml new file mode 100644 index 0000000..aab6f52 --- /dev/null +++ b/reference/providers/sendgrid/provider.yaml @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-sendgrid +name: Sendgrid +description: | + `Sendgrid `__ + +versions: + - 1.0.2 + - 1.0.1 + - 1.0.0 diff --git a/reference/providers/sendgrid/utils/__init__.py b/reference/providers/sendgrid/utils/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/sendgrid/utils/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/sendgrid/utils/emailer.py b/reference/providers/sendgrid/utils/emailer.py new file mode 100644 index 0000000..caf23e8 --- /dev/null +++ b/reference/providers/sendgrid/utils/emailer.py @@ -0,0 +1,153 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Airflow module for email backend using sendgrid""" + +import base64 +import logging +import mimetypes +import os +import warnings +from typing import Dict, Iterable, Optional, Union + +import sendgrid +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.utils.email import get_email_address_list +from sendgrid.helpers.mail import ( + Attachment, + Category, + Content, + CustomArg, + Email, + Mail, + MailSettings, + Personalization, + SandBoxMode, +) + +log = logging.getLogger(__name__) + +AddressesType = Union[str, Iterable[str]] + + +def send_email( # pylint: disable=too-many-locals + to: AddressesType, + subject: str, + html_content: str, + files: Optional[AddressesType] = None, + cc: Optional[AddressesType] = None, + bcc: Optional[AddressesType] = None, + sandbox_mode: bool = False, + conn_id: str = "sendgrid_default", + **kwargs, +) -> None: + """ + Send an email with html content using `Sendgrid `__. + + .. note:: + For more information, see :ref:`email-configuration-sendgrid` + """ + if files is None: + files = [] + + mail = Mail() + from_email = kwargs.get("from_email") or os.environ.get("SENDGRID_MAIL_FROM") + from_name = kwargs.get("from_name") or os.environ.get("SENDGRID_MAIL_SENDER") + mail.from_email = Email(from_email, from_name) + mail.subject = subject + mail.mail_settings = MailSettings() + + if sandbox_mode: + mail.mail_settings.sandbox_mode = SandBoxMode(enable=True) + + # Add the recipient list of to emails. + personalization = Personalization() + to = get_email_address_list(to) + for to_address in to: + personalization.add_to(Email(to_address)) + if cc: + cc = get_email_address_list(cc) + for cc_address in cc: + personalization.add_cc(Email(cc_address)) + if bcc: + bcc = get_email_address_list(bcc) + for bcc_address in bcc: + personalization.add_bcc(Email(bcc_address)) + + # Add custom_args to personalization if present + pers_custom_args = kwargs.get("personalization_custom_args") + if isinstance(pers_custom_args, dict): + for key in pers_custom_args.keys(): + personalization.add_custom_arg(CustomArg(key, pers_custom_args[key])) + + mail.add_personalization(personalization) + mail.add_content(Content("text/html", html_content)) + + categories = kwargs.get("categories", []) + for cat in categories: + mail.add_category(Category(cat)) + + # Add email attachment. + for fname in files: + basename = os.path.basename(fname) + + with open(fname, "rb") as file: + content = base64.b64encode(file.read()).decode("utf-8") + + attachment = Attachment( + file_content=content, + file_type=mimetypes.guess_type(basename)[0], + file_name=basename, + disposition="attachment", + content_id=f"<{basename}>", + ) + + mail.add_attachment(attachment) + _post_sendgrid_mail(mail.get(), conn_id) + + +def _post_sendgrid_mail(mail_data: Dict, conn_id: str = "sendgrid_default") -> None: + api_key = None + try: + conn = BaseHook.get_connection(conn_id) + api_key = conn.password + except AirflowException: + pass + if api_key is None: + warnings.warn( + "Fetching Sendgrid credentials from environment variables will be deprecated in a future " + "release. Please set credentials using a connection instead.", + PendingDeprecationWarning, + stacklevel=2, + ) + api_key = os.environ.get("SENDGRID_API_KEY") + sendgrid_client = sendgrid.SendGridAPIClient(api_key=api_key) + response = sendgrid_client.client.mail.send.post(request_body=mail_data) + # 2xx status code. + if 200 <= response.status_code < 300: + log.info( + "Email with subject %s is successfully sent to recipients: %s", + mail_data["subject"], + mail_data["personalizations"], + ) + else: + log.error( + "Failed to send out email with subject %s, status code: %s", + mail_data["subject"], + response.status_code, + ) diff --git a/reference/providers/sftp/CHANGELOG.rst b/reference/providers/sftp/CHANGELOG.rst new file mode 100644 index 0000000..53b3079 --- /dev/null +++ b/reference/providers/sftp/CHANGELOG.rst @@ -0,0 +1,45 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.1.1 +..... + +Bug fixes +~~~~~~~~~ + +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` + + +1.1.0 +..... + +Updated documentation and readme files. + +Features +~~~~~~~~ + +* ``Add retryer to SFTP hook connection (#13065)`` + + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/sftp/__init__.py b/reference/providers/sftp/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/sftp/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/sftp/hooks/__init__.py b/reference/providers/sftp/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/sftp/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/sftp/hooks/sftp.py b/reference/providers/sftp/hooks/sftp.py new file mode 100644 index 0000000..061efed --- /dev/null +++ b/reference/providers/sftp/hooks/sftp.py @@ -0,0 +1,328 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains SFTP hook.""" +import datetime +import stat +from typing import Dict, List, Optional, Tuple + +import pysftp +import tenacity +from airflow.providers.ssh.hooks.ssh import SSHHook +from paramiko import SSHException + + +class SFTPHook(SSHHook): + """ + This hook is inherited from SSH hook. Please refer to SSH hook for the input + arguments. + + Interact with SFTP. Aims to be interchangeable with FTPHook. + + :Pitfalls:: + + - In contrast with FTPHook describe_directory only returns size, type and + modify. It doesn't return unix.owner, unix.mode, perm, unix.group and + unique. + - retrieve_file and store_file only take a local full path and not a + buffer. + - If no mode is passed to create_directory it will be created with 777 + permissions. + + Errors that may occur throughout but should be handled downstream. + """ + + conn_name_attr = "ftp_conn_id" + default_conn_name = "sftp_default" + conn_type = "sftp" + hook_name = "SFTP" + + @staticmethod + def get_ui_field_behaviour() -> Dict: + return { + "hidden_fields": ["schema"], + "relabeling": { + "login": "Username", + }, + } + + def __init__(self, ftp_conn_id: str = "sftp_default", *args, **kwargs) -> None: + kwargs["ssh_conn_id"] = ftp_conn_id + super().__init__(*args, **kwargs) + + self.conn = None + self.private_key_pass = None + self.ciphers = None + + # Fail for unverified hosts, unless this is explicitly allowed + self.no_host_key_check = False + + if self.ssh_conn_id is not None: + conn = self.get_connection(self.ssh_conn_id) + if conn.extra is not None: + extra_options = conn.extra_dejson + if "private_key_pass" in extra_options: + self.private_key_pass = extra_options.get("private_key_pass") + + # For backward compatibility + # TODO: remove in Airflow 2.1 + import warnings + + if "ignore_hostkey_verification" in extra_options: + warnings.warn( + "Extra option `ignore_hostkey_verification` is deprecated." + "Please use `no_host_key_check` instead." + "This option will be removed in Airflow 2.1", + DeprecationWarning, + stacklevel=2, + ) + self.no_host_key_check = ( + str(extra_options["ignore_hostkey_verification"]).lower() + == "true" + ) + + if "no_host_key_check" in extra_options: + self.no_host_key_check = ( + str(extra_options["no_host_key_check"]).lower() == "true" + ) + + if "ciphers" in extra_options: + self.ciphers = extra_options["ciphers"] + + if "private_key" in extra_options: + warnings.warn( + "Extra option `private_key` is deprecated." + "Please use `key_file` instead." + "This option will be removed in Airflow 2.1", + DeprecationWarning, + stacklevel=2, + ) + self.key_file = extra_options.get("private_key") + + @tenacity.retry( + stop=tenacity.stop_after_delay(10), + wait=tenacity.wait_exponential(multiplier=1, max=10), + retry=tenacity.retry_if_exception_type(SSHException), + reraise=True, + ) + def get_conn(self) -> pysftp.Connection: + """Returns an SFTP connection object""" + if self.conn is None: + cnopts = pysftp.CnOpts() + if self.no_host_key_check: + cnopts.hostkeys = None + else: + if self.host_key is not None: + cnopts.hostkeys.add(self.remote_host, "ssh-rsa", self.host_key) + else: + pass # will fallback to system host keys if none explicitly specified in conn extra + + cnopts.compression = self.compress + cnopts.ciphers = self.ciphers + conn_params = { + "host": self.remote_host, + "port": self.port, + "username": self.username, + "cnopts": cnopts, + } + if self.password and self.password.strip(): + conn_params["password"] = self.password + if self.key_file: + conn_params["private_key"] = self.key_file + if self.private_key_pass: + conn_params["private_key_pass"] = self.private_key_pass + + self.conn = pysftp.Connection(**conn_params) + return self.conn + + def close_conn(self) -> None: + """Closes the connection""" + if self.conn is not None: + self.conn.close() + self.conn = None + + def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]: + """ + Returns a dictionary of {filename: {attributes}} for all files + on the remote system (where the MLSD command is supported). + + :param path: full path to the remote directory + :type path: str + """ + conn = self.get_conn() + flist = conn.listdir_attr(path) + files = {} + for f in flist: + modify = datetime.datetime.fromtimestamp(f.st_mtime).strftime( + "%Y%m%d%H%M%S" + ) + files[f.filename] = { + "size": f.st_size, + "type": "dir" if stat.S_ISDIR(f.st_mode) else "file", + "modify": modify, + } + return files + + def list_directory(self, path: str) -> List[str]: + """ + Returns a list of files on the remote system. + + :param path: full path to the remote directory to list + :type path: str + """ + conn = self.get_conn() + files = conn.listdir(path) + return files + + def create_directory(self, path: str, mode: int = 777) -> None: + """ + Creates a directory on the remote system. + + :param path: full path to the remote directory to create + :type path: str + :param mode: int representation of octal mode for directory + """ + conn = self.get_conn() + conn.makedirs(path, mode) + + def delete_directory(self, path: str) -> None: + """ + Deletes a directory on the remote system. + + :param path: full path to the remote directory to delete + :type path: str + """ + conn = self.get_conn() + conn.rmdir(path) + + def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None: + """ + Transfers the remote file to a local location. + If local_full_path is a string path, the file will be put + at that location + + :param remote_full_path: full path to the remote file + :type remote_full_path: str + :param local_full_path: full path to the local file + :type local_full_path: str + """ + conn = self.get_conn() + self.log.info("Retrieving file from FTP: %s", remote_full_path) + conn.get(remote_full_path, local_full_path) + self.log.info("Finished retrieving file from FTP: %s", remote_full_path) + + def store_file(self, remote_full_path: str, local_full_path: str) -> None: + """ + Transfers a local file to the remote location. + If local_full_path_or_buffer is a string path, the file will be read + from that location + + :param remote_full_path: full path to the remote file + :type remote_full_path: str + :param local_full_path: full path to the local file + :type local_full_path: str + """ + conn = self.get_conn() + conn.put(local_full_path, remote_full_path) + + def delete_file(self, path: str) -> None: + """ + Removes a file on the FTP Server + + :param path: full path to the remote file + :type path: str + """ + conn = self.get_conn() + conn.remove(path) + + def get_mod_time(self, path: str) -> str: + """ + Returns modification time. + + :param path: full path to the remote file + :type path: str + """ + conn = self.get_conn() + ftp_mdtm = conn.stat(path).st_mtime + return datetime.datetime.fromtimestamp(ftp_mdtm).strftime("%Y%m%d%H%M%S") + + def path_exists(self, path: str) -> bool: + """ + Returns True if a remote entity exists + + :param path: full path to the remote file or directory + :type path: str + """ + conn = self.get_conn() + return conn.exists(path) + + @staticmethod + def _is_path_match( + path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None + ) -> bool: + """ + Return True if given path starts with prefix (if set) and ends with delimiter (if set). + + :param path: path to be checked + :type path: str + :param prefix: if set path will be checked is starting with prefix + :type prefix: str + :param delimiter: if set path will be checked is ending with suffix + :type delimiter: str + :return: bool + """ + if prefix is not None and not path.startswith(prefix): + return False + if delimiter is not None and not path.endswith(delimiter): + return False + return True + + def get_tree_map( + self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None + ) -> Tuple[List[str], List[str], List[str]]: + """ + Return tuple with recursive lists of files, directories and unknown paths from given path. + It is possible to filter results by giving prefix and/or delimiter parameters. + + :param path: path from which tree will be built + :type path: str + :param prefix: if set paths will be added if start with prefix + :type prefix: str + :param delimiter: if set paths will be added if end with delimiter + :type delimiter: str + :return: tuple with list of files, dirs and unknown items + :rtype: Tuple[List[str], List[str], List[str]] + """ + conn = self.get_conn() + files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str] + + def append_matching_path_callback(list_): + return ( + lambda item: list_.append(item) + if self._is_path_match(item, prefix, delimiter) + else None + ) + + conn.walktree( + remotepath=path, + fcallback=append_matching_path_callback(files), + dcallback=append_matching_path_callback(dirs), + ucallback=append_matching_path_callback(unknowns), + recurse=True, + ) + + return files, dirs, unknowns diff --git a/reference/providers/sftp/operators/__init__.py b/reference/providers/sftp/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/sftp/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/sftp/operators/sftp.py b/reference/providers/sftp/operators/sftp.py new file mode 100644 index 0000000..64ff6c8 --- /dev/null +++ b/reference/providers/sftp/operators/sftp.py @@ -0,0 +1,191 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains SFTP operator.""" +import os +from pathlib import Path +from typing import Any + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.ssh.hooks.ssh import SSHHook +from airflow.utils.decorators import apply_defaults + + +class SFTPOperation: + """Operation that can be used with SFTP/""" + + PUT = "put" + GET = "get" + + +class SFTPOperator(BaseOperator): + """ + SFTPOperator for transferring files from remote host to local or vice a versa. + This operator uses ssh_hook to open sftp transport channel that serve as basis + for file transfer. + + :param ssh_hook: predefined ssh_hook to use for remote execution. + Either `ssh_hook` or `ssh_conn_id` needs to be provided. + :type ssh_hook: airflow.providers.ssh.hooks.ssh.SSHHook + :param ssh_conn_id: connection id from airflow Connections. + `ssh_conn_id` will be ignored if `ssh_hook` is provided. + :type ssh_conn_id: str + :param remote_host: remote host to connect (templated) + Nullable. If provided, it will replace the `remote_host` which was + defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. + :type remote_host: str + :param local_filepath: local file path to get or put. (templated) + :type local_filepath: str + :param remote_filepath: remote file path to get or put. (templated) + :type remote_filepath: str + :param operation: specify operation 'get' or 'put', defaults to put + :type operation: str + :param confirm: specify if the SFTP operation should be confirmed, defaults to True + :type confirm: bool + :param create_intermediate_dirs: create missing intermediate directories when + copying from remote to local and vice-versa. Default is False. + + Example: The following task would copy ``file.txt`` to the remote host + at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they + don't exist. If the parameter is not passed it would error as the directory + does not exist. :: + + put_file = SFTPOperator( + task_id="test_sftp", + ssh_conn_id="ssh_default", + local_filepath="/tmp/file.txt", + remote_filepath="/tmp/tmp1/tmp2/file.txt", + operation="put", + create_intermediate_dirs=True, + dag=dag + ) + + :type create_intermediate_dirs: bool + """ + + template_fields = ("local_filepath", "remote_filepath", "remote_host") + + @apply_defaults + def __init__( + self, + *, + ssh_hook=None, + ssh_conn_id=None, + remote_host=None, + local_filepath=None, + remote_filepath=None, + operation=SFTPOperation.PUT, + confirm=True, + create_intermediate_dirs=False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.ssh_hook = ssh_hook + self.ssh_conn_id = ssh_conn_id + self.remote_host = remote_host + self.local_filepath = local_filepath + self.remote_filepath = remote_filepath + self.operation = operation + self.confirm = confirm + self.create_intermediate_dirs = create_intermediate_dirs + if not ( + self.operation.lower() == SFTPOperation.GET + or self.operation.lower() == SFTPOperation.PUT + ): + raise TypeError( + "unsupported operation value {}, expected {} or {}".format( + self.operation, SFTPOperation.GET, SFTPOperation.PUT + ) + ) + + def execute(self, context: Any) -> str: + file_msg = None + try: + if self.ssh_conn_id: + if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): + self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") + else: + self.log.info( + "ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook." + ) + self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) + + if not self.ssh_hook: + raise AirflowException( + "Cannot operate without ssh_hook or ssh_conn_id." + ) + + if self.remote_host is not None: + self.log.info( + "remote_host is provided explicitly. " + "It will replace the remote_host which was defined " + "in ssh_hook or predefined in connection of ssh_conn_id." + ) + self.ssh_hook.remote_host = self.remote_host + + with self.ssh_hook.get_conn() as ssh_client: + sftp_client = ssh_client.open_sftp() + if self.operation.lower() == SFTPOperation.GET: + local_folder = os.path.dirname(self.local_filepath) + if self.create_intermediate_dirs: + Path(local_folder).mkdir(parents=True, exist_ok=True) + file_msg = f"from {self.remote_filepath} to {self.local_filepath}" + self.log.info("Starting to transfer %s", file_msg) + sftp_client.get(self.remote_filepath, self.local_filepath) + else: + remote_folder = os.path.dirname(self.remote_filepath) + if self.create_intermediate_dirs: + _make_intermediate_dirs( + sftp_client=sftp_client, + remote_directory=remote_folder, + ) + file_msg = f"from {self.local_filepath} to {self.remote_filepath}" + self.log.info("Starting to transfer file %s", file_msg) + sftp_client.put( + self.local_filepath, self.remote_filepath, confirm=self.confirm + ) + + except Exception as e: + raise AirflowException( + f"Error while transferring {file_msg}, error: {str(e)}" + ) + + return self.local_filepath + + +def _make_intermediate_dirs(sftp_client, remote_directory) -> None: + """ + Create all the intermediate directories in a remote host + + :param sftp_client: A Paramiko SFTP client. + :param remote_directory: Absolute Path of the directory containing the file + :return: + """ + if remote_directory == "/": + sftp_client.chdir("/") + return + if remote_directory == "": + return + try: + sftp_client.chdir(remote_directory) + except OSError: + dirname, basename = os.path.split(remote_directory.rstrip("/")) + _make_intermediate_dirs(sftp_client, dirname) + sftp_client.mkdir(basename) + sftp_client.chdir(basename) + return diff --git a/reference/providers/sftp/provider.yaml b/reference/providers/sftp/provider.yaml new file mode 100644 index 0000000..ba0049e --- /dev/null +++ b/reference/providers/sftp/provider.yaml @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-sftp +name: SFTP +description: | + `SSH File Transfer Protocol (SFTP) `__ + +versions: + - 1.1.1 + - 1.1.0 + - 1.0.0 + +integrations: + - integration-name: SSH File Transfer Protocol (SFTP) + external-doc-url: https://tools.ietf.org/wg/secsh/draft-ietf-secsh-filexfer/ + logo: /integration-logos/sftp/SFTP.png + tags: [protocol] + +operators: + - integration-name: SSH File Transfer Protocol (SFTP) + python-modules: + - airflow.providers.sftp.operators.sftp + +sensors: + - integration-name: SSH File Transfer Protocol (SFTP) + python-modules: + - airflow.providers.sftp.sensors.sftp + +hooks: + - integration-name: SSH File Transfer Protocol (SFTP) + python-modules: + - airflow.providers.sftp.hooks.sftp + +hook-class-names: + - airflow.providers.sftp.hooks.sftp.SFTPHook diff --git a/reference/providers/sftp/sensors/__init__.py b/reference/providers/sftp/sensors/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/sftp/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/sftp/sensors/sftp.py b/reference/providers/sftp/sensors/sftp.py new file mode 100644 index 0000000..8438d15 --- /dev/null +++ b/reference/providers/sftp/sensors/sftp.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains SFTP sensor.""" +from typing import Optional + +from airflow.providers.sftp.hooks.sftp import SFTPHook +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults +from paramiko import SFTP_NO_SUCH_FILE + + +class SFTPSensor(BaseSensorOperator): + """ + Waits for a file or directory to be present on SFTP. + + :param path: Remote file or directory path + :type path: str + :param sftp_conn_id: The connection to run the sensor against + :type sftp_conn_id: str + """ + + template_fields = ("path",) + + @apply_defaults + def __init__( + self, *, path: str, sftp_conn_id: str = "sftp_default", **kwargs + ) -> None: + super().__init__(**kwargs) + self.path = path + self.hook: Optional[SFTPHook] = None + self.sftp_conn_id = sftp_conn_id + + def poke(self, context: dict) -> bool: + self.hook = SFTPHook(self.sftp_conn_id) + self.log.info("Poking for %s", self.path) + try: + self.hook.get_mod_time(self.path) + except OSError as e: + if e.errno != SFTP_NO_SUCH_FILE: + raise e + return False + self.hook.close_conn() + return True diff --git a/reference/providers/singularity/CHANGELOG.rst b/reference/providers/singularity/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/singularity/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/singularity/__init__.py b/reference/providers/singularity/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/singularity/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/singularity/example_dags/__init__.py b/reference/providers/singularity/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/singularity/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/singularity/example_dags/example_singularity.py b/reference/providers/singularity/example_dags/example_singularity.py new file mode 100644 index 0000000..f3a7dc7 --- /dev/null +++ b/reference/providers/singularity/example_dags/example_singularity.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import timedelta + +from airflow import DAG +from airflow.operators.bash_operator import BashOperator +from airflow.providers.singularity.operators.singularity import SingularityOperator +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), +} + +with DAG( + "singularity_sample", + default_args=default_args, + schedule_interval=timedelta(minutes=10), + start_date=days_ago(0), +) as dag: + + t1 = BashOperator(task_id="print_date", bash_command="date", dag=dag) + + t2 = BashOperator(task_id="sleep", bash_command="sleep 5", retries=3, dag=dag) + + t3 = SingularityOperator( + command="/bin/sleep 30", + image="docker://busybox:1.30.1", + task_id="singularity_op_tester", + dag=dag, + ) + + t4 = BashOperator( + task_id="print_hello", bash_command='echo "hello world!!!"', dag=dag + ) + + t1 >> [t2, t3] + t3 >> t4 diff --git a/reference/providers/singularity/operators/__init__.py b/reference/providers/singularity/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/singularity/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/singularity/operators/singularity.py b/reference/providers/singularity/operators/singularity.py new file mode 100644 index 0000000..60c906e --- /dev/null +++ b/reference/providers/singularity/operators/singularity.py @@ -0,0 +1,189 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import ast +import os +import shutil +from typing import Any, Dict, List, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from spython.main import Client + + +class SingularityOperator(BaseOperator): + """ + Execute a command inside a Singularity container + + Singularity has more seamless connection to the host than Docker, so + no special binds are needed to ensure binding content in the user $HOME + and temporary directories. If the user needs custom binds, this can + be done with --volumes + + :param image: Singularity image or URI from which to create the container. + :type image: str + :param auto_remove: Delete the container when the process exits + The default is False. + :type auto_remove: bool + :param command: Command to be run in the container. (templated) + :type command: str or list + :param start_command: start command to pass to the container instance + :type start_command: string or list + :param environment: Environment variables to set in the container. (templated) + :type environment: dict + :param working_dir: Set a working directory for the instance. + :type working_dir: str + :param force_pull: Pull the image on every run. Default is False. + :type force_pull: bool + :param volumes: List of volumes to mount into the container, e.g. + ``['/host/path:/container/path', '/host/path2:/container/path2']``. + :param options: other flags (list) to provide to the instance start + :type options: list + :param working_dir: Working directory to + set on the container (equivalent to the -w switch the docker client) + :type working_dir: str + """ + + template_fields = ( + "command", + "environment", + ) + template_ext = ( + ".sh", + ".bash", + ) + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + image: str, + command: Union[str, ast.AST], + start_command: Optional[Union[str, List[str]]] = None, + environment: Optional[Dict[str, Any]] = None, + pull_folder: Optional[str] = None, + working_dir: Optional[str] = None, + force_pull: Optional[bool] = False, + volumes: Optional[List[str]] = None, + options: Optional[List[str]] = None, + auto_remove: Optional[bool] = False, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self.auto_remove = auto_remove + self.command = command + self.start_command = start_command + self.environment = environment or {} + self.force_pull = force_pull + self.image = image + self.instance = None + self.options = options or [] + self.pull_folder = pull_folder + self.volumes = volumes or [] + self.working_dir = working_dir + self.cli = None + self.container = None + + def execute(self, context) -> None: + + self.log.info("Preparing Singularity container %s", self.image) + self.cli = Client + + if not self.command: + raise AirflowException("You must define a command.") + + # Pull the container if asked, and ensure not a binary file + if self.force_pull and not os.path.exists(self.image): + self.log.info("Pulling container %s", self.image) + image = self.cli.pull( # type: ignore[attr-defined] + self.image, stream=True, pull_folder=self.pull_folder + ) + + # If we need to stream result for the user, returns lines + if isinstance(image, list): + lines = image.pop() + image = image[0] + for line in lines: + self.log.info(line) + + # Update the image to be a filepath on the system + self.image = image + + # Prepare list of binds + for bind in self.volumes: + self.options += ["--bind", bind] + + # Does the user want a custom working directory? + if self.working_dir is not None: + self.options += ["--workdir", self.working_dir] + + # Export environment before instance is run + for enkey, envar in self.environment.items(): + self.log.debug("Exporting %s=%s", envar, enkey) + os.putenv(enkey, envar) + os.environ[enkey] = envar + + # Create a container instance + self.log.debug("Options include: %s", self.options) + self.instance = self.cli.instance( # type: ignore[attr-defined] + self.image, options=self.options, args=self.start_command, start=False + ) + + self.instance.start() # type: ignore[attr-defined] + self.log.info(self.instance.cmd) # type: ignore[attr-defined] + self.log.info("Created instance %s from %s", self.instance, self.image) + + self.log.info("Running command %s", self._get_command()) + self.cli.quiet = True # type: ignore[attr-defined] + result = self.cli.execute( # type: ignore[attr-defined] + self.instance, self._get_command(), return_result=True + ) + + # Stop the instance + self.log.info("Stopping instance %s", self.instance) + self.instance.stop() # type: ignore[attr-defined] + + if self.auto_remove is True: + if self.auto_remove and os.path.exists(self.image): + shutil.rmtree(self.image) + + # If the container failed, raise the exception + if result["return_code"] != 0: + message = result["message"] + raise AirflowException(f"Singularity failed: {message}") + + self.log.info("Output from command %s", result["message"]) + + def _get_command(self) -> Optional[Any]: + if self.command is not None and self.command.strip().find("[") == 0: # type: ignore + commands = ast.literal_eval(self.command) + else: + commands = self.command + return commands + + def on_kill(self) -> None: + if self.instance is not None: + self.log.info("Stopping Singularity instance") + self.instance.stop() + + # If an image exists, clean it up + if self.auto_remove is True: + if self.auto_remove and os.path.exists(self.image): + shutil.rmtree(self.image) diff --git a/reference/providers/singularity/provider.yaml b/reference/providers/singularity/provider.yaml new file mode 100644 index 0000000..89d7c14 --- /dev/null +++ b/reference/providers/singularity/provider.yaml @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-singularity +name: Singularity +description: | + `Singularity `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Singularity + external-doc-url: https://sylabs.io/guides/latest/user-guide/ + logo: /integration-logos/singularity/Singularity.png + tags: [software] + +operators: + - integration-name: Singularity + python-modules: + - airflow.providers.singularity.operators.singularity diff --git a/reference/providers/slack/CHANGELOG.rst b/reference/providers/slack/CHANGELOG.rst new file mode 100644 index 0000000..08cc5c1 --- /dev/null +++ b/reference/providers/slack/CHANGELOG.rst @@ -0,0 +1,49 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* ``Don't allow SlackHook.call method accept *args (#14289)`` + + +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +We updated the support for ``slack_sdk`` from ``>=2.0.0,<3.0.0`` to ``>=3.0.0,<4.0.0``. In most cases, +this doesn't mean any breaking changes to the DAG files, but if you used this library directly +then you have to make the changes. For details, see +`the Migration Guide `_ +for Python Slack SDK. + +* ``Upgrade slack_sdk to v3 (#13745)`` + + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/slack/__init__.py b/reference/providers/slack/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/slack/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/slack/hooks/__init__.py b/reference/providers/slack/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/slack/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/slack/hooks/slack.py b/reference/providers/slack/hooks/slack.py new file mode 100644 index 0000000..916c48a --- /dev/null +++ b/reference/providers/slack/hooks/slack.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Slack""" +from typing import Any, Optional + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from slack_sdk import WebClient + + +class SlackHook(BaseHook): # noqa + """ + Creates a Slack connection, to be used for calls. Takes both Slack API token directly and + connection that has Slack API token. If both supplied, Slack API token will be used. + Exposes also the rest of slack.WebClient args + Examples: + + .. code-block:: python + + # Create hook + slack_hook = SlackHook(token="xxx") # or slack_hook = SlackHook(slack_conn_id="slack") + + # Call generic API with parameters (errors are handled by hook) + # For more details check https://api.slack.com/methods/chat.postMessage + slack_hook.call("chat.postMessage", json={"channel": "#random", "text": "Hello world!"}) + + # Call method from Slack SDK (you have to handle errors yourself) + # For more details check https://slack.dev/python-slack-sdk/web/index.html#messaging + slack_hook.client.chat_postMessage(channel="#random", text="Hello world!") + + :param token: Slack API token + :type token: str + :param slack_conn_id: connection that has Slack API token in the password field + :type slack_conn_id: str + :param use_session: A boolean specifying if the client should take advantage of + connection pooling. Default is True. + :type use_session: bool + :param base_url: A string representing the Slack API base URL. Default is + ``https://www.slack.com/api/`` + :type base_url: str + :param timeout: The maximum number of seconds the client will wait + to connect and receive a response from Slack. Default is 30 seconds. + :type timeout: int + """ + + def __init__( + self, + token: Optional[str] = None, + slack_conn_id: Optional[str] = None, + **client_args: Any, + ) -> None: + super().__init__() + self.token = self.__get_token(token, slack_conn_id) + self.client = WebClient(self.token, **client_args) + + def __get_token(self, token: Any, slack_conn_id: Any) -> str: + if token is not None: + return token + + if slack_conn_id is not None: + conn = self.get_connection(slack_conn_id) + + if not getattr(conn, "password", None): + raise AirflowException("Missing token(password) in Slack connection") + return conn.password + + raise AirflowException( + "Cannot get token: No valid Slack token nor slack_conn_id supplied." + ) + + def call(self, api_method: str, **kwargs) -> None: + """ + Calls Slack WebClient `WebClient.api_call` with given arguments. + + :param api_method: The target Slack API method. e.g. 'chat.postMessage'. Required. + :type api_method: str + :param http_verb: HTTP Verb. Optional (defaults to 'POST') + :type http_verb: str + :param files: Files to multipart upload. e.g. {imageORfile: file_objectORfile_path} + :type files: dict + :param data: The body to attach to the request. If a dictionary is provided, + form-encoding will take place. Optional. + :type data: dict or aiohttp.FormData + :param params: The URL parameters to append to the URL. Optional. + :type params: dict + :param json: JSON for the body to attach to the request. Optional. + :type json: dict + """ + self.client.api_call(api_method, **kwargs) diff --git a/reference/providers/slack/hooks/slack_webhook.py b/reference/providers/slack/hooks/slack_webhook.py new file mode 100644 index 0000000..03054b1 --- /dev/null +++ b/reference/providers/slack/hooks/slack_webhook.py @@ -0,0 +1,169 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import json +import warnings +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.providers.http.hooks.http import HttpHook + + +class SlackWebhookHook(HttpHook): + """ + This hook allows you to post messages to Slack using incoming webhooks. + Takes both Slack webhook token directly and connection that has Slack webhook token. + If both supplied, http_conn_id will be used as base_url, + and webhook_token will be taken as endpoint, the relative path of the url. + + Each Slack webhook token can be pre-configured to use a specific channel, username and + icon. You can override these defaults in this hook. + + :param http_conn_id: connection that has Slack webhook token in the extra field + :type http_conn_id: str + :param webhook_token: Slack webhook token + :type webhook_token: str + :param message: The message you want to send on Slack + :type message: str + :param attachments: The attachments to send on Slack. Should be a list of + dictionaries representing Slack attachments. + :type attachments: list + :param blocks: The blocks to send on Slack. Should be a list of + dictionaries representing Slack blocks. + :type blocks: list + :param channel: The channel the message should be posted to + :type channel: str + :param username: The username to post to slack with + :type username: str + :param icon_emoji: The emoji to use as icon for the user posting to Slack + :type icon_emoji: str + :param icon_url: The icon image URL string to use in place of the default icon. + :type icon_url: str + :param link_names: Whether or not to find and link channel and usernames in your + message + :type link_names: bool + :param proxy: Proxy to use to make the Slack webhook call + :type proxy: str + """ + + # pylint: disable=too-many-arguments + def __init__( + self, + http_conn_id=None, + webhook_token=None, + message="", + attachments=None, + blocks=None, + channel=None, + username=None, + icon_emoji=None, + icon_url=None, + link_names=False, + proxy=None, + *args, + **kwargs, + ): + super().__init__(http_conn_id=http_conn_id, *args, **kwargs) + self.webhook_token = self._get_token(webhook_token, http_conn_id) + self.message = message + self.attachments = attachments + self.blocks = blocks + self.channel = channel + self.username = username + self.icon_emoji = icon_emoji + self.icon_url = icon_url + self.link_names = link_names + self.proxy = proxy + + def _get_token(self, token: str, http_conn_id: Optional[str]) -> str: + """ + Given either a manually set token or a conn_id, return the webhook_token to use. + + :param token: The manually provided token + :type token: str + :param http_conn_id: The conn_id provided + :type http_conn_id: str + :return: webhook_token to use + :rtype: str + """ + if token: + return token + elif http_conn_id: + conn = self.get_connection(http_conn_id) + + if getattr(conn, "password", None): + return conn.password + else: + extra = conn.extra_dejson + web_token = extra.get("webhook_token", "") + + if web_token: + warnings.warn( + "'webhook_token' in 'extra' is deprecated. Please use 'password' field", + DeprecationWarning, + stacklevel=2, + ) + + return web_token + else: + raise AirflowException( + "Cannot get token: No valid Slack webhook token nor conn_id supplied" + ) + + def _build_slack_message(self) -> str: + """ + Construct the Slack message. All relevant parameters are combined here to a valid + Slack json message. + + :return: Slack message to send + :rtype: str + """ + cmd = {} + + if self.channel: + cmd["channel"] = self.channel + if self.username: + cmd["username"] = self.username + if self.icon_emoji: + cmd["icon_emoji"] = self.icon_emoji + if self.icon_url: + cmd["icon_url"] = self.icon_url + if self.link_names: + cmd["link_names"] = 1 + if self.attachments: + cmd["attachments"] = self.attachments + if self.blocks: + cmd["blocks"] = self.blocks + + cmd["text"] = self.message + return json.dumps(cmd) + + def execute(self) -> None: + """Remote Popen (actually execute the slack webhook call)""" + proxies = {} + if self.proxy: + # we only need https proxy for Slack, as the endpoint is https + proxies = {"https": self.proxy} + + slack_message = self._build_slack_message() + self.run( + endpoint=self.webhook_token, + data=slack_message, + headers={"Content-type": "application/json"}, + extra_options={"proxies": proxies, "check_response": True}, + ) diff --git a/reference/providers/slack/operators/__init__.py b/reference/providers/slack/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/slack/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/slack/operators/slack.py b/reference/providers/slack/operators/slack.py new file mode 100644 index 0000000..e3d82e4 --- /dev/null +++ b/reference/providers/slack/operators/slack.py @@ -0,0 +1,218 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from typing import Any, Dict, List, Optional + +from airflow.models import BaseOperator +from airflow.providers.slack.hooks.slack import SlackHook +from airflow.utils.decorators import apply_defaults + + +class SlackAPIOperator(BaseOperator): + """ + Base Slack Operator + The SlackAPIPostOperator is derived from this operator. + In the future additional Slack API Operators will be derived from this class as well + Only one of `slack_conn_id` and `token` is required. + + :param slack_conn_id: Slack connection ID which its password is Slack API token. Optional + :type slack_conn_id: str + :param token: Slack API token (https://api.slack.com/web). Optional + :type token: str + :param method: The Slack API Method to Call (https://api.slack.com/methods). Optional + :type method: str + :param api_params: API Method call parameters (https://api.slack.com/methods). Optional + :type api_params: dict + :param client_args: Slack Hook parameters. Optional. Check airflow.providers.slack.hooks.SlackHook + :type api_params: dict + """ + + @apply_defaults + def __init__( + self, + *, + slack_conn_id: Optional[str] = None, + token: Optional[str] = None, + method: Optional[str] = None, + api_params: Optional[Dict] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.token = token # type: Optional[str] + self.slack_conn_id = slack_conn_id # type: Optional[str] + + self.method = method + self.api_params = api_params + + def construct_api_call_params(self) -> Any: + """ + Used by the execute function. Allows templating on the source fields + of the api_call_params dict before construction + + Override in child classes. + Each SlackAPIOperator child class is responsible for + having a construct_api_call_params function + which sets self.api_call_params with a dict of + API call parameters (https://api.slack.com/methods) + """ + raise NotImplementedError( + "SlackAPIOperator should not be used directly. Chose one of the subclasses instead" + ) + + def execute(self, **kwargs): # noqa: D403 + """ + SlackAPIOperator calls will not fail even if the call is not unsuccessful. + It should not prevent a DAG from completing in success + """ + if not self.api_params: + self.construct_api_call_params() + slack = SlackHook(token=self.token, slack_conn_id=self.slack_conn_id) + slack.call(self.method, json=self.api_params) + + +class SlackAPIPostOperator(SlackAPIOperator): + """ + Posts messages to a slack channel + Examples: + + .. code-block:: python + + slack = SlackAPIPostOperator( + task_id="post_hello", + dag=dag, + token="XXX", + text="hello there!", + channel="#random", + ) + + :param channel: channel in which to post message on slack name (#general) or + ID (C12318391). (templated) + :type channel: str + :param username: Username that airflow will be posting to Slack as. (templated) + :type username: str + :param text: message to send to slack. (templated) + :type text: str + :param icon_url: url to icon used for this message + :type icon_url: str + :param attachments: extra formatting details. (templated) + - see https://api.slack.com/docs/attachments. + :type attachments: list of hashes + :param blocks: extra block layouts. (templated) + - see https://api.slack.com/reference/block-kit/blocks. + :type blocks: list of hashes + """ + + template_fields = ("username", "text", "attachments", "blocks", "channel") + ui_color = "#FFBA40" + + @apply_defaults + def __init__( + self, + channel: str = "#general", + username: str = "Airflow", + text: str = "No message has been set.\n" + "Here is a cat video instead\n" + "https://www.youtube.com/watch?v=J---aiyznGQ", + icon_url: str = "https://raw.githubusercontent.com/apache/" + "airflow/master/airflow/www/static/pin_100.png", + attachments: Optional[List] = None, + blocks: Optional[List] = None, + **kwargs, + ) -> None: + self.method = "chat.postMessage" + self.channel = channel + self.username = username + self.text = text + self.icon_url = icon_url + self.attachments = attachments or [] + self.blocks = blocks or [] + super().__init__(method=self.method, **kwargs) + + def construct_api_call_params(self) -> Any: + self.api_params = { + "channel": self.channel, + "username": self.username, + "text": self.text, + "icon_url": self.icon_url, + "attachments": json.dumps(self.attachments), + "blocks": json.dumps(self.blocks), + } + + +class SlackAPIFileOperator(SlackAPIOperator): + """ + Send a file to a slack channel + Examples: + + .. code-block:: python + + slack = SlackAPIFileOperator( + task_id="slack_file_upload", + dag=dag, + slack_conn_id="slack", + channel="#general", + initial_comment="Hello World!", + filename="hello_world.csv", + filetype="csv", + content="hello,world,csv,file", + ) + + :param channel: channel in which to sent file on slack name (templated) + :type channel: str + :param initial_comment: message to send to slack. (templated) + :type initial_comment: str + :param filename: name of the file (templated) + :type filename: str + :param filetype: slack filetype. (templated) + - see https://api.slack.com/types/file + :type filetype: str + :param content: file content. (templated) + :type content: str + """ + + template_fields = ("channel", "initial_comment", "filename", "filetype", "content") + ui_color = "#44BEDF" + + @apply_defaults + def __init__( + self, + channel: str = "#general", + initial_comment: str = "No message has been set!", + filename: str = "default_name.csv", + filetype: str = "csv", + content: str = "default,content,csv,file", + **kwargs, + ) -> None: + self.method = "files.upload" + self.channel = channel + self.initial_comment = initial_comment + self.filename = filename + self.filetype = filetype + self.content = content + super().__init__(method=self.method, **kwargs) + + def construct_api_call_params(self) -> Any: + self.api_params = { + "channels": self.channel, + "content": self.content, + "filename": self.filename, + "filetype": self.filetype, + "initial_comment": self.initial_comment, + } diff --git a/reference/providers/slack/operators/slack_webhook.py b/reference/providers/slack/operators/slack_webhook.py new file mode 100644 index 0000000..1b3b7c1 --- /dev/null +++ b/reference/providers/slack/operators/slack_webhook.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Any, Dict, Optional + +from airflow.providers.http.operators.http import SimpleHttpOperator +from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook +from airflow.utils.decorators import apply_defaults + + +class SlackWebhookOperator(SimpleHttpOperator): + """ + This operator allows you to post messages to Slack using incoming webhooks. + Takes both Slack webhook token directly and connection that has Slack webhook token. + If both supplied, http_conn_id will be used as base_url, + and webhook_token will be taken as endpoint, the relative path of the url. + + Each Slack webhook token can be pre-configured to use a specific channel, username and + icon. You can override these defaults in this hook. + + :param http_conn_id: connection that has Slack webhook token in the extra field + :type http_conn_id: str + :param webhook_token: Slack webhook token + :type webhook_token: str + :param message: The message you want to send on Slack + :type message: str + :param attachments: The attachments to send on Slack. Should be a list of + dictionaries representing Slack attachments. + :type attachments: list + :param blocks: The blocks to send on Slack. Should be a list of + dictionaries representing Slack blocks. + :type blocks: list + :param channel: The channel the message should be posted to + :type channel: str + :param username: The username to post to slack with + :type username: str + :param icon_emoji: The emoji to use as icon for the user posting to Slack + :type icon_emoji: str + :param icon_url: The icon image URL string to use in place of the default icon. + :type icon_url: str + :param link_names: Whether or not to find and link channel and usernames in your + message + :type link_names: bool + :param proxy: Proxy to use to make the Slack webhook call + :type proxy: str + """ + + template_fields = [ + "webhook_token", + "message", + "attachments", + "blocks", + "channel", + "username", + "proxy", + ] + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + http_conn_id: str, + webhook_token: Optional[str] = None, + message: str = "", + attachments: Optional[list] = None, + blocks: Optional[list] = None, + channel: Optional[str] = None, + username: Optional[str] = None, + icon_emoji: Optional[str] = None, + icon_url: Optional[str] = None, + link_names: bool = False, + proxy: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(endpoint=webhook_token, **kwargs) + self.http_conn_id = http_conn_id + self.webhook_token = webhook_token + self.message = message + self.attachments = attachments + self.blocks = blocks + self.channel = channel + self.username = username + self.icon_emoji = icon_emoji + self.icon_url = icon_url + self.link_names = link_names + self.proxy = proxy + self.hook: Optional[SlackWebhookHook] = None + + def execute(self, context: Dict[str, Any]) -> None: + """Call the SlackWebhookHook to post the provided Slack message""" + self.hook = SlackWebhookHook( + self.http_conn_id, + self.webhook_token, + self.message, + self.attachments, + self.blocks, + self.channel, + self.username, + self.icon_emoji, + self.icon_url, + self.link_names, + self.proxy, + ) + self.hook.execute() diff --git a/reference/providers/slack/provider.yaml b/reference/providers/slack/provider.yaml new file mode 100644 index 0000000..f6cb003 --- /dev/null +++ b/reference/providers/slack/provider.yaml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-slack +name: Slack +description: | + `Slack `__ + +versions: + - 3.0.0 + - 2.0.0 + - 1.0.0 + +integrations: + - integration-name: Slack + external-doc-url: https://slack.com/ + logo: /integration-logos/slack/Slack.png + tags: [service] + +operators: + - integration-name: Slack + python-modules: + - airflow.providers.slack.operators.slack + - airflow.providers.slack.operators.slack_webhook + +hooks: + - integration-name: Slack + python-modules: + - airflow.providers.slack.hooks.slack + - airflow.providers.slack.hooks.slack_webhook diff --git a/reference/providers/snowflake/CHANGELOG.rst b/reference/providers/snowflake/CHANGELOG.rst new file mode 100644 index 0000000..359f614 --- /dev/null +++ b/reference/providers/snowflake/CHANGELOG.rst @@ -0,0 +1,50 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.2.0 +..... + +Features +~~~~~~~~ + +* ``Bumped snowflake-connector-python library to >=2.4.1 version and get rid of pytz library pinning`` + +Bug fixes +~~~~~~~~~ + +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` +* ``Prepare to release the next wave of providers: (#14487)`` + +1.1.0 +..... + +Updated documentation and readme files. + +Features +~~~~~~~~ + +* ``Fix S3ToSnowflakeOperator to support uploading all files in the specified stage (#12505)`` +* ``Add connection arguments in S3ToSnowflakeOperator (#12564)`` + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/snowflake/__init__.py b/reference/providers/snowflake/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/snowflake/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/snowflake/example_dags/__init__.py b/reference/providers/snowflake/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/snowflake/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/snowflake/example_dags/example_snowflake.py b/reference/providers/snowflake/example_dags/example_snowflake.py new file mode 100644 index 0000000..90f96f0 --- /dev/null +++ b/reference/providers/snowflake/example_dags/example_snowflake.py @@ -0,0 +1,132 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example use of Snowflake related operators. +""" +from airflow import DAG +from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator +from airflow.providers.snowflake.transfers.s3_to_snowflake import S3ToSnowflakeOperator +from airflow.providers.snowflake.transfers.snowflake_to_slack import ( + SnowflakeToSlackOperator, +) +from airflow.utils.dates import days_ago + +SNOWFLAKE_CONN_ID = "my_snowflake_conn" +SLACK_CONN_ID = "my_slack_conn" +# TODO: should be able to rely on connection's schema, but currently param required by S3ToSnowflakeTransfer +SNOWFLAKE_SCHEMA = "schema_name" +SNOWFLAKE_STAGE = "stage_name" +SNOWFLAKE_WAREHOUSE = "warehouse_name" +SNOWFLAKE_DATABASE = "database_name" +SNOWFLAKE_ROLE = "role_name" +SNOWFLAKE_SAMPLE_TABLE = "sample_table" +S3_FILE_PATH = "> [ + snowflake_op_with_params, + snowflake_op_sql_list, + snowflake_op_template_file, + copy_into_table, +] >> slack_report diff --git a/reference/providers/snowflake/hooks/__init__.py b/reference/providers/snowflake/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/snowflake/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/snowflake/hooks/snowflake.py b/reference/providers/snowflake/hooks/snowflake.py new file mode 100644 index 0000000..b97041f --- /dev/null +++ b/reference/providers/snowflake/hooks/snowflake.py @@ -0,0 +1,254 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, Optional, Tuple + +from airflow.hooks.dbapi import DbApiHook +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization + +# pylint: disable=no-name-in-module +from snowflake import connector +from snowflake.connector import SnowflakeConnection + + +class SnowflakeHook(DbApiHook): + """ + A client to interact with Snowflake. + + This hook requires the snowflake_conn_id connection. The snowflake host, login, + and, password field must be setup in the connection. Other inputs can be defined + in the connection or hook instantiation. If used with the S3ToSnowflakeOperator + add 'aws_access_key_id' and 'aws_secret_access_key' to extra field in the connection. + + :param account: snowflake account name + :type account: Optional[str] + :param authenticator: authenticator for Snowflake. + 'snowflake' (default) to use the internal Snowflake authenticator + 'externalbrowser' to authenticate using your web browser and + Okta, ADFS or any other SAML 2.0-compliant identify provider + (IdP) that has been defined for your account + 'https://.okta.com' to authenticate + through native Okta. + :type authenticator: Optional[str] + :param warehouse: name of snowflake warehouse + :type warehouse: Optional[str] + :param database: name of snowflake database + :type database: Optional[str] + :param region: name of snowflake region + :type region: Optional[str] + :param role: name of snowflake role + :type role: Optional[str] + :param schema: name of snowflake schema + :type schema: Optional[str] + :param session_parameters: You can set session-level parameters at + the time you connect to Snowflake + :type session_parameters: Optional[dict] + + .. note:: + get_sqlalchemy_engine() depends on snowflake-sqlalchemy + + .. seealso:: + For more information on how to use this Snowflake connection, take a look at the guide: + :ref:`howto/operator:SnowflakeOperator` + """ + + conn_name_attr = "snowflake_conn_id" + default_conn_name = "snowflake_default" + conn_type = "snowflake" + hook_name = "Snowflake" + supports_autocommit = True + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import ( + BS3PasswordFieldWidget, + BS3TextFieldWidget, + ) + from flask_babel import lazy_gettext + from wtforms import PasswordField, StringField + + return { + "extra__snowflake__account": StringField( + lazy_gettext("Account"), widget=BS3TextFieldWidget() + ), + "extra__snowflake__warehouse": StringField( + lazy_gettext("Warehouse"), widget=BS3TextFieldWidget() + ), + "extra__snowflake__database": StringField( + lazy_gettext("Database"), widget=BS3TextFieldWidget() + ), + "extra__snowflake__region": StringField( + lazy_gettext("Region"), widget=BS3TextFieldWidget() + ), + "extra__snowflake__aws_access_key_id": StringField( + lazy_gettext("AWS Access Key"), widget=BS3TextFieldWidget() + ), + "extra__snowflake__aws_secret_access_key": PasswordField( + lazy_gettext("AWS Secret Key"), widget=BS3PasswordFieldWidget() + ), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + import json + + return { + "hidden_fields": ["port"], + "relabeling": {}, + "placeholders": { + "extra": json.dumps( + { + "role": "snowflake role", + "authenticator": "snowflake oauth", + "private_key_file": "private key", + "session_parameters": "session parameters", + }, + indent=1, + ), + "host": "snowflake hostname", + "schema": "snowflake schema", + "login": "snowflake username", + "password": "snowflake password", + "extra__snowflake__account": "snowflake account name", + "extra__snowflake__warehouse": "snowflake warehouse name", + "extra__snowflake__database": "snowflake db name", + "extra__snowflake__region": "snowflake hosted region", + "extra__snowflake__aws_access_key_id": "aws access key id (S3ToSnowflakeOperator)", + "extra__snowflake__aws_secret_access_key": "aws secret access key (S3ToSnowflakeOperator)", + }, + } + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.account = kwargs.pop("account", None) + self.warehouse = kwargs.pop("warehouse", None) + self.database = kwargs.pop("database", None) + self.region = kwargs.pop("region", None) + self.role = kwargs.pop("role", None) + self.schema = kwargs.pop("schema", None) + self.authenticator = kwargs.pop("authenticator", None) + self.session_parameters = kwargs.pop("session_parameters", None) + + def _get_conn_params(self) -> Dict[str, Optional[str]]: + """ + One method to fetch connection params as a dict + used in get_uri() and get_connection() + """ + conn = self.get_connection( + self.snowflake_conn_id # type: ignore[attr-defined] # pylint: disable=no-member + ) + account = conn.extra_dejson.get( + "extra__snowflake__account", "" + ) or conn.extra_dejson.get("account", "") + warehouse = conn.extra_dejson.get( + "extra__snowflake__warehouse", "" + ) or conn.extra_dejson.get("warehouse", "") + database = conn.extra_dejson.get( + "extra__snowflake__database", "" + ) or conn.extra_dejson.get("database", "") + region = conn.extra_dejson.get( + "extra__snowflake__region", "" + ) or conn.extra_dejson.get("region", "") + role = conn.extra_dejson.get("role", "") + schema = conn.schema or "" + authenticator = conn.extra_dejson.get("authenticator", "snowflake") + session_parameters = conn.extra_dejson.get("session_parameters") + + conn_config = { + "user": conn.login, + "password": conn.password or "", + "schema": self.schema or schema, + "database": self.database or database, + "account": self.account or account, + "warehouse": self.warehouse or warehouse, + "region": self.region or region, + "role": self.role or role, + "authenticator": self.authenticator or authenticator, + "session_parameters": self.session_parameters or session_parameters, + } + + # If private_key_file is specified in the extra json, load the contents of the file as a private + # key and specify that in the connection configuration. The connection password then becomes the + # passphrase for the private key. If your private key file is not encrypted (not recommended), then + # leave the password empty. + + private_key_file = conn.extra_dejson.get("private_key_file") + if private_key_file: + with open(private_key_file, "rb") as key: + passphrase = None + if conn.password: + passphrase = conn.password.strip().encode() + + p_key = serialization.load_pem_private_key( + key.read(), password=passphrase, backend=default_backend() + ) + + pkb = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + conn_config["private_key"] = pkb + conn_config.pop("password", None) + + return conn_config + + def get_uri(self) -> str: + """Override DbApiHook get_uri method for get_sqlalchemy_engine()""" + conn_config = self._get_conn_params() + uri = ( + "snowflake://{user}:{password}@{account}/{database}/{schema}" + "?warehouse={warehouse}&role={role}&authenticator={authenticator}" + ) + return uri.format(**conn_config) + + def get_conn(self) -> SnowflakeConnection: + """Returns a snowflake.connection object""" + conn_config = self._get_conn_params() + conn = connector.connect(**conn_config) + return conn + + def _get_aws_credentials(self) -> Tuple[Optional[Any], Optional[Any]]: + """ + Returns aws_access_key_id, aws_secret_access_key + from extra + + intended to be used by external import and export statements + """ + if self.snowflake_conn_id: # type: ignore[attr-defined] # pylint: disable=no-member + connection_object = self.get_connection( + self.snowflake_conn_id # type: ignore[attr-defined] # pylint: disable=no-member + ) + if "aws_secret_access_key" in connection_object.extra_dejson: + aws_access_key_id = connection_object.extra_dejson.get( + "aws_access_key_id" + ) or connection_object.extra_dejson.get("aws_access_key_id") + aws_secret_access_key = connection_object.extra_dejson.get( + "aws_secret_access_key" + ) or connection_object.extra_dejson.get("aws_secret_access_key") + return aws_access_key_id, aws_secret_access_key + + def set_autocommit(self, conn, autocommit: Any) -> None: + conn.autocommit(autocommit) + conn.autocommit_mode = autocommit + + def get_autocommit(self, conn): + return getattr(conn, "autocommit_mode", False) diff --git a/reference/providers/snowflake/operators/__init__.py b/reference/providers/snowflake/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/snowflake/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/snowflake/operators/snowflake.py b/reference/providers/snowflake/operators/snowflake.py new file mode 100644 index 0000000..0011e8a --- /dev/null +++ b/reference/providers/snowflake/operators/snowflake.py @@ -0,0 +1,121 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Optional + +from airflow.models import BaseOperator +from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook +from airflow.utils.decorators import apply_defaults + + +class SnowflakeOperator(BaseOperator): + """ + Executes SQL code in a Snowflake database + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SnowflakeOperator` + + :param snowflake_conn_id: reference to specific snowflake connection id + :type snowflake_conn_id: str + :param sql: the sql code to be executed. (templated) + :type sql: Can receive a str representing a sql statement, + a list of str (sql statements), or reference to a template file. + Template reference are recognized by str ending in '.sql' + :param autocommit: if True, each command is automatically committed. + (default value: True) + :type autocommit: bool + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: dict or iterable + :param warehouse: name of warehouse (will overwrite any warehouse + defined in the connection's extra JSON) + :type warehouse: str + :param database: name of database (will overwrite database defined + in connection) + :type database: str + :param schema: name of schema (will overwrite schema defined in + connection) + :type schema: str + :param role: name of role (will overwrite any role defined in + connection's extra JSON) + :type role: str + :param authenticator: authenticator for Snowflake. + 'snowflake' (default) to use the internal Snowflake authenticator + 'externalbrowser' to authenticate using your web browser and + Okta, ADFS or any other SAML 2.0-compliant identify provider + (IdP) that has been defined for your account + 'https://.okta.com' to authenticate + through native Okta. + :type authenticator: str + :param session_parameters: You can set session-level parameters at + the time you connect to Snowflake + :type session_parameters: dict + """ + + template_fields = ("sql",) + template_ext = (".sql",) + ui_color = "#ededed" + + @apply_defaults + def __init__( + self, + *, + sql: Any, + snowflake_conn_id: str = "snowflake_default", + parameters: Optional[dict] = None, + autocommit: bool = True, + warehouse: Optional[str] = None, + database: Optional[str] = None, + role: Optional[str] = None, + schema: Optional[str] = None, + authenticator: Optional[str] = None, + session_parameters: Optional[dict] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.snowflake_conn_id = snowflake_conn_id + self.sql = sql + self.autocommit = autocommit + self.parameters = parameters + self.warehouse = warehouse + self.database = database + self.role = role + self.schema = schema + self.authenticator = authenticator + self.session_parameters = session_parameters + + def get_hook(self) -> SnowflakeHook: + """ + Create and return SnowflakeHook. + :return: a SnowflakeHook instance. + :rtype: SnowflakeHook + """ + return SnowflakeHook( + snowflake_conn_id=self.snowflake_conn_id, + warehouse=self.warehouse, + database=self.database, + role=self.role, + schema=self.schema, + authenticator=self.authenticator, + session_parameters=self.session_parameters, + ) + + def execute(self, context: Any) -> None: + """Run query on snowflake""" + self.log.info("Executing: %s", self.sql) + hook = self.get_hook() + hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) diff --git a/reference/providers/snowflake/provider.yaml b/reference/providers/snowflake/provider.yaml new file mode 100644 index 0000000..0eb230f --- /dev/null +++ b/reference/providers/snowflake/provider.yaml @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-snowflake +name: Snowflake +description: | + `Snowflake `__ + +versions: + - 1.1.1 + - 1.1.0 + - 1.0.0 + +integrations: + - integration-name: Snowflake + external-doc-url: https://snowflake.com/ + how-to-guide: + - /docs/apache-airflow-providers-snowflake/operators/snowflake.rst + logo: /integration-logos/snowflake/Snowflake.png + tags: [service] + +operators: + - integration-name: Snowflake + python-modules: + - airflow.providers.snowflake.operators.snowflake + +hooks: + - integration-name: Snowflake + python-modules: + - airflow.providers.snowflake.hooks.snowflake + +transfers: + - source-integration-name: Amazon Simple Storage Service (S3) + target-integration-name: Snowflake + python-module: airflow.providers.snowflake.transfers.s3_to_snowflake + how-to-guide: /docs/apache-airflow-providers-snowflake/operators/s3_to_snowflake.rst + - source-integration-name: Snowflake + target-integration-name: Slack + python-module: airflow.providers.snowflake.transfers.snowflake_to_slack + how-to-guide: /docs/apache-airflow-providers-snowflake/operators/snowflake_to_slack.rst + +hook-class-names: + - airflow.providers.snowflake.hooks.snowflake.SnowflakeHook diff --git a/reference/providers/snowflake/transfers/__init__.py b/reference/providers/snowflake/transfers/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/snowflake/transfers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/snowflake/transfers/s3_to_snowflake.py b/reference/providers/snowflake/transfers/s3_to_snowflake.py new file mode 100644 index 0000000..165a652 --- /dev/null +++ b/reference/providers/snowflake/transfers/s3_to_snowflake.py @@ -0,0 +1,154 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains AWS S3 to Snowflake operator.""" +from typing import Any, Optional + +from airflow.models import BaseOperator +from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook +from airflow.utils.decorators import apply_defaults + + +class S3ToSnowflakeOperator(BaseOperator): + """ + Executes an COPY command to load files from s3 to Snowflake + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:S3ToSnowflakeOperator` + + :param s3_keys: reference to a list of S3 keys + :type s3_keys: list + :param table: reference to a specific table in snowflake database + :type table: str + :param schema: name of schema (will overwrite schema defined in + connection) + :type schema: str + :param stage: reference to a specific snowflake stage. If the stage's schema is not the same as the + table one, it must be specified + :type stage: str + :param prefix: cloud storage location specified to limit the set of files to load + :type prefix: str + :param file_format: reference to a specific file format + :type file_format: str + :param warehouse: name of warehouse (will overwrite any warehouse + defined in the connection's extra JSON) + :type warehouse: str + :param database: reference to a specific database in Snowflake connection + :type database: str + :param columns_array: reference to a specific columns array in snowflake database + :type columns_array: list + :param snowflake_conn_id: reference to a specific snowflake connection + :type snowflake_conn_id: str + :param role: name of role (will overwrite any role defined in + connection's extra JSON) + :type role: str + :param authenticator: authenticator for Snowflake. + 'snowflake' (default) to use the internal Snowflake authenticator + 'externalbrowser' to authenticate using your web browser and + Okta, ADFS or any other SAML 2.0-compliant identify provider + (IdP) that has been defined for your account + 'https://.okta.com' to authenticate + through native Okta. + :type authenticator: str + :param session_parameters: You can set session-level parameters at + the time you connect to Snowflake + :type session_parameters: dict + """ + + @apply_defaults + def __init__( + self, + *, + s3_keys: Optional[list] = None, + table: str, + stage: str, + prefix: Optional[str] = None, + file_format: str, + schema: str, # TODO: shouldn't be required, rely on session/user defaults + columns_array: Optional[list] = None, + warehouse: Optional[str] = None, + database: Optional[str] = None, + autocommit: bool = True, + snowflake_conn_id: str = "snowflake_default", + role: Optional[str] = None, + authenticator: Optional[str] = None, + session_parameters: Optional[dict] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.s3_keys = s3_keys + self.table = table + self.warehouse = warehouse + self.database = database + self.stage = stage + self.prefix = prefix + self.file_format = file_format + self.schema = schema + self.columns_array = columns_array + self.autocommit = autocommit + self.snowflake_conn_id = snowflake_conn_id + self.role = role + self.authenticator = authenticator + self.session_parameters = session_parameters + + def execute(self, context: Any) -> None: + snowflake_hook = SnowflakeHook( + snowflake_conn_id=self.snowflake_conn_id, + warehouse=self.warehouse, + database=self.database, + role=self.role, + schema=self.schema, + authenticator=self.authenticator, + session_parameters=self.session_parameters, + ) + + files = "" + if self.s3_keys: + files = "files=({})".format(", ".join(f"'{key}'" for key in self.s3_keys)) + + # we can extend this based on stage + base_sql = """ + FROM @{stage}/{prefix} + {files} + file_format={file_format} + """.format( + stage=self.stage, + prefix=(self.prefix if self.prefix else ""), + files=files, + file_format=self.file_format, + ) + + if self.columns_array: + copy_query = """ + COPY INTO {schema}.{table}({columns}) {base_sql} + """.format( + schema=self.schema, + table=self.table, + columns=",".join(self.columns_array), + base_sql=base_sql, + ) + else: + copy_query = f""" + COPY INTO {self.schema}.{self.table} {base_sql} + """ + copy_query = "\n".join(line.strip() for line in copy_query.splitlines()) + + self.log.info("Executing COPY command...") + snowflake_hook.run(copy_query, self.autocommit) + self.log.info("COPY command completed") diff --git a/reference/providers/snowflake/transfers/snowflake_to_slack.py b/reference/providers/snowflake/transfers/snowflake_to_slack.py new file mode 100644 index 0000000..c89e376 --- /dev/null +++ b/reference/providers/snowflake/transfers/snowflake_to_slack.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterable, Mapping, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook +from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook +from airflow.utils.decorators import apply_defaults +from pandas import DataFrame +from tabulate import tabulate + + +class SnowflakeToSlackOperator(BaseOperator): + """ + Executes an SQL statement in Snowflake and sends the results to Slack. The results of the query are + rendered into the 'slack_message' parameter as a Pandas dataframe using a JINJA variable called '{{ + results_df }}'. The 'results_df' variable name can be changed by specifying a different + 'results_df_name' parameter. The Tabulate library is added to the JINJA environment as a filter to + allow the dataframe to be rendered nicely. For example, set 'slack_message' to {{ results_df | + tabulate(tablefmt="pretty", headers="keys") }} to send the results to Slack as an ascii rendered table. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SnowflakeToSlackOperator` + + :param sql: The SQL statement to execute on Snowflake (templated) + :type sql: str + :param slack_message: The templated Slack message to send with the data returned from Snowflake. + You can use the default JINJA variable {{ results_df }} to access the pandas dataframe containing the + SQL results + :type slack_message: str + :param snowflake_conn_id: The Snowflake connection id + :type snowflake_conn_id: str + :param slack_conn_id: The connection id for Slack + :type slack_conn_id: str + :param results_df_name: The name of the JINJA template's dataframe variable, default is 'results_df' + :type results_df_name: str + :param parameters: The parameters to pass to the SQL query + :type parameters: Optional[Union[Iterable, Mapping]] + :param warehouse: The Snowflake virtual warehouse to use to run the SQL query + :type warehouse: Optional[str] + :param database: The Snowflake database to use for the SQL query + :type database: Optional[str] + :param schema: The schema to run the SQL against in Snowflake + :type schema: Optional[str] + :param role: The role to use when connecting to Snowflake + :type role: Optional[str] + :param slack_token: The token to use to authenticate to Slack. If this is not provided, the + 'webhook_token' attribute needs to be specified in the 'Extra' JSON field against the slack_conn_id + :type slack_token: Optional[str] + """ + + template_fields = ["sql", "slack_message"] + template_ext = [".sql", ".jinja", ".j2"] + times_rendered = 0 + + @apply_defaults + def __init__( # pylint: disable=too-many-arguments + self, + *, + sql: str, + slack_message: str, + snowflake_conn_id: str = "snowflake_default", + slack_conn_id: str = "slack_default", + results_df_name: str = "results_df", + parameters: Optional[Union[Iterable, Mapping]] = None, + warehouse: Optional[str] = None, + database: Optional[str] = None, + schema: Optional[str] = None, + role: Optional[str] = None, + slack_token: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.snowflake_conn_id = snowflake_conn_id + self.sql = sql + self.parameters = parameters + self.warehouse = warehouse + self.database = database + self.schema = schema + self.role = role + self.slack_conn_id = slack_conn_id + self.slack_token = slack_token + self.slack_message = slack_message + self.results_df_name = results_df_name + + def _get_query_results(self) -> DataFrame: + snowflake_hook = self._get_snowflake_hook() + + self.log.info("Running SQL query: %s", self.sql) + df = snowflake_hook.get_pandas_df(self.sql, parameters=self.parameters) + return df + + def _render_and_send_slack_message(self, context, df) -> None: + # Put the dataframe into the context and render the JINJA template fields + context[self.results_df_name] = df + self.render_template_fields(context) + + slack_hook = self._get_slack_hook() + self.log.info("Sending slack message: %s", self.slack_message) + slack_hook.execute() + + def _get_snowflake_hook(self) -> SnowflakeHook: + return SnowflakeHook( + snowflake_conn_id=self.snowflake_conn_id, + warehouse=self.warehouse, + database=self.database, + role=self.role, + schema=self.schema, + ) + + def _get_slack_hook(self) -> SlackWebhookHook: + return SlackWebhookHook( + http_conn_id=self.slack_conn_id, + message=self.slack_message, + webhook_token=self.slack_token, + ) + + def render_template_fields(self, context, jinja_env=None) -> None: + # If this is the first render of the template fields, exclude slack_message from rendering since + # the snowflake results haven't been retrieved yet. + if self.times_rendered == 0: + fields_to_render: Iterable[str] = filter( + lambda x: x != "slack_message", self.template_fields + ) + else: + fields_to_render = self.template_fields + + if not jinja_env: + jinja_env = self.get_template_env() + + # Add the tabulate library into the JINJA environment + jinja_env.filters["tabulate"] = tabulate + + self._do_render_template_fields( + self, fields_to_render, context, jinja_env, set() + ) + self.times_rendered += 1 + + def execute(self, context) -> None: + if not isinstance(self.sql, str): + raise AirflowException("Expected 'sql' parameter should be a string.") + if self.sql is None or self.sql.strip() == "": + raise AirflowException("Expected 'sql' parameter is missing.") + if self.slack_message is None or self.slack_message.strip() == "": + raise AirflowException("Expected 'slack_message' parameter is missing.") + + df = self._get_query_results() + self._render_and_send_slack_message(context, df) + + self.log.debug("Finished sending Snowflake data to Slack") diff --git a/reference/providers/sqlite/CHANGELOG.rst b/reference/providers/sqlite/CHANGELOG.rst new file mode 100644 index 0000000..da7155a --- /dev/null +++ b/reference/providers/sqlite/CHANGELOG.rst @@ -0,0 +1,41 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.2 +..... + +Bug fixes +~~~~~~~~~ + +* ``Corrections in docs and tools after releasing provider RCs (#14082)`` + + +1.0.1 +..... + +Updated documentation and readme files. + +* ``Add example DAG & how-to guide for sqlite (#13196)`` + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/sqlite/__init__.py b/reference/providers/sqlite/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/sqlite/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/sqlite/example_dags/__init__.py b/reference/providers/sqlite/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/sqlite/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/sqlite/example_dags/example_sqlite.py b/reference/providers/sqlite/example_dags/example_sqlite.py new file mode 100644 index 0000000..d00c62c --- /dev/null +++ b/reference/providers/sqlite/example_dags/example_sqlite.py @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example DAG for the use of the SqliteOperator. +In this example, we create two tasks that execute in sequence. +The first task calls an sql command, defined in the SQLite operator, +which when triggered, is performed on the connected sqlite database. +The second task is similar but instead calls the SQL command from an external file. +""" + +from airflow import DAG +from airflow.providers.sqlite.operators.sqlite import SqliteOperator +from airflow.utils.dates import days_ago + +default_args = {"owner": "airflow"} + +dag = DAG( + dag_id="example_sqlite", + default_args=default_args, + schedule_interval="@daily", + start_date=days_ago(2), + tags=["example"], +) + +# [START howto_operator_sqlite] + +# Example of creating a task that calls a common CREATE TABLE sql command. +create_table_sqlite_task = SqliteOperator( + task_id="create_table_sqlite", + sqlite_conn_id="sqlite_conn_id", + sql=r""" + CREATE TABLE table_name ( + column_1 string, + column_2 string, + column_3 string + ); + """, + dag=dag, +) + +# [END howto_operator_sqlite] + +# [START howto_operator_sqlite_external_file] + +# Example of creating a task that calls an sql command from an external file. +external_create_table_sqlite_task = SqliteOperator( + task_id="create_table_sqlite_external_file", + sqlite_conn_id="sqlite_conn_id", + sql="/scripts/create_table.sql", + dag=dag, +) + +# [END howto_operator_sqlite_external_file] + +create_table_sqlite_task >> external_create_table_sqlite_task diff --git a/reference/providers/sqlite/hooks/__init__.py b/reference/providers/sqlite/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/sqlite/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/sqlite/hooks/sqlite.py b/reference/providers/sqlite/hooks/sqlite.py new file mode 100644 index 0000000..18d8cdd --- /dev/null +++ b/reference/providers/sqlite/hooks/sqlite.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sqlite3 + +from airflow.hooks.dbapi import DbApiHook + + +class SqliteHook(DbApiHook): + """Interact with SQLite.""" + + conn_name_attr = "sqlite_conn_id" + default_conn_name = "sqlite_default" + conn_type = "sqlite" + hook_name = "Sqlite" + + def get_conn(self) -> sqlite3.dbapi2.Connection: + """Returns a sqlite connection object""" + conn_id = getattr(self, self.conn_name_attr) + airflow_conn = self.get_connection(conn_id) + conn = sqlite3.connect(airflow_conn.host) + return conn diff --git a/reference/providers/sqlite/operators/__init__.py b/reference/providers/sqlite/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/sqlite/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/sqlite/operators/sqlite.py b/reference/providers/sqlite/operators/sqlite.py new file mode 100644 index 0000000..6bffa69 --- /dev/null +++ b/reference/providers/sqlite/operators/sqlite.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Iterable, Mapping, Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.sqlite.hooks.sqlite import SqliteHook +from airflow.utils.decorators import apply_defaults + + +class SqliteOperator(BaseOperator): + """ + Executes sql code in a specific Sqlite database + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SqliteOperator` + + :param sql: the sql code to be executed. Can receive a str representing a + sql statement, a list of str (sql statements), or reference to a template file. + Template reference are recognized by str ending in '.sql' + (templated) + :type sql: str or list[str] + :param sqlite_conn_id: reference to a specific sqlite database + :type sqlite_conn_id: str + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: dict or iterable + """ + + template_fields = ("sql",) + template_ext = (".sql",) + ui_color = "#cdaaed" + + @apply_defaults + def __init__( + self, + *, + sql: str, + sqlite_conn_id: str = "sqlite_default", + parameters: Optional[Union[Mapping, Iterable]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sqlite_conn_id = sqlite_conn_id + self.sql = sql + self.parameters = parameters or [] + + def execute(self, context: Mapping[Any, Any]) -> None: + self.log.info("Executing: %s", self.sql) + hook = SqliteHook(sqlite_conn_id=self.sqlite_conn_id) + hook.run(self.sql, parameters=self.parameters) diff --git a/reference/providers/sqlite/provider.yaml b/reference/providers/sqlite/provider.yaml new file mode 100644 index 0000000..28bfa80 --- /dev/null +++ b/reference/providers/sqlite/provider.yaml @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-sqlite +name: SQLite +description: | + `SQLite `__ + +versions: + - 1.0.2 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: SQLite + external-doc-url: https://www.sqlite.org/index.html + how-to-guide: + - /docs/apache-airflow-providers-sqlite/operators.rst + logo: /integration-logos/sqlite/SQLite.png + tags: [software] + +operators: + - integration-name: SQLite + + python-modules: + - airflow.providers.sqlite.operators.sqlite + +hooks: + - integration-name: SQLite + python-modules: + - airflow.providers.sqlite.hooks.sqlite + +hook-class-names: + - airflow.providers.sqlite.hooks.sqlite.SqliteHook diff --git a/reference/providers/ssh/CHANGELOG.rst b/reference/providers/ssh/CHANGELOG.rst new file mode 100644 index 0000000..8db7a03 --- /dev/null +++ b/reference/providers/ssh/CHANGELOG.rst @@ -0,0 +1,43 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.2.0 +..... + +Features +~~~~~~~~ + +* ``Added support for DSS, ECDSA, and Ed25519 private keys in SSHHook (#12467)`` + +1.1.0 +..... + +Updated documentation and readme files. + +Features +~~~~~~~~ + +* ``[AIRFLOW-7044] Host key can be specified via SSH connection extras. (#12944)`` + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/ssh/__init__.py b/reference/providers/ssh/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/ssh/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/ssh/hooks/__init__.py b/reference/providers/ssh/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/ssh/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/ssh/hooks/ssh.py b/reference/providers/ssh/hooks/ssh.py new file mode 100644 index 0000000..26a5c9c --- /dev/null +++ b/reference/providers/ssh/hooks/ssh.py @@ -0,0 +1,364 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for SSH connections.""" +import getpass +import os +import warnings +from base64 import decodebytes +from io import StringIO +from typing import Dict, Optional, Tuple, Union + +import paramiko +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from paramiko.config import SSH_PORT +from sshtunnel import SSHTunnelForwarder + + +class SSHHook(BaseHook): # pylint: disable=too-many-instance-attributes + """ + Hook for ssh remote execution using Paramiko. + ref: https://github.com/paramiko/paramiko + This hook also lets you create ssh tunnel and serve as basis for SFTP file transfer + + :param ssh_conn_id: connection id from airflow Connections from where all the required + parameters can be fetched like username, password or key_file. + Thought the priority is given to the param passed during init + :type ssh_conn_id: str + :param remote_host: remote host to connect + :type remote_host: str + :param username: username to connect to the remote_host + :type username: str + :param password: password of the username to connect to the remote_host + :type password: str + :param key_file: path to key file to use to connect to the remote_host + :type key_file: str + :param port: port of remote host to connect (Default is paramiko SSH_PORT) + :type port: int + :param timeout: timeout for the attempt to connect to the remote_host. + :type timeout: int + :param keepalive_interval: send a keepalive packet to remote host every + keepalive_interval seconds + :type keepalive_interval: int + """ + + # key type name to paramiko PKey class + _default_pkey_mappings = { + "dsa": paramiko.DSSKey, + "ecdsa": paramiko.ECDSAKey, + "ed25519": paramiko.Ed25519Key, + "rsa": paramiko.RSAKey, + } + + conn_name_attr = "ssh_conn_id" + default_conn_name = "ssh_default" + conn_type = "ssh" + hook_name = "SSH" + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ["schema"], + "relabeling": { + "login": "Username", + }, + } + + def __init__( # pylint: disable=too-many-statements + self, + ssh_conn_id: Optional[str] = None, + remote_host: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + key_file: Optional[str] = None, + port: Optional[int] = None, + timeout: int = 10, + keepalive_interval: int = 30, + ) -> None: + super().__init__() + self.ssh_conn_id = ssh_conn_id + self.remote_host = remote_host + self.username = username + self.password = password + self.key_file = key_file + self.pkey = None + self.port = port + self.timeout = timeout + self.keepalive_interval = keepalive_interval + + # Default values, overridable from Connection + self.compress = True + self.no_host_key_check = True + self.allow_host_key_change = False + self.host_proxy = None + self.host_key = None + self.look_for_keys = True + + # Placeholder for deprecated __enter__ + self.client = None + + # Use connection to override defaults + if self.ssh_conn_id is not None: + conn = self.get_connection(self.ssh_conn_id) + if self.username is None: + self.username = conn.login + if self.password is None: + self.password = conn.password + if self.remote_host is None: + self.remote_host = conn.host + if self.port is None: + self.port = conn.port + if conn.extra is not None: + extra_options = conn.extra_dejson + if "key_file" in extra_options and self.key_file is None: + self.key_file = extra_options.get("key_file") + + private_key = extra_options.get("private_key") + private_key_passphrase = extra_options.get("private_key_passphrase") + if private_key: + self.pkey = self._pkey_from_private_key( + private_key, passphrase=private_key_passphrase + ) + if "timeout" in extra_options: + self.timeout = int(extra_options["timeout"], 10) + + if ( + "compress" in extra_options + and str(extra_options["compress"]).lower() == "false" + ): + self.compress = False + if ( + "no_host_key_check" in extra_options + and str(extra_options["no_host_key_check"]).lower() == "false" + ): + self.no_host_key_check = False + if ( + "allow_host_key_change" in extra_options + and str(extra_options["allow_host_key_change"]).lower() == "true" + ): + self.allow_host_key_change = True + if ( + "look_for_keys" in extra_options + and str(extra_options["look_for_keys"]).lower() == "false" + ): + self.look_for_keys = False + if "host_key" in extra_options and self.no_host_key_check is False: + decoded_host_key = decodebytes( + extra_options["host_key"].encode("utf-8") + ) + self.host_key = paramiko.RSAKey(data=decoded_host_key) + if self.pkey and self.key_file: + raise AirflowException( + "Params key_file and private_key both provided. Must provide no more than one." + ) + + if not self.remote_host: + raise AirflowException("Missing required param: remote_host") + + # Auto detecting username values from system + if not self.username: + self.log.debug( + "username to ssh to host: %s is not specified for connection id" + " %s. Using system's default provided by getpass.getuser()", + self.remote_host, + self.ssh_conn_id, + ) + self.username = getpass.getuser() + + user_ssh_config_filename = os.path.expanduser("~/.ssh/config") + if os.path.isfile(user_ssh_config_filename): + ssh_conf = paramiko.SSHConfig() + with open(user_ssh_config_filename) as config_fd: + ssh_conf.parse(config_fd) + host_info = ssh_conf.lookup(self.remote_host) + if host_info and host_info.get("proxycommand"): + self.host_proxy = paramiko.ProxyCommand(host_info.get("proxycommand")) + + if not (self.password or self.key_file): + if host_info and host_info.get("identityfile"): + self.key_file = host_info.get("identityfile")[0] + + self.port = self.port or SSH_PORT + + def get_conn(self) -> paramiko.SSHClient: + """ + Opens a ssh connection to the remote host. + + :rtype: paramiko.client.SSHClient + """ + self.log.debug("Creating SSH client for conn_id: %s", self.ssh_conn_id) + client = paramiko.SSHClient() + + if not self.allow_host_key_change: + self.log.warning( + "Remote Identification Change is not verified. " + "This wont protect against Man-In-The-Middle attacks" + ) + client.load_system_host_keys() + + if self.no_host_key_check: + self.log.warning( + "No Host Key Verification. This wont protect against Man-In-The-Middle attacks" + ) + # Default is RejectPolicy + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + else: + if self.host_key is not None: + client_host_keys = client.get_host_keys() + client_host_keys.add(self.remote_host, "ssh-rsa", self.host_key) + else: + pass # will fallback to system host keys if none explicitly specified in conn extra + + connect_kwargs = dict( + hostname=self.remote_host, + username=self.username, + timeout=self.timeout, + compress=self.compress, + port=self.port, + sock=self.host_proxy, + look_for_keys=self.look_for_keys, + ) + + if self.password: + password = self.password.strip() + connect_kwargs.update(password=password) + + if self.pkey: + connect_kwargs.update(pkey=self.pkey) + + if self.key_file: + connect_kwargs.update(key_filename=self.key_file) + + client.connect(**connect_kwargs) + + if self.keepalive_interval: + client.get_transport().set_keepalive(self.keepalive_interval) + + self.client = client + return client + + def __enter__(self) -> "SSHHook": + warnings.warn( + "The contextmanager of SSHHook is deprecated." + "Please use get_conn() as a contextmanager instead." + "This method will be removed in Airflow 2.0", + category=DeprecationWarning, + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if self.client is not None: + self.client.close() + self.client = None + + def get_tunnel( + self, + remote_port: int, + remote_host: str = "localhost", + local_port: Optional[int] = None, + ) -> SSHTunnelForwarder: + """ + Creates a tunnel between two hosts. Like ssh -L :host:. + + :param remote_port: The remote port to create a tunnel to + :type remote_port: int + :param remote_host: The remote host to create a tunnel to (default localhost) + :type remote_host: str + :param local_port: The local port to attach the tunnel to + :type local_port: int + + :return: sshtunnel.SSHTunnelForwarder object + """ + if local_port: + local_bind_address: Union[Tuple[str, int], Tuple[str]] = ( + "localhost", + local_port, + ) + else: + local_bind_address = ("localhost",) + + tunnel_kwargs = dict( + ssh_port=self.port, + ssh_username=self.username, + ssh_pkey=self.key_file or self.pkey, + ssh_proxy=self.host_proxy, + local_bind_address=local_bind_address, + remote_bind_address=(remote_host, remote_port), + logger=self.log, + ) + + if self.password: + password = self.password.strip() + tunnel_kwargs.update( + ssh_password=password, + ) + else: + tunnel_kwargs.update( + host_pkey_directories=[], + ) + + client = SSHTunnelForwarder(self.remote_host, **tunnel_kwargs) + + return client + + def create_tunnel( + self, local_port: int, remote_port: int, remote_host: str = "localhost" + ) -> SSHTunnelForwarder: + """ + Creates tunnel for SSH connection [Deprecated]. + + :param local_port: local port number + :param remote_port: remote port number + :param remote_host: remote host + :return: + """ + warnings.warn( + "SSHHook.create_tunnel is deprecated, Please" + "use get_tunnel() instead. But please note that the" + "order of the parameters have changed" + "This method will be removed in Airflow 2.0", + category=DeprecationWarning, + ) + + return self.get_tunnel(remote_port, remote_host, local_port) + + def _pkey_from_private_key( + self, private_key: str, passphrase: Optional[str] = None + ) -> paramiko.PKey: + """ + Creates appropriate paramiko key for given private key + + :param private_key: string containing private key + :return: `paramiko.PKey` appropriate for given key + :raises AirflowException: if key cannot be read + """ + allowed_pkey_types = self._default_pkey_mappings.values() + for pkey_type in allowed_pkey_types: + try: + key = pkey_type.from_private_key( + StringIO(private_key), password=passphrase + ) + return key + except paramiko.ssh_exception.SSHException: + continue + raise AirflowException( + "Private key provided cannot be read by paramiko." + "Ensure key provided is valid for one of the following" + "key formats: RSA, DSS, ECDSA, or Ed25519" + ) diff --git a/reference/providers/ssh/operators/__init__.py b/reference/providers/ssh/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/ssh/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/ssh/operators/ssh.py b/reference/providers/ssh/operators/ssh.py new file mode 100644 index 0000000..d528b5a --- /dev/null +++ b/reference/providers/ssh/operators/ssh.py @@ -0,0 +1,194 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from base64 import b64encode +from select import select +from typing import Optional, Union + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.ssh.hooks.ssh import SSHHook +from airflow.utils.decorators import apply_defaults + + +class SSHOperator(BaseOperator): + """ + SSHOperator to execute commands on given remote host using the ssh_hook. + + :param ssh_hook: predefined ssh_hook to use for remote execution. + Either `ssh_hook` or `ssh_conn_id` needs to be provided. + :type ssh_hook: airflow.providers.ssh.hooks.ssh.SSHHook + :param ssh_conn_id: connection id from airflow Connections. + `ssh_conn_id` will be ignored if `ssh_hook` is provided. + :type ssh_conn_id: str + :param remote_host: remote host to connect (templated) + Nullable. If provided, it will replace the `remote_host` which was + defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. + :type remote_host: str + :param command: command to execute on remote host. (templated) + :type command: str + :param timeout: timeout (in seconds) for executing the command. The default is 10 seconds. + :type timeout: int + :param environment: a dict of shell environment variables. Note that the + server will reject them silently if `AcceptEnv` is not set in SSH config. + :type environment: dict + :param get_pty: request a pseudo-terminal from the server. Set to ``True`` + to have the remote process killed upon task timeout. + The default is ``False`` but note that `get_pty` is forced to ``True`` + when the `command` starts with ``sudo``. + :type get_pty: bool + """ + + template_fields = ("command", "remote_host") + template_ext = (".sh",) + + @apply_defaults + def __init__( + self, + *, + ssh_hook: Optional[SSHHook] = None, + ssh_conn_id: Optional[str] = None, + remote_host: Optional[str] = None, + command: Optional[str] = None, + timeout: int = 10, + environment: Optional[dict] = None, + get_pty: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.ssh_hook = ssh_hook + self.ssh_conn_id = ssh_conn_id + self.remote_host = remote_host + self.command = command + self.timeout = timeout + self.environment = environment + self.get_pty = ( + (self.command.startswith("sudo") or get_pty) if self.command else get_pty + ) + + def execute(self, context) -> Union[bytes, str, bool]: + try: + if self.ssh_conn_id: + if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): + self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") + else: + self.log.info( + "ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook." + ) + self.ssh_hook = SSHHook( + ssh_conn_id=self.ssh_conn_id, timeout=self.timeout + ) + + if not self.ssh_hook: + raise AirflowException( + "Cannot operate without ssh_hook or ssh_conn_id." + ) + + if self.remote_host is not None: + self.log.info( + "remote_host is provided explicitly. " + "It will replace the remote_host which was defined " + "in ssh_hook or predefined in connection of ssh_conn_id." + ) + self.ssh_hook.remote_host = self.remote_host + + if not self.command: + raise AirflowException("SSH command not specified. Aborting.") + + with self.ssh_hook.get_conn() as ssh_client: + self.log.info("Running command: %s", self.command) + + # set timeout taken as params + stdin, stdout, stderr = ssh_client.exec_command( + command=self.command, + get_pty=self.get_pty, + timeout=self.timeout, + environment=self.environment, + ) + # get channels + channel = stdout.channel + + # closing stdin + stdin.close() + channel.shutdown_write() + + agg_stdout = b"" + agg_stderr = b"" + + # capture any initial output in case channel is closed already + stdout_buffer_length = len(stdout.channel.in_buffer) + + if stdout_buffer_length > 0: + agg_stdout += stdout.channel.recv(stdout_buffer_length) + + # read from both stdout and stderr + while ( + not channel.closed + or channel.recv_ready() + or channel.recv_stderr_ready() + ): + readq, _, _ = select([channel], [], [], self.timeout) + for recv in readq: + if recv.recv_ready(): + line = stdout.channel.recv(len(recv.in_buffer)) + agg_stdout += line + self.log.info(line.decode("utf-8", "replace").strip("\n")) + if recv.recv_stderr_ready(): + line = stderr.channel.recv_stderr( + len(recv.in_stderr_buffer) + ) + agg_stderr += line + self.log.warning( + line.decode("utf-8", "replace").strip("\n") + ) + if ( + stdout.channel.exit_status_ready() + and not stderr.channel.recv_stderr_ready() + and not stdout.channel.recv_ready() + ): + stdout.channel.shutdown_read() + stdout.channel.close() + break + + stdout.close() + stderr.close() + + exit_status = stdout.channel.recv_exit_status() + if exit_status == 0: + enable_pickling = conf.getboolean("core", "enable_xcom_pickling") + if enable_pickling: + return agg_stdout + else: + return b64encode(agg_stdout).decode("utf-8") + + else: + error_msg = agg_stderr.decode("utf-8") + raise AirflowException( + f"error running cmd: {self.command}, error: {error_msg}" + ) + + except Exception as e: + raise AirflowException(f"SSH operator error: {str(e)}") + + return True + + def tunnel(self) -> None: + """Get ssh tunnel""" + ssh_client = self.ssh_hook.get_conn() # type: ignore[union-attr] + ssh_client.get_transport() diff --git a/reference/providers/ssh/provider.yaml b/reference/providers/ssh/provider.yaml new file mode 100644 index 0000000..db6d350 --- /dev/null +++ b/reference/providers/ssh/provider.yaml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-ssh +name: SSH +description: | + `Secure Shell (SSH) `__ + +versions: + - 1.2.0 + - 1.1.0 + - 1.0.0 + +integrations: + - integration-name: Secure Shell (SSH) + external-doc-url: https://tools.ietf.org/html/rfc4251 + logo: /integration-logos/ssh/SSH.png + tags: [protocol] + +operators: + - integration-name: Secure Shell (SSH) + python-modules: + - airflow.providers.ssh.operators.ssh + +hooks: + - integration-name: Secure Shell (SSH) + python-modules: + - airflow.providers.ssh.hooks.ssh + +hook-class-names: + - airflow.providers.ssh.hooks.ssh.SSHHook diff --git a/reference/providers/tableau/CHANGELOG.rst b/reference/providers/tableau/CHANGELOG.rst new file mode 100644 index 0000000..cef7dda --- /dev/null +++ b/reference/providers/tableau/CHANGELOG.rst @@ -0,0 +1,25 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/tableau/__init__.py b/reference/providers/tableau/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/tableau/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/tableau/example_dags/__init__.py b/reference/providers/tableau/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/tableau/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/tableau/example_dags/example_tableau_refresh_workbook.py b/reference/providers/tableau/example_dags/example_tableau_refresh_workbook.py new file mode 100644 index 0000000..ec4549b --- /dev/null +++ b/reference/providers/tableau/example_dags/example_tableau_refresh_workbook.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example dag that performs two refresh operations on a Tableau Workbook aka Extract. The first one +waits until it succeeds. The second does not wait since this is an asynchronous operation and we don't know +when the operation actually finishes. That's why we have another task that checks only that. +""" +from datetime import timedelta + +from airflow import DAG +from airflow.providers.tableau.operators.tableau_refresh_workbook import ( + TableauRefreshWorkbookOperator, +) +from airflow.providers.tableau.sensors.tableau_job_status import TableauJobStatusSensor +from airflow.utils.dates import days_ago + +DEFAULT_ARGS = { + "owner": "airflow", + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, +} + +with DAG( + dag_id="example_tableau_refresh_workbook", + default_args=DEFAULT_ARGS, + dagrun_timeout=timedelta(hours=2), + schedule_interval=None, + start_date=days_ago(2), + tags=["example"], +) as dag: + # Refreshes a workbook and waits until it succeeds. + task_refresh_workbook_blocking = TableauRefreshWorkbookOperator( + site_id="my_site", + workbook_name="MyWorkbook", + blocking=True, + task_id="refresh_tableau_workbook_blocking", + ) + # Refreshes a workbook and does not wait until it succeeds. + task_refresh_workbook_non_blocking = TableauRefreshWorkbookOperator( + site_id="my_site", + workbook_name="MyWorkbook", + blocking=False, + task_id="refresh_tableau_workbook_non_blocking", + ) + # The following task queries the status of the workbook refresh job until it succeeds. + task_check_job_status = TableauJobStatusSensor( + site_id="my_site", + job_id="{{ ti.xcom_pull(task_ids='refresh_tableau_workbook_non_blocking') }}", + task_id="check_tableau_job_status", + ) + task_refresh_workbook_non_blocking >> task_check_job_status diff --git a/reference/providers/tableau/hooks/__init__.py b/reference/providers/tableau/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/tableau/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/tableau/hooks/tableau.py b/reference/providers/tableau/hooks/tableau.py new file mode 100644 index 0000000..85f66b4 --- /dev/null +++ b/reference/providers/tableau/hooks/tableau.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from enum import Enum +from typing import Any, Optional + +from airflow.hooks.base import BaseHook +from tableauserverclient import Pager, PersonalAccessTokenAuth, Server, TableauAuth +from tableauserverclient.server import Auth + + +class TableauJobFinishCode(Enum): + """ + The finish code indicates the status of the job. + + .. seealso:: https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref.htm#query_job + + """ + + PENDING = -1 + SUCCESS = 0 + ERROR = 1 + CANCELED = 2 + + +class TableauHook(BaseHook): + """ + Connects to the Tableau Server Instance and allows to communicate with it. + + .. seealso:: https://tableau.github.io/server-client-python/docs/ + + :param site_id: The id of the site where the workbook belongs to. + It will connect to the default site if you don't provide an id. + :type site_id: Optional[str] + :param tableau_conn_id: The Tableau Connection id containing the credentials + to authenticate to the Tableau Server. + :type tableau_conn_id: str + """ + + conn_name_attr = "tableau_conn_id" + default_conn_name = "tableau_default" + conn_type = "tableau" + hook_name = "Tableau" + + def __init__( + self, site_id: Optional[str] = None, tableau_conn_id: str = default_conn_name + ) -> None: + super().__init__() + self.tableau_conn_id = tableau_conn_id + self.conn = self.get_connection(self.tableau_conn_id) + self.site_id = site_id or self.conn.extra_dejson.get("site_id", "") + self.server = Server(self.conn.host, use_server_version=True) + self.tableau_conn = None + + def __enter__(self): + if not self.tableau_conn: + self.tableau_conn = self.get_conn() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.server.auth.sign_out() + + def get_conn(self) -> Auth.contextmgr: + """ + Signs in to the Tableau Server and automatically signs out if used as ContextManager. + + :return: an authorized Tableau Server Context Manager object. + :rtype: tableauserverclient.server.Auth.contextmgr + """ + if self.conn.login and self.conn.password: + return self._auth_via_password() + if ( + "token_name" in self.conn.extra_dejson + and "personal_access_token" in self.conn.extra_dejson + ): + return self._auth_via_token() + raise NotImplementedError( + "No Authentication method found for given Credentials!" + ) + + def _auth_via_password(self) -> Auth.contextmgr: + tableau_auth = TableauAuth( + username=self.conn.login, password=self.conn.password, site_id=self.site_id + ) + return self.server.auth.sign_in(tableau_auth) + + def _auth_via_token(self) -> Auth.contextmgr: + tableau_auth = PersonalAccessTokenAuth( + token_name=self.conn.extra_dejson["token_name"], + personal_access_token=self.conn.extra_dejson["personal_access_token"], + site_id=self.site_id, + ) + return self.server.auth.sign_in_with_personal_access_token(tableau_auth) + + def get_all(self, resource_name: str) -> Pager: + """ + Get all items of the given resource. + + .. seealso:: https://tableau.github.io/server-client-python/docs/page-through-results + + :param resource_name: The name of the resource to paginate. + For example: jobs or workbooks + :type resource_name: str + :return: all items by returning a Pager. + :rtype: tableauserverclient.Pager + """ + resource = getattr(self.server, resource_name) + return Pager(resource.get) diff --git a/reference/providers/tableau/operators/__init__.py b/reference/providers/tableau/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/tableau/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/tableau/operators/tableau_refresh_workbook.py b/reference/providers/tableau/operators/tableau_refresh_workbook.py new file mode 100644 index 0000000..f9791e5 --- /dev/null +++ b/reference/providers/tableau/operators/tableau_refresh_workbook.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.tableau.hooks.tableau import TableauHook +from airflow.utils.decorators import apply_defaults +from tableauserverclient import WorkbookItem + + +class TableauRefreshWorkbookOperator(BaseOperator): + """ + Refreshes a Tableau Workbook/Extract + + .. seealso:: https://tableau.github.io/server-client-python/docs/api-ref#workbooks + + :param workbook_name: The name of the workbook to refresh. + :type workbook_name: str + :param site_id: The id of the site where the workbook belongs to. + :type site_id: Optional[str] + :param blocking: By default the extract refresh will be blocking means it will wait until it has finished. + :type blocking: bool + :param tableau_conn_id: The Tableau Connection id containing the credentials + to authenticate to the Tableau Server. + :type tableau_conn_id: str + """ + + @apply_defaults + def __init__( + self, + *, + workbook_name: str, + site_id: Optional[str] = None, + blocking: bool = True, + tableau_conn_id: str = "tableau_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.workbook_name = workbook_name + self.site_id = site_id + self.blocking = blocking + self.tableau_conn_id = tableau_conn_id + + def execute(self, context: dict) -> str: + """ + Executes the Tableau Extract Refresh and pushes the job id to xcom. + + :param context: The task context during execution. + :type context: dict + :return: the id of the job that executes the extract refresh + :rtype: str + """ + with TableauHook(self.site_id, self.tableau_conn_id) as tableau_hook: + workbook = self._get_workbook_by_name(tableau_hook) + + job_id = self._refresh_workbook(tableau_hook, workbook.id) + if self.blocking: + from airflow.providers.tableau.sensors.tableau_job_status import ( + TableauJobStatusSensor, + ) + + TableauJobStatusSensor( + job_id=job_id, + site_id=self.site_id, + tableau_conn_id=self.tableau_conn_id, + task_id="wait_until_succeeded", + dag=None, + ).execute(context={}) + self.log.info( + "Workbook %s has been successfully refreshed.", self.workbook_name + ) + return job_id + + def _get_workbook_by_name(self, tableau_hook: TableauHook) -> WorkbookItem: + for workbook in tableau_hook.get_all(resource_name="workbooks"): + if workbook.name == self.workbook_name: + self.log.info("Found matching workbook with id %s", workbook.id) + return workbook + + raise AirflowException(f"Workbook {self.workbook_name} not found!") + + def _refresh_workbook(self, tableau_hook: TableauHook, workbook_id: str) -> str: + job = tableau_hook.server.workbooks.refresh(workbook_id) + self.log.info("Refreshing Workbook %s...", self.workbook_name) + return job.id diff --git a/reference/providers/tableau/provider.yaml b/reference/providers/tableau/provider.yaml new file mode 100644 index 0000000..e777947 --- /dev/null +++ b/reference/providers/tableau/provider.yaml @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-tableau +name: Tableau +description: | + `Tableau `__ + +versions: + - 1.0.0 + +integrations: + - integration-name: Tableau + external-doc-url: https://www.tableau.com/ + logo: /integration-logos/tableau/tableau.png + tags: [service] + +operators: + - integration-name: Tableau + python-modules: + - airflow.providers.tableau.operators.tableau_refresh_workbook + +sensors: + - integration-name: Tableau + python-modules: + - airflow.providers.tableau.sensors.tableau_job_status + +hooks: + - integration-name: Tableau + python-modules: + - airflow.providers.tableau.hooks.tableau + +hook-class-names: + - airflow.providers.tableau.hooks.tableau.TableauHook diff --git a/reference/providers/tableau/sensors/__init__.py b/reference/providers/tableau/sensors/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/tableau/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/tableau/sensors/tableau_job_status.py b/reference/providers/tableau/sensors/tableau_job_status.py new file mode 100644 index 0000000..ab64f5e --- /dev/null +++ b/reference/providers/tableau/sensors/tableau_job_status.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.providers.tableau.hooks.tableau import TableauHook, TableauJobFinishCode +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class TableauJobFailedException(AirflowException): + """An exception that indicates that a Job failed to complete.""" + + +class TableauJobStatusSensor(BaseSensorOperator): + """ + Watches the status of a Tableau Server Job. + + .. seealso:: https://tableau.github.io/server-client-python/docs/api-ref#jobs + + :param job_id: The job to watch. + :type job_id: str + :param site_id: The id of the site where the workbook belongs to. + :type site_id: Optional[str] + :param tableau_conn_id: The Tableau Connection id containing the credentials + to authenticate to the Tableau Server. + :type tableau_conn_id: str + """ + + template_fields = ("job_id",) + + @apply_defaults + def __init__( + self, + *, + job_id: str, + site_id: Optional[str] = None, + tableau_conn_id: str = "tableau_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.tableau_conn_id = tableau_conn_id + self.job_id = job_id + self.site_id = site_id + + def poke(self, context: dict) -> bool: + """ + Pokes until the job has successfully finished. + + :param context: The task context during execution. + :type context: dict + :return: True if it succeeded and False if not. + :rtype: bool + """ + with TableauHook(self.site_id, self.tableau_conn_id) as tableau_hook: + finish_code = TableauJobFinishCode( + int(tableau_hook.server.jobs.get_by_id(self.job_id).finish_code) + ) + self.log.info( + "Current finishCode is %s (%s)", finish_code.name, finish_code.value + ) + if finish_code in [ + TableauJobFinishCode.ERROR, + TableauJobFinishCode.CANCELED, + ]: + raise TableauJobFailedException( + "The Tableau Refresh Workbook Job failed!" + ) + return finish_code == TableauJobFinishCode.SUCCESS diff --git a/reference/providers/telegram/CHANGELOG.rst b/reference/providers/telegram/CHANGELOG.rst new file mode 100644 index 0000000..4348fa8 --- /dev/null +++ b/reference/providers/telegram/CHANGELOG.rst @@ -0,0 +1,38 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.2 +..... + +Bug fixes +~~~~~~~~~ + +* ``Fix the AttributeError with text field in TelegramOperator (#13990)`` + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/telegram/__init__.py b/reference/providers/telegram/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/telegram/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/telegram/example_dags/__init__.py b/reference/providers/telegram/example_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/telegram/example_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/telegram/example_dags/example_telegram.py b/reference/providers/telegram/example_dags/example_telegram.py new file mode 100644 index 0000000..ae26fb9 --- /dev/null +++ b/reference/providers/telegram/example_dags/example_telegram.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example use of Telegram operator. +""" + +from airflow import DAG +from airflow.providers.telegram.operators.telegram import TelegramOperator +from airflow.utils.dates import days_ago + +default_args = { + "owner": "airflow", +} + +dag = DAG( + "example_telegram", + default_args=default_args, + start_date=days_ago(2), + tags=["example"], +) + +# [START howto_operator_telegram] + +send_message_telegram_task = TelegramOperator( + task_id="send_message_telegram", + telegram_conn_id="telegram_conn_id", + chat_id="-3222103937", + text="Hello from Airflow!", + dag=dag, +) + +# [END howto_operator_telegram] diff --git a/reference/providers/telegram/hooks/__init__.py b/reference/providers/telegram/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/telegram/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/telegram/hooks/telegram.py b/reference/providers/telegram/hooks/telegram.py new file mode 100644 index 0000000..653ba12 --- /dev/null +++ b/reference/providers/telegram/hooks/telegram.py @@ -0,0 +1,155 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Telegram""" +from typing import Optional + +import telegram +import tenacity +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +class TelegramHook(BaseHook): + """ + This hook allows you to post messages to Telegram using the telegram python-telegram-bot library. + + The library can be found here: https://github.com/python-telegram-bot/python-telegram-bot + It accepts both telegram bot API token directly or connection that has telegram bot API token. + If both supplied, token parameter will be given precedence, otherwise 'password' field in the connection + from telegram_conn_id will be used. + chat_id can also be provided in the connection using 'host' field in connection. + Following is the details of a telegram_connection: + name: 'telegram-connection-name' + conn_type: 'http' + password: 'TELEGRAM_TOKEN' + host: 'chat_id' (optional) + Examples: + .. code-block:: python + + # Create hook + telegram_hook = TelegramHook(telegram_conn_id='telegram_default') + # or telegram_hook = TelegramHook(telegram_conn_id='telegram_default', chat_id='-1xxx') + # or telegram_hook = TelegramHook(token='xxx:xxx', chat_id='-1xxx') + + # Call method from telegram bot client + telegram_hook.send_message(None', {"text": "message", "chat_id": "-1xxx"}) + # or telegram_hook.send_message(None', {"text": "message"}) + + :param telegram_conn_id: connection that optionally has Telegram API token in the password field + :type telegram_conn_id: str + :param token: optional telegram API token + :type token: str + :param chat_id: optional chat_id of the telegram chat/channel/group + :type chat_id: str + """ + + def __init__( + self, + telegram_conn_id: Optional[str] = None, + token: Optional[str] = None, + chat_id: Optional[str] = None, + ) -> None: + super().__init__() + self.token = self.__get_token(token, telegram_conn_id) + self.chat_id = self.__get_chat_id(chat_id, telegram_conn_id) + self.connection = self.get_conn() + + def get_conn(self) -> telegram.bot.Bot: + """ + Returns the telegram bot client + + :return: telegram bot client + :rtype: telegram.bot.Bot + """ + return telegram.bot.Bot(token=self.token) + + def __get_token(self, token: Optional[str], telegram_conn_id: str) -> str: + """ + Returns the telegram API token + + :param token: telegram API token + :type token: str + :param telegram_conn_id: telegram connection name + :type telegram_conn_id: str + :return: telegram API token + :rtype: str + """ + if token is not None: + return token + + if telegram_conn_id is not None: + conn = self.get_connection(telegram_conn_id) + + if not conn.password: + raise AirflowException("Missing token(password) in Telegram connection") + + return conn.password + + raise AirflowException( + "Cannot get token: No valid Telegram connection supplied." + ) + + def __get_chat_id( + self, chat_id: Optional[str], telegram_conn_id: str + ) -> Optional[str]: + """ + Returns the telegram chat ID for a chat/channel/group + + :param chat_id: optional chat ID + :type chat_id: str + :param telegram_conn_id: telegram connection name + :type telegram_conn_id: str + :return: telegram chat ID + :rtype: str + """ + if chat_id is not None: + return chat_id + + if telegram_conn_id is not None: + conn = self.get_connection(telegram_conn_id) + return conn.host + + return None + + @tenacity.retry( + retry=tenacity.retry_if_exception_type(telegram.error.TelegramError), + stop=tenacity.stop_after_attempt(5), + wait=tenacity.wait_fixed(1), + ) + def send_message(self, api_params: dict) -> None: + """ + Sends the message to a telegram channel or chat. + + :param api_params: params for telegram_instance.send_message. It can also be used to override chat_id + :type api_params: dict + """ + kwargs = { + "chat_id": self.chat_id, + "parse_mode": telegram.parsemode.ParseMode.HTML, + "disable_web_page_preview": True, + } + kwargs.update(api_params) + + if "text" not in kwargs or kwargs["text"] is None: + raise AirflowException("'text' must be provided for telegram message") + + if kwargs["chat_id"] is None: + raise AirflowException("'chat_id' must be provided for telegram message") + + response = self.connection.send_message(**kwargs) + self.log.debug(response) diff --git a/reference/providers/telegram/operators/__init__.py b/reference/providers/telegram/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/telegram/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/telegram/operators/telegram.py b/reference/providers/telegram/operators/telegram.py new file mode 100644 index 0000000..c5048f5 --- /dev/null +++ b/reference/providers/telegram/operators/telegram.py @@ -0,0 +1,85 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operator for Telegram""" +from typing import Optional + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.telegram.hooks.telegram import TelegramHook +from airflow.utils.decorators import apply_defaults + + +class TelegramOperator(BaseOperator): + """ + This operator allows you to post messages to Telegram using Telegram Bot API. + Takes both Telegram Bot API token directly or connection that has Telegram token in password field. + If both supplied, token parameter will be given precedence. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TelegramOperator` + + :param telegram_conn_id: Telegram connection ID which its password is Telegram API token + :type telegram_conn_id: str + :param token: Telegram API Token + :type token: str + :param chat_id: Telegram chat ID for a chat/channel/group + :type chat_id: str + :param text: Message to be sent on telegram + :type text: str + :param telegram_kwargs: Extra args to be passed to telegram client + :type telegram_kwargs: dict + """ + + template_fields = ("text", "chat_id") + ui_color = "#FFBA40" + + @apply_defaults + def __init__( + self, + *, + telegram_conn_id: str = "telegram_default", + token: Optional[str] = None, + chat_id: Optional[str] = None, + text: str = "No message has been set.", + telegram_kwargs: Optional[dict] = None, + **kwargs, + ): + self.chat_id = chat_id + self.token = token + self.telegram_kwargs = telegram_kwargs or {} + self.text = text + + if telegram_conn_id is None: + raise AirflowException("No valid Telegram connection id supplied.") + + self.telegram_conn_id = telegram_conn_id + + super().__init__(**kwargs) + + def execute(self, **kwargs) -> None: + """Calls the TelegramHook to post the provided Telegram message""" + if self.text: + self.telegram_kwargs["text"] = self.text + + telegram_hook = TelegramHook( + telegram_conn_id=self.telegram_conn_id, + token=self.token, + chat_id=self.chat_id, + ) + telegram_hook.send_message(self.telegram_kwargs) diff --git a/reference/providers/telegram/provider.yaml b/reference/providers/telegram/provider.yaml new file mode 100644 index 0000000..79ae470 --- /dev/null +++ b/reference/providers/telegram/provider.yaml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-telegram +name: Telegram +description: | + `Telegram `__ + +versions: + - 1.0.2 + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Telegram + external-doc-url: https://telegram.org/ + how-to-guide: + - /docs/apache-airflow-providers-telegram/operators.rst + logo: /integration-logos/telegram/Telegram.png + tags: [service] + +operators: + - integration-name: Telegram + python-modules: + - airflow.providers.telegram.operators.telegram + +hooks: + - integration-name: Telegram + python-modules: + - airflow.providers.telegram.hooks.telegram diff --git a/reference/providers/vertica/CHANGELOG.rst b/reference/providers/vertica/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/vertica/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/vertica/__init__.py b/reference/providers/vertica/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/vertica/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/vertica/hooks/__init__.py b/reference/providers/vertica/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/vertica/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/vertica/hooks/vertica.py b/reference/providers/vertica/hooks/vertica.py new file mode 100644 index 0000000..893d07d --- /dev/null +++ b/reference/providers/vertica/hooks/vertica.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from airflow.hooks.dbapi import DbApiHook +from vertica_python import connect + + +class VerticaHook(DbApiHook): + """Interact with Vertica.""" + + conn_name_attr = "vertica_conn_id" + default_conn_name = "vertica_default" + conn_type = "vertica" + hook_name = "Vertica" + supports_autocommit = True + + def get_conn(self) -> connect: + """Return verticaql connection object""" + conn = self.get_connection(self.vertica_conn_id) # type: ignore # pylint: disable=no-member + conn_config = { + "user": conn.login, + "password": conn.password or "", + "database": conn.schema, + "host": conn.host or "localhost", + } + + if not conn.port: + conn_config["port"] = 5433 + else: + conn_config["port"] = int(conn.port) + + conn = connect(**conn_config) + return conn diff --git a/reference/providers/vertica/operators/__init__.py b/reference/providers/vertica/operators/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/vertica/operators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/vertica/operators/vertica.py b/reference/providers/vertica/operators/vertica.py new file mode 100644 index 0000000..5a1381f --- /dev/null +++ b/reference/providers/vertica/operators/vertica.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, List, Union + +from airflow.models import BaseOperator +from airflow.providers.vertica.hooks.vertica import VerticaHook +from airflow.utils.decorators import apply_defaults + + +class VerticaOperator(BaseOperator): + """ + Executes sql code in a specific Vertica database. + + :param vertica_conn_id: reference to a specific Vertica database + :type vertica_conn_id: str + :param sql: the sql code to be executed. (templated) + :type sql: Can receive a str representing a sql statement, + a list of str (sql statements), or reference to a template file. + Template reference are recognized by str ending in '.sql' + """ + + template_fields = ("sql",) + template_ext = (".sql",) + ui_color = "#b4e0ff" + + @apply_defaults + def __init__( + self, + *, + sql: Union[str, List[str]], + vertica_conn_id: str = "vertica_default", + **kwargs: Any + ) -> None: + super().__init__(**kwargs) + self.vertica_conn_id = vertica_conn_id + self.sql = sql + + def execute(self, context: Dict[Any, Any]) -> None: + self.log.info("Executing: %s", self.sql) + hook = VerticaHook(vertica_conn_id=self.vertica_conn_id) + hook.run(sql=self.sql) diff --git a/reference/providers/vertica/provider.yaml b/reference/providers/vertica/provider.yaml new file mode 100644 index 0000000..3a222d2 --- /dev/null +++ b/reference/providers/vertica/provider.yaml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-vertica +name: Vertica +description: | + `Vertica `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Vertica + external-doc-url: https://www.vertica.com/ + logo: /integration-logos/vertica/Vertica.png + tags: [software] + +operators: + - integration-name: Vertica + python-modules: + - airflow.providers.vertica.operators.vertica + +hooks: + - integration-name: Vertica + python-modules: + - airflow.providers.vertica.hooks.vertica + +hook-class-names: + - airflow.providers.vertica.hooks.vertica.VerticaHook diff --git a/reference/providers/yandex/CHANGELOG.rst b/reference/providers/yandex/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/yandex/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/yandex/__init__.py b/reference/providers/yandex/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/yandex/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/yandex/example_dags/__init__.py b/reference/providers/yandex/example_dags/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/yandex/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/yandex/example_dags/example_yandexcloud_dataproc.py b/reference/providers/yandex/example_dags/example_yandexcloud_dataproc.py new file mode 100644 index 0000000..dd9aa09 --- /dev/null +++ b/reference/providers/yandex/example_dags/example_yandexcloud_dataproc.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow import DAG +from airflow.providers.yandex.operators.yandexcloud_dataproc import ( + DataprocCreateClusterOperator, + DataprocCreateHiveJobOperator, + DataprocCreateMapReduceJobOperator, + DataprocCreatePysparkJobOperator, + DataprocCreateSparkJobOperator, + DataprocDeleteClusterOperator, +) +from airflow.utils.dates import days_ago + +# should be filled with appropriate ids + +# Airflow connection with type "yandexcloud" must be created. +# By default connection with id "yandexcloud_default" will be used +CONNECTION_ID = "yandexcloud_default" + +# Name of the datacenter where Dataproc cluster will be created +AVAILABILITY_ZONE_ID = "ru-central1-c" + +# Dataproc cluster jobs will produce logs in specified s3 bucket +S3_BUCKET_NAME_FOR_JOB_LOGS = "" + + +default_args = { + "owner": "airflow", +} + +with DAG( + "example_yandexcloud_dataproc_operator", + default_args=default_args, + schedule_interval=None, + start_date=days_ago(1), + tags=["example"], +) as dag: + create_cluster = DataprocCreateClusterOperator( + task_id="create_cluster", + zone=AVAILABILITY_ZONE_ID, + connection_id=CONNECTION_ID, + s3_bucket=S3_BUCKET_NAME_FOR_JOB_LOGS, + ) + + create_hive_query = DataprocCreateHiveJobOperator( + task_id="create_hive_query", + query="SELECT 1;", + ) + + create_hive_query_from_file = DataprocCreateHiveJobOperator( + task_id="create_hive_query_from_file", + query_file_uri="s3a://data-proc-public/jobs/sources/hive-001/main.sql", + script_variables={ + "CITIES_URI": "s3a://data-proc-public/jobs/sources/hive-001/cities/", + "COUNTRY_CODE": "RU", + }, + ) + + create_mapreduce_job = DataprocCreateMapReduceJobOperator( + task_id="create_mapreduce_job", + main_class="org.apache.hadoop.streaming.HadoopStreaming", + file_uris=[ + "s3a://data-proc-public/jobs/sources/mapreduce-001/mapper.py", + "s3a://data-proc-public/jobs/sources/mapreduce-001/reducer.py", + ], + args=[ + "-mapper", + "mapper.py", + "-reducer", + "reducer.py", + "-numReduceTasks", + "1", + "-input", + "s3a://data-proc-public/jobs/sources/data/cities500.txt.bz2", + "-output", + f"s3a://{S3_BUCKET_NAME_FOR_JOB_LOGS}/dataproc/job/results", + ], + properties={ + "yarn.app.mapreduce.am.resource.mb": "2048", + "yarn.app.mapreduce.am.command-opts": "-Xmx2048m", + "mapreduce.job.maps": "6", + }, + ) + + create_spark_job = DataprocCreateSparkJobOperator( + task_id="create_spark_job", + main_jar_file_uri="s3a://data-proc-public/jobs/sources/java/dataproc-examples-1.0.jar", + main_class="ru.yandex.cloud.dataproc.examples.PopulationSparkJob", + file_uris=[ + "s3a://data-proc-public/jobs/sources/data/config.json", + ], + archive_uris=[ + "s3a://data-proc-public/jobs/sources/data/country-codes.csv.zip", + ], + jar_file_uris=[ + "s3a://data-proc-public/jobs/sources/java/icu4j-61.1.jar", + "s3a://data-proc-public/jobs/sources/java/commons-lang-2.6.jar", + "s3a://data-proc-public/jobs/sources/java/opencsv-4.1.jar", + "s3a://data-proc-public/jobs/sources/java/json-20190722.jar", + ], + args=[ + "s3a://data-proc-public/jobs/sources/data/cities500.txt.bz2", + f"s3a://{S3_BUCKET_NAME_FOR_JOB_LOGS}/dataproc/job/results/${{JOB_ID}}", + ], + properties={ + "spark.submit.deployMode": "cluster", + }, + ) + + create_pyspark_job = DataprocCreatePysparkJobOperator( + task_id="create_pyspark_job", + main_python_file_uri="s3a://data-proc-public/jobs/sources/pyspark-001/main.py", + python_file_uris=[ + "s3a://data-proc-public/jobs/sources/pyspark-001/geonames.py", + ], + file_uris=[ + "s3a://data-proc-public/jobs/sources/data/config.json", + ], + archive_uris=[ + "s3a://data-proc-public/jobs/sources/data/country-codes.csv.zip", + ], + args=[ + "s3a://data-proc-public/jobs/sources/data/cities500.txt.bz2", + f"s3a://{S3_BUCKET_NAME_FOR_JOB_LOGS}/jobs/results/${{JOB_ID}}", + ], + jar_file_uris=[ + "s3a://data-proc-public/jobs/sources/java/dataproc-examples-1.0.jar", + "s3a://data-proc-public/jobs/sources/java/icu4j-61.1.jar", + "s3a://data-proc-public/jobs/sources/java/commons-lang-2.6.jar", + ], + properties={ + "spark.submit.deployMode": "cluster", + }, + ) + + delete_cluster = DataprocDeleteClusterOperator( + task_id="delete_cluster", + ) + + create_cluster >> create_mapreduce_job >> create_hive_query >> create_hive_query_from_file + create_hive_query_from_file >> create_spark_job >> create_pyspark_job >> delete_cluster diff --git a/reference/providers/yandex/hooks/__init__.py b/reference/providers/yandex/hooks/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/yandex/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/yandex/hooks/yandex.py b/reference/providers/yandex/hooks/yandex.py new file mode 100644 index 0000000..c1ca7d7 --- /dev/null +++ b/reference/providers/yandex/hooks/yandex.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import warnings +from typing import Any, Dict, Optional, Union + +import yandexcloud +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +class YandexCloudBaseHook(BaseHook): + """ + A base hook for Yandex.Cloud related tasks. + + :param connection_id: The connection ID to use when fetching connection info. + :type connection_id: str + """ + + conn_name_attr = "yandex_conn_id" + default_conn_name = "yandexcloud_default" + conn_type = "yandexcloud" + hook_name = "Yandex Cloud" + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import ( + BS3PasswordFieldWidget, + BS3TextFieldWidget, + ) + from flask_babel import lazy_gettext + from wtforms import PasswordField, StringField + + return { + "extra__yandexcloud__service_account_json": PasswordField( + lazy_gettext("Service account auth JSON"), + widget=BS3PasswordFieldWidget(), + description="Service account auth JSON. Looks like " + '{"id", "...", "service_account_id": "...", "private_key": "..."}. ' + "Will be used instead of OAuth token and SA JSON file path field if specified.", + ), + "extra__yandexcloud__service_account_json_path": StringField( + lazy_gettext("Service account auth JSON file path"), + widget=BS3TextFieldWidget(), + description="Service account auth JSON file path. File content looks like " + '{"id", "...", "service_account_id": "...", "private_key": "..."}. ' + "Will be used instead of OAuth token if specified.", + ), + "extra__yandexcloud__oauth": PasswordField( + lazy_gettext("OAuth Token"), + widget=BS3PasswordFieldWidget(), + description="User account OAuth token. " + "Either this or service account JSON must be specified.", + ), + "extra__yandexcloud__folder_id": StringField( + lazy_gettext("Default folder ID"), + widget=BS3TextFieldWidget(), + description="Optional. This folder will be used " + "to create all new clusters and nodes by default", + ), + "extra__yandexcloud__public_ssh_key": StringField( + lazy_gettext("Public SSH key"), + widget=BS3TextFieldWidget(), + description="Optional. This key will be placed to all created Compute nodes" + "to let you have a root shell there", + ), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], + "relabeling": {}, + } + + def __init__( + self, + # Connection id is deprecated. Use yandex_conn_id instead + connection_id: Optional[str] = None, + yandex_conn_id: Optional[str] = None, + default_folder_id: Union[dict, bool, None] = None, + default_public_ssh_key: Optional[str] = None, + ) -> None: + super().__init__() + if connection_id: + warnings.warn( + "Using `connection_id` is deprecated. Please use `yandex_conn_id` parameter.", + DeprecationWarning, + stacklevel=2, + ) + self.connection_id = yandex_conn_id or connection_id or self.default_conn_name + self.connection = self.get_connection(self.connection_id) + self.extras = self.connection.extra_dejson + credentials = self._get_credentials() + self.sdk = yandexcloud.SDK(**credentials) + self.default_folder_id = default_folder_id or self._get_field( + "folder_id", False + ) + self.default_public_ssh_key = default_public_ssh_key or self._get_field( + "public_ssh_key", False + ) + self.client = self.sdk.client + + def _get_credentials(self) -> Dict[str, Any]: + service_account_json_path = self._get_field("service_account_json_path", False) + service_account_json = self._get_field("service_account_json", False) + oauth_token = self._get_field("oauth", False) + if not (service_account_json or oauth_token or service_account_json_path): + raise AirflowException( + "No credentials are found in connection. Specify either service account " + + "authentication JSON or user OAuth token in Yandex.Cloud connection" + ) + if service_account_json_path: + with open(service_account_json_path) as infile: + service_account_json = infile.read() + if service_account_json: + service_account_key = json.loads(service_account_json) + return {"service_account_key": service_account_key} + else: + return {"token": oauth_token} + + def _get_field(self, field_name: str, default: Any = None) -> Any: + """Fetches a field from extras, and returns it.""" + long_f = f"extra__yandexcloud__{field_name}" + if hasattr(self, "extras") and long_f in self.extras: + return self.extras[long_f] + else: + return default diff --git a/reference/providers/yandex/hooks/yandexcloud_dataproc.py b/reference/providers/yandex/hooks/yandexcloud_dataproc.py new file mode 100644 index 0000000..cff2a99 --- /dev/null +++ b/reference/providers/yandex/hooks/yandexcloud_dataproc.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook + + +class DataprocHook(YandexCloudBaseHook): + """ + A base hook for Yandex.Cloud Data Proc. + + :param connection_id: The connection ID to use when fetching connection info. + :type connection_id: str + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.cluster_id = None + self.client = self.sdk.wrappers.Dataproc( + default_folder_id=self.default_folder_id, + default_public_ssh_key=self.default_public_ssh_key, + ) diff --git a/reference/providers/yandex/operators/__init__.py b/reference/providers/yandex/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/reference/providers/yandex/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/yandex/operators/yandexcloud_dataproc.py b/reference/providers/yandex/operators/yandexcloud_dataproc.py new file mode 100644 index 0000000..3dfc090 --- /dev/null +++ b/reference/providers/yandex/operators/yandexcloud_dataproc.py @@ -0,0 +1,517 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict, Iterable, Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.yandex.hooks.yandexcloud_dataproc import DataprocHook +from airflow.utils.decorators import apply_defaults + + +class DataprocCreateClusterOperator(BaseOperator): + """Creates Yandex.Cloud Data Proc cluster. + + :param folder_id: ID of the folder in which cluster should be created. + :type folder_id: Optional[str] + :param cluster_name: Cluster name. Must be unique inside the folder. + :type cluster_name: Optional[str] + :param cluster_description: Cluster description. + :type cluster_description: str + :param cluster_image_version: Cluster image version. Use default. + :type cluster_image_version: str + :param ssh_public_keys: List of SSH public keys that will be deployed to created compute instances. + :type ssh_public_keys: Optional[Union[str, Iterable[str]]] + :param subnet_id: ID of the subnetwork. All Data Proc cluster nodes will use one subnetwork. + :type subnet_id: str + :param services: List of services that will be installed to the cluster. Possible options: + HDFS, YARN, MAPREDUCE, HIVE, TEZ, ZOOKEEPER, HBASE, SQOOP, FLUME, SPARK, SPARK, ZEPPELIN, OOZIE + :type services: Iterable[str] + :param s3_bucket: Yandex.Cloud S3 bucket to store cluster logs. + Jobs will not work if the bucket is not specified. + :type s3_bucket: Optional[str] + :param zone: Availability zone to create cluster in. + Currently there are ru-central1-a, ru-central1-b and ru-central1-c. + :type zone: str + :param service_account_id: Service account id for the cluster. + Service account can be created inside the folder. + :type service_account_id: Optional[str] + :param masternode_resource_preset: Resources preset (CPU+RAM configuration) + for the master node of the cluster. + :type masternode_resource_preset: str + :param masternode_disk_size: Masternode storage size in GiB. + :type masternode_disk_size: int + :param masternode_disk_type: Masternode storage type. Possible options: network-ssd, network-hdd. + :type masternode_disk_type: str + :param datanode_resource_preset: Resources preset (CPU+RAM configuration) + for the data nodes of the cluster. + :type datanode_resource_preset: str + :param datanode_disk_size: Datanodes storage size in GiB. + :type datanode_disk_size: int + :param datanode_disk_type: Datanodes storage type. Possible options: network-ssd, network-hdd. + :type datanode_disk_type: str + :param computenode_resource_preset: Resources preset (CPU+RAM configuration) + for the compute nodes of the cluster. + :type computenode_resource_preset: str + :param computenode_disk_size: Computenodes storage size in GiB. + :type computenode_disk_size: int + :param computenode_disk_type: Computenodes storage type. Possible options: network-ssd, network-hdd. + :type computenode_disk_type: str + :param connection_id: ID of the Yandex.Cloud Airflow connection. + :type connection_id: Optional[str] + """ + + # pylint: disable=too-many-instance-attributes + # pylint: disable=too-many-arguments + # pylint: disable=too-many-locals + @apply_defaults + def __init__( + self, + *, + folder_id: Optional[str] = None, + cluster_name: Optional[str] = None, + cluster_description: str = "", + cluster_image_version: str = "1.1", + ssh_public_keys: Optional[Union[str, Iterable[str]]] = None, + subnet_id: Optional[str] = None, + services: Iterable[str] = ("HDFS", "YARN", "MAPREDUCE", "HIVE", "SPARK"), + s3_bucket: Optional[str] = None, + zone: str = "ru-central1-b", + service_account_id: Optional[str] = None, + masternode_resource_preset: str = "s2.small", + masternode_disk_size: int = 15, + masternode_disk_type: str = "network-ssd", + datanode_resource_preset: str = "s2.small", + datanode_disk_size: int = 15, + datanode_disk_type: str = "network-ssd", + datanode_count: int = 2, + computenode_resource_preset: str = "s2.small", + computenode_disk_size: int = 15, + computenode_disk_type: str = "network-ssd", + computenode_count: int = 0, + connection_id: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.folder_id = folder_id + self.connection_id = connection_id + self.cluster_name = cluster_name + self.cluster_description = cluster_description + self.cluster_image_version = cluster_image_version + self.ssh_public_keys = ssh_public_keys + self.subnet_id = subnet_id + self.services = services + self.s3_bucket = s3_bucket + self.zone = zone + self.service_account_id = service_account_id + self.masternode_resource_preset = masternode_resource_preset + self.masternode_disk_size = masternode_disk_size + self.masternode_disk_type = masternode_disk_type + self.datanode_resource_preset = datanode_resource_preset + self.datanode_disk_size = datanode_disk_size + self.datanode_disk_type = datanode_disk_type + self.datanode_count = datanode_count + self.computenode_resource_preset = computenode_resource_preset + self.computenode_disk_size = computenode_disk_size + self.computenode_disk_type = computenode_disk_type + self.computenode_count = computenode_count + self.hook: Optional[DataprocHook] = None + + def execute(self, context) -> None: + self.hook = DataprocHook( + connection_id=self.connection_id, + ) + operation_result = self.hook.client.create_cluster( + folder_id=self.folder_id, + cluster_name=self.cluster_name, + cluster_description=self.cluster_description, + cluster_image_version=self.cluster_image_version, + ssh_public_keys=self.ssh_public_keys, + subnet_id=self.subnet_id, + services=self.services, + s3_bucket=self.s3_bucket, + zone=self.zone, + service_account_id=self.service_account_id, + masternode_resource_preset=self.masternode_resource_preset, + masternode_disk_size=self.masternode_disk_size, + masternode_disk_type=self.masternode_disk_type, + datanode_resource_preset=self.datanode_resource_preset, + datanode_disk_size=self.datanode_disk_size, + datanode_disk_type=self.datanode_disk_type, + datanode_count=self.datanode_count, + computenode_resource_preset=self.computenode_resource_preset, + computenode_disk_size=self.computenode_disk_size, + computenode_disk_type=self.computenode_disk_type, + computenode_count=self.computenode_count, + ) + context["task_instance"].xcom_push( + key="cluster_id", value=operation_result.response.id + ) + context["task_instance"].xcom_push( + key="yandexcloud_connection_id", value=self.connection_id + ) + + +class DataprocDeleteClusterOperator(BaseOperator): + """Deletes Yandex.Cloud Data Proc cluster. + + :param connection_id: ID of the Yandex.Cloud Airflow connection. + :type connection_id: Optional[str] + :param cluster_id: ID of the cluster to remove. (templated) + :type cluster_id: Optional[str] + """ + + template_fields = ["cluster_id"] + + @apply_defaults + def __init__( + self, + *, + connection_id: Optional[str] = None, + cluster_id: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.connection_id = connection_id + self.cluster_id = cluster_id + self.hook: Optional[DataprocHook] = None + + def execute(self, context) -> None: + cluster_id = self.cluster_id or context["task_instance"].xcom_pull( + key="cluster_id" + ) + connection_id = self.connection_id or context["task_instance"].xcom_pull( + key="yandexcloud_connection_id" + ) + self.hook = DataprocHook( + connection_id=connection_id, + ) + self.hook.client.delete_cluster(cluster_id) + + +class DataprocCreateHiveJobOperator(BaseOperator): + """Runs Hive job in Data Proc cluster. + + :param query: Hive query. + :type query: Optional[str] + :param query_file_uri: URI of the script that contains Hive queries. Can be placed in HDFS or S3. + :type query_file_uri: Optional[str] + :param properties: A mapping of property names to values, used to configure Hive. + :type properties: Optional[Dist[str, str]] + :param script_variables: Mapping of query variable names to values. + :type script_variables: Optional[Dist[str, str]] + :param continue_on_failure: Whether to continue executing queries if a query fails. + :type continue_on_failure: bool + :param name: Name of the job. Used for labeling. + :type name: str + :param cluster_id: ID of the cluster to run job in. + Will try to take the ID from Dataproc Hook object if ot specified. (templated) + :type cluster_id: Optional[str] + :param connection_id: ID of the Yandex.Cloud Airflow connection. + :type connection_id: Optional[str] + """ + + template_fields = ["cluster_id"] + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + query: Optional[str] = None, + query_file_uri: Optional[str] = None, + script_variables: Optional[Dict[str, str]] = None, + continue_on_failure: bool = False, + properties: Optional[Dict[str, str]] = None, + name: str = "Hive job", + cluster_id: Optional[str] = None, + connection_id: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.query = query + self.query_file_uri = query_file_uri + self.script_variables = script_variables + self.continue_on_failure = continue_on_failure + self.properties = properties + self.name = name + self.cluster_id = cluster_id + self.connection_id = connection_id + self.hook: Optional[DataprocHook] = None + + def execute(self, context) -> None: + cluster_id = self.cluster_id or context["task_instance"].xcom_pull( + key="cluster_id" + ) + connection_id = self.connection_id or context["task_instance"].xcom_pull( + key="yandexcloud_connection_id" + ) + self.hook = DataprocHook( + connection_id=connection_id, + ) + self.hook.client.create_hive_job( + query=self.query, + query_file_uri=self.query_file_uri, + script_variables=self.script_variables, + continue_on_failure=self.continue_on_failure, + properties=self.properties, + name=self.name, + cluster_id=cluster_id, + ) + + +class DataprocCreateMapReduceJobOperator(BaseOperator): + """Runs Mapreduce job in Data Proc cluster. + + :param main_jar_file_uri: URI of jar file with job. + Can be placed in HDFS or S3. Can be specified instead of main_class. + :type main_class: Optional[str] + :param main_class: Name of the main class of the job. Can be specified instead of main_jar_file_uri. + :type main_class: Optional[str] + :param file_uris: URIs of files used in the job. Can be placed in HDFS or S3. + :type file_uris: Optional[Iterable[str]] + :param archive_uris: URIs of archive files used in the job. Can be placed in HDFS or S3. + :type archive_uris: Optional[Iterable[str]] + :param jar_file_uris: URIs of JAR files used in the job. Can be placed in HDFS or S3. + :type archive_uris: Optional[Iterable[str]] + :param properties: Properties for the job. + :type properties: Optional[Dist[str, str]] + :param args: Arguments to be passed to the job. + :type args: Optional[Iterable[str]] + :param name: Name of the job. Used for labeling. + :type name: str + :param cluster_id: ID of the cluster to run job in. + Will try to take the ID from Dataproc Hook object if ot specified. (templated) + :type cluster_id: Optional[str] + :param connection_id: ID of the Yandex.Cloud Airflow connection. + :type connection_id: Optional[str] + """ + + template_fields = ["cluster_id"] + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + main_class: Optional[str] = None, + main_jar_file_uri: Optional[str] = None, + jar_file_uris: Optional[Iterable[str]] = None, + archive_uris: Optional[Iterable[str]] = None, + file_uris: Optional[Iterable[str]] = None, + args: Optional[Iterable[str]] = None, + properties: Optional[Dict[str, str]] = None, + name: str = "Mapreduce job", + cluster_id: Optional[str] = None, + connection_id: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.main_class = main_class + self.main_jar_file_uri = main_jar_file_uri + self.jar_file_uris = jar_file_uris + self.archive_uris = archive_uris + self.file_uris = file_uris + self.args = args + self.properties = properties + self.name = name + self.cluster_id = cluster_id + self.connection_id = connection_id + self.hook: Optional[DataprocHook] = None + + def execute(self, context) -> None: + cluster_id = self.cluster_id or context["task_instance"].xcom_pull( + key="cluster_id" + ) + connection_id = self.connection_id or context["task_instance"].xcom_pull( + key="yandexcloud_connection_id" + ) + self.hook = DataprocHook( + connection_id=connection_id, + ) + self.hook.client.create_mapreduce_job( + main_class=self.main_class, + main_jar_file_uri=self.main_jar_file_uri, + jar_file_uris=self.jar_file_uris, + archive_uris=self.archive_uris, + file_uris=self.file_uris, + args=self.args, + properties=self.properties, + name=self.name, + cluster_id=cluster_id, + ) + + +class DataprocCreateSparkJobOperator(BaseOperator): + """Runs Spark job in Data Proc cluster. + + :param main_jar_file_uri: URI of jar file with job. Can be placed in HDFS or S3. + :type main_class: Optional[str] + :param main_class: Name of the main class of the job. + :type main_class: Optional[str] + :param file_uris: URIs of files used in the job. Can be placed in HDFS or S3. + :type file_uris: Optional[Iterable[str]] + :param archive_uris: URIs of archive files used in the job. Can be placed in HDFS or S3. + :type archive_uris: Optional[Iterable[str]] + :param jar_file_uris: URIs of JAR files used in the job. Can be placed in HDFS or S3. + :type archive_uris: Optional[Iterable[str]] + :param properties: Properties for the job. + :type properties: Optional[Dist[str, str]] + :param args: Arguments to be passed to the job. + :type args: Optional[Iterable[str]] + :param name: Name of the job. Used for labeling. + :type name: str + :param cluster_id: ID of the cluster to run job in. + Will try to take the ID from Dataproc Hook object if ot specified. (templated) + :type cluster_id: Optional[str] + :param connection_id: ID of the Yandex.Cloud Airflow connection. + :type connection_id: Optional[str] + """ + + template_fields = ["cluster_id"] + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + main_class: Optional[str] = None, + main_jar_file_uri: Optional[str] = None, + jar_file_uris: Optional[Iterable[str]] = None, + archive_uris: Optional[Iterable[str]] = None, + file_uris: Optional[Iterable[str]] = None, + args: Optional[Iterable[str]] = None, + properties: Optional[Dict[str, str]] = None, + name: str = "Spark job", + cluster_id: Optional[str] = None, + connection_id: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.main_class = main_class + self.main_jar_file_uri = main_jar_file_uri + self.jar_file_uris = jar_file_uris + self.archive_uris = archive_uris + self.file_uris = file_uris + self.args = args + self.properties = properties + self.name = name + self.cluster_id = cluster_id + self.connection_id = connection_id + self.hook: Optional[DataprocHook] = None + + def execute(self, context) -> None: + cluster_id = self.cluster_id or context["task_instance"].xcom_pull( + key="cluster_id" + ) + connection_id = self.connection_id or context["task_instance"].xcom_pull( + key="yandexcloud_connection_id" + ) + self.hook = DataprocHook( + connection_id=connection_id, + ) + self.hook.client.create_spark_job( + main_class=self.main_class, + main_jar_file_uri=self.main_jar_file_uri, + jar_file_uris=self.jar_file_uris, + archive_uris=self.archive_uris, + file_uris=self.file_uris, + args=self.args, + properties=self.properties, + name=self.name, + cluster_id=cluster_id, + ) + + +class DataprocCreatePysparkJobOperator(BaseOperator): + """Runs Pyspark job in Data Proc cluster. + + :param main_python_file_uri: URI of python file with job. Can be placed in HDFS or S3. + :type main_python_file_uri: Optional[str] + :param python_file_uris: URIs of python files used in the job. Can be placed in HDFS or S3. + :type python_file_uris: Optional[Iterable[str]] + :param file_uris: URIs of files used in the job. Can be placed in HDFS or S3. + :type file_uris: Optional[Iterable[str]] + :param archive_uris: URIs of archive files used in the job. Can be placed in HDFS or S3. + :type archive_uris: Optional[Iterable[str]] + :param jar_file_uris: URIs of JAR files used in the job. Can be placed in HDFS or S3. + :type archive_uris: Optional[Iterable[str]] + :param properties: Properties for the job. + :type properties: Optional[Dist[str, str]] + :param args: Arguments to be passed to the job. + :type args: Optional[Iterable[str]] + :param name: Name of the job. Used for labeling. + :type name: str + :param cluster_id: ID of the cluster to run job in. + Will try to take the ID from Dataproc Hook object if ot specified. (templated) + :type cluster_id: Optional[str] + :param connection_id: ID of the Yandex.Cloud Airflow connection. + :type connection_id: Optional[str] + """ + + template_fields = ["cluster_id"] + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__( + self, + *, + main_python_file_uri: Optional[str] = None, + python_file_uris: Optional[Iterable[str]] = None, + jar_file_uris: Optional[Iterable[str]] = None, + archive_uris: Optional[Iterable[str]] = None, + file_uris: Optional[Iterable[str]] = None, + args: Optional[Iterable[str]] = None, + properties: Optional[Dict[str, str]] = None, + name: str = "Pyspark job", + cluster_id: Optional[str] = None, + connection_id: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.main_python_file_uri = main_python_file_uri + self.python_file_uris = python_file_uris + self.jar_file_uris = jar_file_uris + self.archive_uris = archive_uris + self.file_uris = file_uris + self.args = args + self.properties = properties + self.name = name + self.cluster_id = cluster_id + self.connection_id = connection_id + self.hook: Optional[DataprocHook] = None + + def execute(self, context) -> None: + cluster_id = self.cluster_id or context["task_instance"].xcom_pull( + key="cluster_id" + ) + connection_id = self.connection_id or context["task_instance"].xcom_pull( + key="yandexcloud_connection_id" + ) + self.hook = DataprocHook( + connection_id=connection_id, + ) + self.hook.client.create_pyspark_job( + main_python_file_uri=self.main_python_file_uri, + python_file_uris=self.python_file_uris, + jar_file_uris=self.jar_file_uris, + archive_uris=self.archive_uris, + file_uris=self.file_uris, + args=self.args, + properties=self.properties, + name=self.name, + cluster_id=cluster_id, + ) diff --git a/reference/providers/yandex/provider.yaml b/reference/providers/yandex/provider.yaml new file mode 100644 index 0000000..aab7503 --- /dev/null +++ b/reference/providers/yandex/provider.yaml @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-yandex +name: Yandex +description: | + Yandex including `Yandex.Cloud `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Yandex.Cloud + external-doc-url: https://cloud.yandex.com/ + logo: /integration-logos/yandex/Yandex-Cloud.png + tags: [service] + + - integration-name: Yandex.Cloud Dataproc + external-doc-url: https://cloud.yandex.com/dataproc + how-to-guide: + - /docs/apache-airflow-providers-yandex/operators.rst + logo: /integration-logos/yandex/Yandex-Cloud.png + tags: [service] + +operators: + - integration-name: Yandex.Cloud Dataproc + python-modules: + - airflow.providers.yandex.operators.yandexcloud_dataproc + +hooks: + - integration-name: Yandex.Cloud + python-modules: + - airflow.providers.yandex.hooks.yandex + - integration-name: Yandex.Cloud Dataproc + python-modules: + - airflow.providers.yandex.hooks.yandexcloud_dataproc + +hook-class-names: + - airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook diff --git a/reference/providers/zendesk/CHANGELOG.rst b/reference/providers/zendesk/CHANGELOG.rst new file mode 100644 index 0000000..12fea86 --- /dev/null +++ b/reference/providers/zendesk/CHANGELOG.rst @@ -0,0 +1,30 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.1 +..... + +Updated documentation and readme files. + +1.0.0 +..... + +Initial version of the provider. diff --git a/reference/providers/zendesk/__init__.py b/reference/providers/zendesk/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/zendesk/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/zendesk/hooks/__init__.py b/reference/providers/zendesk/hooks/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/reference/providers/zendesk/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/reference/providers/zendesk/hooks/zendesk.py b/reference/providers/zendesk/hooks/zendesk.py new file mode 100644 index 0000000..fbc15ee --- /dev/null +++ b/reference/providers/zendesk/hooks/zendesk.py @@ -0,0 +1,119 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +from typing import Optional + +from airflow.hooks.base import BaseHook +from zdesk import RateLimitError, Zendesk, ZendeskError + + +class ZendeskHook(BaseHook): + """A hook to talk to Zendesk""" + + def __init__(self, zendesk_conn_id: str) -> None: + super().__init__() + self.__zendesk_conn_id = zendesk_conn_id + self.__url = None + + def get_conn(self) -> Zendesk: + conn = self.get_connection(self.__zendesk_conn_id) + self.__url = "https://" + conn.host + return Zendesk( + zdesk_url=self.__url, + zdesk_email=conn.login, + zdesk_password=conn.password, + zdesk_token=True, + ) + + def __handle_rate_limit_exception(self, rate_limit_exception: ZendeskError) -> None: + """ + Sleep for the time specified in the exception. If not specified, wait + for 60 seconds. + """ + retry_after = int(rate_limit_exception.response.headers.get("Retry-After", 60)) + self.log.info("Hit Zendesk API rate limit. Pausing for %s seconds", retry_after) + time.sleep(retry_after) + + def call( + self, + path: str, + query: Optional[dict] = None, + get_all_pages: bool = True, + side_loading: bool = False, + ) -> dict: + """ + Call Zendesk API and return results + + :param path: The Zendesk API to call + :param query: Query parameters + :param get_all_pages: Accumulate results over all pages before + returning. Due to strict rate limiting, this can often timeout. + Waits for recommended period between tries after a timeout. + :param side_loading: Retrieve related records as part of a single + request. In order to enable side-loading, add an 'include' + query parameter containing a comma-separated list of resources + to load. For more information on side-loading see + https://developer.zendesk.com/rest_api/docs/core/side_loading + """ + query_params = query or {} + zendesk = self.get_conn() + first_request_successful = False + + while not first_request_successful: + try: + results = zendesk.call(path, query_params) + first_request_successful = True + except RateLimitError as rle: + self.__handle_rate_limit_exception(rle) + + # Find the key with the results + keys = [path.split("/")[-1].split(".json")[0]] + next_page = results["next_page"] + if side_loading: + keys += query_params["include"].split(",") + results = {key: results[key] for key in keys} + + # pylint: disable=too-many-nested-blocks + if get_all_pages: + while next_page is not None: + try: + # Need to split because the next page URL has + # `github.zendesk...` + # in it, but the call function needs it removed. + next_url = next_page.split(self.__url)[1] + self.log.info("Calling %s", next_url) + more_res = zendesk.call(next_url) + for key in results: + results[key].extend(more_res[key]) + if next_page == more_res["next_page"]: + # Unfortunately zdesk doesn't always throw ZendeskError + # when we are done getting all the data. Sometimes the + # next just refers to the current set of results. + # Hence, need to deal with this special case + break + next_page = more_res["next_page"] + except RateLimitError as rle: + self.__handle_rate_limit_exception(rle) + except ZendeskError as zde: + if b"Use a start_time older than 5 minutes" in zde.msg: + # We have pretty up to date data + break + raise zde + + return results diff --git a/reference/providers/zendesk/provider.yaml b/reference/providers/zendesk/provider.yaml new file mode 100644 index 0000000..09b092c --- /dev/null +++ b/reference/providers/zendesk/provider.yaml @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-zendesk +name: Zendesk +description: | + `Zendesk `__ + +versions: + - 1.0.1 + - 1.0.0 + +integrations: + - integration-name: Zendesk + external-doc-url: https://www.zendesk.com/ + logo: /integration-logos/zendesk/Zendesk.png + tags: [software] + +hooks: + - integration-name: Zendesk + python-modules: + - airflow.providers.zendesk.hooks.zendesk diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..464e3ff --- /dev/null +++ b/requirements.txt @@ -0,0 +1,29 @@ +beautifulsoup4 +selenium +azure.storage.blob +flask_admin +oauth2client +apprise +random_user_agent +tqdm +PyYAML +pylint-airflow +pandas +lxml +undetected-chromedriver +# tap-google-sheets +dag-factory +boa-str +simple-dag-editor +apache-airflow[discord] +apache-airflow[hashicorp] +apache-airflow[docker] +pdfkit +black +# csvkit +# httpie +MechanicalSoup +prettify +# dbt +docker +petl diff --git a/scripts/cronicle-backup.sh b/scripts/cronicle-backup.sh new file mode 100755 index 0000000..b4780d9 --- /dev/null +++ b/scripts/cronicle-backup.sh @@ -0,0 +1,8 @@ +#!/bin/bash +DATE_STAMP=`date "+%Y-%m-%d"` +BACKUP_DIR="/backup/cronicle/data" +BACKUP_FILE="$BACKUP_DIR/backup-$DATE_STAMP.txt" + +mkdir -p $BACKUP_DIR +/opt/cronicle/bin/control.sh export $BACKUP_FILE --verbose +find $BACKUP_DIR -mtime +365 -type f -exec rm -v {} \; diff --git a/scripts/cronicle-init.sh b/scripts/cronicle-init.sh new file mode 100755 index 0000000..0c09ab0 --- /dev/null +++ b/scripts/cronicle-init.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +HOME_DIR=/opt/cronicle/ +BIN_DIR=/opt/cronicle/bin/ +DATA_DIR=/opt/cronicle/data/ +LOGS_DIR=/opt/cronicle/logs/ +QUEUE_DIR=/opt/cronicle/queue/ + +mkdir -p $DATA_DIR $LOGS_DIR $QUEUE_DIR + +export CRONICLE_foreground=1 +export CRONICLE_Storage__Filesystem__base_dir=$DATA_DIR +export CRONICLE_log_dir=$LOGS_DIR +export CRONICLE_echo=${CRONICLE_echo:-1} +#export CRONICLE_pid_file=$HOME_DIR/cronicle.pid + +if [ ! "$(ls -A $DATA_DIR)" ]; then + echo "$(date -I'seconds') INFO $DATA_DIR is empty, running setup ..." + ${BIN_DIR}/control.sh setup + echo "$(date -I'seconds') INFO done" +fi + +chown nonroot:nonroot -R ${DATA_DIR} +chown nonroot:nonroot -R ${LOGS_DIR} +chown nonroot:nonroot -R ${QUEUE_DIR} + +rm /data/logs/cronicle/cronicled.pid + +# supervisord -c "/etc/supervisord.conf" + +/opt/cronicle/bin/control.sh start diff --git a/scripts/pg-init-scripts/init-user-db.sh b/scripts/pg-init-scripts/init-user-db.sh new file mode 100644 index 0000000..72a1661 --- /dev/null +++ b/scripts/pg-init-scripts/init-user-db.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -e + +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL + CREATE USER test WITH PASSWORD 'postgres'; + CREATE DATABASE test; + GRANT ALL PRIVILEGES ON DATABASE test TO test; +EOSQL \ No newline at end of file diff --git a/scripts/post-scrape.sh b/scripts/post-scrape.sh new file mode 100755 index 0000000..b7df35e --- /dev/null +++ b/scripts/post-scrape.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +set -e + +NOW=$(date +"%Y-%m-%d_%H-%M-%S") + +# Loop through arguments and process them +for arg in "$@" +do + case $arg in + -s|--site) + SITE="$2" + shift # Remove argument name from processing + shift # Remove argument value from processing + ;; + # *) + -k|--keyword) + KEYWORD="$2" + shift # Remove argument name from processing + shift # Remove argument value from processing + ;; + esac +done + +cd /data/scripts; + +echo "$PWD" + +# python3 -c "import selenium; print(selenium.__version__)" + +# python3 /data/scripts/target/scraper.py -k gitlab +# python3 /data/scripts/linkedin/scraper.py -k gitlab +# echo "Searching $SITE for $KEYWORD" +# python3 /data/scripts/gather/"$SITE".py -k "$KEYWORD" +# echo "List Generated" + +echo "Archiving raw JSON" +sudo cp /data/data/staging/"$SITE"/jobs.json /data/data/archive/"$SITE"/json/whitelist/"$NOW".json +sudo cp /data/data/staging/"$SITE"/jobs-blacklist.json /data/data/archive/"$SITE"/json/blacklist/"$NOW".json +echo "Archived /data/data/archive/$SITE/json/$NOW.json" +echo "Archived /data/data/archive/$SITE/json/blacklist-$NOW.json" + +echo "Archiving raw CSV" +sudo cp /data/data/staging/"$SITE"/jobs.csv /data/data/archive/"$SITE"/csv/whitelist/"$NOW".csv +sudo cp /data/data/staging/"$SITE"/jobs-blacklist.csv /data/data/archive/"$SITE"/csv/blacklist/"$NOW".csv +echo "Archived /data/data/archive/$SITE/csv/$NOW.csv" + +# echo "Archiving HTML" +# cp "$SITE"-page.html /data/data/archive/"$SITE"/html/page-"$NOW".html +# cp "$SITE"-snippet.html /data/data/archive/"$SITE"/html/snippet-"$NOW".html +# echo "Archived /data/data/archive/$SITE/html/page-$NOW.html" +# echo "Archived /data/data/archive/$SITE/html/snippet-$NOW.html" + +# echo "Archiving screenshots" +# cp "$SITE"-scroll.png /data/data/archive/"$SITE"/png/scroll-"$NOW".png +# echo "Archived /data/data/archive/$SITE/png/$NOW.png" + +echo "Cleanslating" +# mv "$SITE"-scroll.png /data/data/archive/"$SITE"/last/"$SITE"-scroll.png +# mv "$SITE"-page.html /data/data/archive/"$SITE"/last/"$SITE"-page.html +# mv "$SITE"-snippet.html /data/data/archive/"$SITE"/last/"$SITE"-snippet.html +sudo mv /data/data/staging/"$SITE"/jobs.json /data/data/archive/"$SITE"/last/jobs.json +sudo mv /data/data/staging/"$SITE"/jobs.csv /data/data/archive/"$SITE"/last/jobs.csv +sudo mv /data/data/staging/"$SITE"/jobs-blacklist.json /data/data/archive/"$SITE"/last/jobs-blacklist.json +sudo mv /data/data/staging/"$SITE"/jobs-blacklist.csv /data/data/archive/"$SITE"/last/jobs-blacklist.csv +echo "Cleanslated" + +# cat target.json | sqlite-utils insert --alter /data/data/datasette/zip.db $NOW - +# echo "Adding to master table" +# sqlite-utils insert /data/data/datasette/"$SITE".db "$NOW" "$SITE".json --alter --truncate --pk=id +# echo "Table added" + +# sqlite-utils analyze-tables /data/data/datasette/target.db --save diff --git a/scripts/scrape.sh b/scripts/scrape.sh new file mode 100755 index 0000000..af66658 --- /dev/null +++ b/scripts/scrape.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +set -e + +NOW=$(date +"%Y-%m-%d_%H-%M-%S") + +# Loop through arguments and process them +for arg in "$@" +do + case $arg in + -s|--site) + SITE="$2" + shift # Remove argument name from processing + shift # Remove argument value from processing + ;; + # *) + -k|--keyword) + KEYWORD="$2" + shift # Remove argument name from processing + shift # Remove argument value from processing + ;; + esac +done + +cd /data/scripts; + +echo "$PWD" + +python3 -c "import selenium; print(selenium.__version__)" + +# python3 /data/scripts/target/scraper.py -k gitlab +# python3 /data/scripts/linkedin/scraper.py -k gitlab +echo "Searching $SITE for $KEYWORD" +python3 /data/scripts/gather/"$SITE".py -k "$KEYWORD" +echo "List Generated" + +echo "Archiving raw JSON" +cp "$SITE".json /data/data/archive/"$SITE"/json/"$NOW".json +cp "$SITE"-blacklist.json /data/data/archive/"$SITE"/json/blacklist-"$NOW".json +echo "Archived /data/data/archive/$SITE/json/$NOW.json" + +echo "Archiving raw CSV" +cp "$SITE".csv /data/data/archive/"$SITE"/csv/"$NOW".csv +echo "Archived /data/data/archive/$SITE/csv/$NOW.csv" + +echo "Archiving HTML" +cp "$SITE"-page.html /data/data/archive/"$SITE"/html/page-"$NOW".html +cp "$SITE"-snippet.html /data/data/archive/"$SITE"/html/snippet-"$NOW".html +echo "Archived /data/data/archive/$SITE/html/page-$NOW.html" +echo "Archived /data/data/archive/$SITE/html/snippet-$NOW.html" + +echo "Archiving screenshots" +cp "$SITE"-scroll.png /data/data/archive/"$SITE"/png/scroll-"$NOW".png +echo "Archived /data/data/archive/$SITE/png/$NOW.png" + +echo "Cleanslating" +mv "$SITE"-scroll.png /data/data/archive/"$SITE"/last/"$SITE"-scroll.png +mv "$SITE"-page.html /data/data/archive/"$SITE"/last/"$SITE"-page.html +mv "$SITE"-snippet.html /data/data/archive/"$SITE"/last/"$SITE"-snippet.html +mv "$SITE".csv /data/data/archive/"$SITE"/last/"$SITE".csv +mv "$SITE".json /data/data/archive/"$SITE"/last/"$SITE".json +mv "$SITE"-blacklist.json /data/data/archive/"$SITE"/last/"$SITE"-blacklist.json +echo "Cleanslated" + +# cat target.json | sqlite-utils insert --alter /data/data/datasette/zip.db $NOW - +# echo "Adding to master table" +# sqlite-utils insert /data/data/datasette/"$SITE".db "$NOW" "$SITE".json --alter --truncate --pk=id +# echo "Table added" + +# sqlite-utils analyze-tables /data/data/datasette/target.db --save