first commit

master
Don Aldrich 3 years ago
commit 88398a1491
  1. 2
      .gitignore
  2. 46
      Dockerfile
  3. 0
      README.md
  4. 0
      __main__ .py
  5. 1032
      config/airflow.cfg
  6. 11
      config/airflow_local_settings.py
  7. 5
      config/log_config.py
  8. 23
      config/vault.py
  9. 130
      config/webserver_config.py
  10. 264
      dags/complex.py
  11. 93
      dags/crawler.py
  12. 125
      dags/debug.py
  13. 101
      dags/dev/evade.py
  14. 27
      dags/dev/form.py
  15. 270
      dags/dev/log_config.py
  16. 28
      dags/dev/menu.py
  17. 32
      dags/dev/singer.py
  18. 181
      dags/dev/test.py
  19. 62
      dags/docker.py
  20. 114
      dags/etl.py
  21. 61
      dags/postgres.py
  22. 126
      dags/remote_browser.py
  23. 41
      dags/s3.py
  24. 154
      dags/scraper.py
  25. 65
      dags/selenium.py
  26. 1
      dags/sql/search.sql
  27. 122
      dags/sqlite.py
  28. 137
      dags/tutorial_etl_dag.py
  29. 32
      dags/vault.py
  30. 186
      docker-compose.yml
  31. 0
      plugins/__init__.py
  32. 0
      plugins/config/__init__.py
  33. 49
      plugins/config/logger.py
  34. 145
      plugins/hooks/discord/discord.py
  35. 108
      plugins/hooks/google.py
  36. 95
      plugins/hooks/selenium_hook.py
  37. 2
      plugins/operators/__init__.py
  38. 95
      plugins/operators/discord.py
  39. 254
      plugins/operators/google.py
  40. 24
      plugins/operators/selenium_operator.py
  41. 45
      plugins/operators/singer.py
  42. 110
      plugins/operators/sudo_bash.py
  43. 15
      plugins/tools/GoogleSheetsPlugin.py
  44. 13
      plugins/tools/singer.py
  45. 25
      reference/providers/airbyte/CHANGELOG.rst
  46. 17
      reference/providers/airbyte/__init__.py
  47. 16
      reference/providers/airbyte/example_dags/__init__.py
  48. 64
      reference/providers/airbyte/example_dags/example_airbyte_trigger_job.py
  49. 17
      reference/providers/airbyte/hooks/__init__.py
  50. 123
      reference/providers/airbyte/hooks/airbyte.py
  51. 17
      reference/providers/airbyte/operators/__init__.py
  52. 89
      reference/providers/airbyte/operators/airbyte.py
  53. 51
      reference/providers/airbyte/provider.yaml
  54. 16
      reference/providers/airbyte/sensors/__init__.py
  55. 75
      reference/providers/airbyte/sensors/airbyte.py
  56. 63
      reference/providers/amazon/CHANGELOG.rst
  57. 16
      reference/providers/amazon/__init__.py
  58. 16
      reference/providers/amazon/aws/__init__.py
  59. 16
      reference/providers/amazon/aws/example_dags/__init__.py
  60. 71
      reference/providers/amazon/aws/example_dags/example_datasync_1.py
  61. 100
      reference/providers/amazon/aws/example_dags/example_datasync_2.py
  62. 82
      reference/providers/amazon/aws/example_dags/example_ecs_fargate.py
  63. 96
      reference/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py
  64. 114
      reference/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py
  65. 70
      reference/providers/amazon/aws/example_dags/example_glacier_to_gcs.py
  66. 141
      reference/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_advanced.py
  67. 57
      reference/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_basic.py
  68. 53
      reference/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py
  69. 69
      reference/providers/amazon/aws/example_dags/example_s3_bucket.py
  70. 73
      reference/providers/amazon/aws/example_dags/example_s3_bucket_tagging.py
  71. 90
      reference/providers/amazon/aws/example_dags/example_s3_to_redshift.py
  72. 16
      reference/providers/amazon/aws/hooks/__init__.py
  73. 263
      reference/providers/amazon/aws/hooks/athena.py
  74. 29
      reference/providers/amazon/aws/hooks/aws_dynamodb.py
  75. 596
      reference/providers/amazon/aws/hooks/base_aws.py
  76. 556
      reference/providers/amazon/aws/hooks/batch_client.py
  77. 105
      reference/providers/amazon/aws/hooks/batch_waiters.json
  78. 242
      reference/providers/amazon/aws/hooks/batch_waiters.py
  79. 79
      reference/providers/amazon/aws/hooks/cloud_formation.py
  80. 334
      reference/providers/amazon/aws/hooks/datasync.py
  81. 67
      reference/providers/amazon/aws/hooks/dynamodb.py
  82. 82
      reference/providers/amazon/aws/hooks/ec2.py
  83. 325
      reference/providers/amazon/aws/hooks/elasticache_replication_group.py
  84. 101
      reference/providers/amazon/aws/hooks/emr.py
  85. 80
      reference/providers/amazon/aws/hooks/glacier.py
  86. 205
      reference/providers/amazon/aws/hooks/glue.py
  87. 142
      reference/providers/amazon/aws/hooks/glue_catalog.py
  88. 184
      reference/providers/amazon/aws/hooks/glue_crawler.py
  89. 50
      reference/providers/amazon/aws/hooks/kinesis.py
  90. 69
      reference/providers/amazon/aws/hooks/lambda_function.py
  91. 105
      reference/providers/amazon/aws/hooks/logs.py
  92. 131
      reference/providers/amazon/aws/hooks/redshift.py
  93. 973
      reference/providers/amazon/aws/hooks/s3.py
  94. 1039
      reference/providers/amazon/aws/hooks/sagemaker.py
  95. 71
      reference/providers/amazon/aws/hooks/secrets_manager.py
  96. 99
      reference/providers/amazon/aws/hooks/ses.py
  97. 96
      reference/providers/amazon/aws/hooks/sns.py
  98. 87
      reference/providers/amazon/aws/hooks/sqs.py
  99. 82
      reference/providers/amazon/aws/hooks/step_function.py
  100. 16
      reference/providers/amazon/aws/log/__init__.py
  101. Some files were not shown because too many files have changed in this diff Show More

2
.gitignore vendored

@ -0,0 +1,2 @@
.DS_Store
*/.DS_Store

@ -0,0 +1,46 @@
FROM apache/airflow:2.0.2-python3.8
USER root
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
git \
pylint \
libpq-dev \
python-dev \
# gcc \
&& apt-get remove python-cffi \
&& apt-get autoremove -yqq --purge \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN /usr/local/bin/python -m pip install --upgrade pip
# COPY --from=docker /usr/local/bin/docker /usr/local/bin/docker
RUN curl -sSL https://get.docker.com/ | sh
# COPY requirements.txt ./
# RUN pip install --no-cache-dir -r requirements.txt
# ENV PYTHONPATH "${PYTHONPATH}:/your/custom/path"
ENV AIRFLOW_UID=1000
ENV AIRFLOW_GID=1000
RUN echo airflow ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/airflow
RUN chmod 0440 /etc/sudoers.d/airflow
RUN echo 'airflow:airflow' | chpasswd
RUN groupmod --gid ${AIRFLOW_GID} airflow
RUN groupmod --gid 998 docker
# adduser --quiet "airflow" --uid "${AIRFLOW_UID}" \
# --gid "${AIRFLOW_GID}" \
# --home "${AIRFLOW_USER_HOME_DIR}"
RUN usermod --gid "${AIRFLOW_GID}" --uid "${AIRFLOW_UID}" airflow
RUN usermod -aG docker airflow
RUN chown -R ${AIRFLOW_UID}:${AIRFLOW_GID} /home/airflow
RUN chown -R ${AIRFLOW_UID}:${AIRFLOW_GID} /opt/airflow
WORKDIR /opt/airflow
# RUN usermod -g ${AIRFLOW_GID} airflow -G airflow
USER ${AIRFLOW_UID}
COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt

File diff suppressed because it is too large Load Diff

@ -0,0 +1,11 @@
STATE_COLORS = {
"queued": "darkgray",
"running": "blue",
"success": "cobalt",
"failed": "firebrick",
"up_for_retry": "yellow",
"up_for_reschedule": "turquoise",
"upstream_failed": "orange",
"skipped": "darkorchid",
"scheduled": "tan",
}

@ -0,0 +1,5 @@
from copy import deepcopy
from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG
LOGGING_CONFIG = deepcopy(DEFAULT_LOGGING_CONFIG)

@ -0,0 +1,23 @@
from datetime import datetime
from airflow import DAG
from airflow.hooks.base_hook import BaseHook
from airflow.operators.python_operator import PythonOperator
def get_secrets(**kwargs):
conn = BaseHook.get_connection(kwargs["my_conn_id"])
print(
f"Password: {conn.password}, Login: {conn.login}, URI: {conn.get_uri()}, Host: {conn.host}"
)
with DAG(
"example_secrets_dags", start_date=datetime(2020, 1, 1), schedule_interval=None
) as dag:
test_task = PythonOperator(
task_id="test-task",
python_callable=get_secrets,
op_kwargs={"my_conn_id": "smtp_default"},
)

@ -0,0 +1,130 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Default configuration for the Airflow webserver"""
import os
from flask_appbuilder.security.manager import AUTH_DB
# from flask_appbuilder.security.manager import AUTH_LDAP
# from flask_appbuilder.security.manager import AUTH_OAUTH
# from flask_appbuilder.security.manager import AUTH_OID
# from flask_appbuilder.security.manager import AUTH_REMOTE_USER
basedir = os.path.abspath(os.path.dirname(__file__))
# Flask-WTF flag for CSRF
WTF_CSRF_ENABLED = True
# ----------------------------------------------------
# AUTHENTICATION CONFIG
# ----------------------------------------------------
# For details on how to set up each of the following authentication, see
# http://flask-appbuilder.readthedocs.io/en/latest/security.html# authentication-methods
# for details.
# The authentication type
# AUTH_OID : Is for OpenID
# AUTH_DB : Is for database
# AUTH_LDAP : Is for LDAP
# AUTH_REMOTE_USER : Is for using REMOTE_USER from web server
# AUTH_OAUTH : Is for OAuth
AUTH_TYPE = AUTH_DB
# Uncomment to setup Full admin role name
# AUTH_ROLE_ADMIN = 'Admin'
# Uncomment to setup Public role name, no authentication needed
# AUTH_ROLE_PUBLIC = 'Public'
# Will allow user self registration
# AUTH_USER_REGISTRATION = True
# The recaptcha it's automatically enabled for user self registration is active and the keys are necessary
# RECAPTCHA_PRIVATE_KEY = PRIVATE_KEY
# RECAPTCHA_PUBLIC_KEY = PUBLIC_KEY
# Config for Flask-Mail necessary for user self registration
# MAIL_SERVER = 'smtp.gmail.com'
# MAIL_USE_TLS = True
# MAIL_USERNAME = 'yourappemail@gmail.com'
# MAIL_PASSWORD = 'passwordformail'
# MAIL_DEFAULT_SENDER = 'sender@gmail.com'
# The default user self registration role
# AUTH_USER_REGISTRATION_ROLE = "Public"
# When using OAuth Auth, uncomment to setup provider(s) info
# Google OAuth example:
# OAUTH_PROVIDERS = [{
# 'name':'google',
# 'token_key':'access_token',
# 'icon':'fa-google',
# 'remote_app': {
# 'api_base_url':'https://www.googleapis.com/oauth2/v2/',
# 'client_kwargs':{
# 'scope': 'email profile'
# },
# 'access_token_url':'https://accounts.google.com/o/oauth2/token',
# 'authorize_url':'https://accounts.google.com/o/oauth2/auth',
# 'request_token_url': None,
# 'client_id': GOOGLE_KEY,
# 'client_secret': GOOGLE_SECRET_KEY,
# }
# }]
# When using LDAP Auth, setup the ldap server
# AUTH_LDAP_SERVER = "ldap://ldapserver.new"
# When using OpenID Auth, uncomment to setup OpenID providers.
# example for OpenID authentication
# OPENID_PROVIDERS = [
# { 'name': 'Yahoo', 'url': 'https://me.yahoo.com' },
# { 'name': 'AOL', 'url': 'http://openid.aol.com/<username>' },
# { 'name': 'Flickr', 'url': 'http://www.flickr.com/<username>' },
# { 'name': 'MyOpenID', 'url': 'https://www.myopenid.com' }]
# ----------------------------------------------------
# Theme CONFIG
# ----------------------------------------------------
# Flask App Builder comes up with a number of predefined themes
# that you can use for Apache Airflow.
# http://flask-appbuilder.readthedocs.io/en/latest/customizing.html#changing-themes
# Please make sure to remove "navbar_color" configuration from airflow.cfg
# in order to fully utilize the theme. (or use that property in conjunction with theme)
# APP_THEME = "bootstrap-theme.css" # default bootstrap
# APP_THEME = "amelia.css"
# APP_THEME = "cerulean.css"
APP_THEME = "cosmo.css"
# APP_THEME = "cyborg.css"
# APP_THEME = "darkly.css"
# APP_THEME = "flatly.css"
# APP_THEME = "journal.css"
# APP_THEME = "lumen.css"
# APP_THEME = "paper.css"
# APP_THEME = "readable.css"
# APP_THEME = "sandstone.css"
# APP_THEME = "simplex.css"
# APP_THEME = "slate.css"
# APP_THEME = "solar.css"
# APP_THEME = "spacelab.css"
# APP_THEME = "superhero.css"
# APP_THEME = "united.css"
# APP_THEME = "yeti.css"
# FAB_STATIC_FOLDER =

@ -0,0 +1,264 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example Airflow DAG that shows the complex DAG structure.
"""
from airflow import models
from airflow.models.baseoperator import chain
from airflow.operators.bash import BashOperator
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
with models.DAG(
dag_id="example_complex",
schedule_interval=None,
start_date=days_ago(1),
tags=["example", "example2", "example3"],
) as dag:
# Create
create_entry_group = BashOperator(
task_id="create_entry_group", bash_command="echo create_entry_group"
)
create_entry_group_result = BashOperator(
task_id="create_entry_group_result",
bash_command="echo create_entry_group_result",
)
create_entry_group_result2 = BashOperator(
task_id="create_entry_group_result2",
bash_command="echo create_entry_group_result2",
)
create_entry_gcs = BashOperator(
task_id="create_entry_gcs", bash_command="echo create_entry_gcs"
)
create_entry_gcs_result = BashOperator(
task_id="create_entry_gcs_result", bash_command="echo create_entry_gcs_result"
)
create_entry_gcs_result2 = BashOperator(
task_id="create_entry_gcs_result2", bash_command="echo create_entry_gcs_result2"
)
create_tag = BashOperator(task_id="create_tag", bash_command="echo create_tag")
create_tag_result = BashOperator(
task_id="create_tag_result", bash_command="echo create_tag_result"
)
create_tag_result2 = BashOperator(
task_id="create_tag_result2", bash_command="echo create_tag_result2"
)
create_tag_template = BashOperator(
task_id="create_tag_template", bash_command="echo create_tag_template"
)
create_tag_template_result = BashOperator(
task_id="create_tag_template_result",
bash_command="echo create_tag_template_result",
)
create_tag_template_result2 = BashOperator(
task_id="create_tag_template_result2",
bash_command="echo create_tag_template_result2",
)
create_tag_template_field = BashOperator(
task_id="create_tag_template_field",
bash_command="echo create_tag_template_field",
)
create_tag_template_field_result = BashOperator(
task_id="create_tag_template_field_result",
bash_command="echo create_tag_template_field_result",
)
create_tag_template_field_result2 = BashOperator(
task_id="create_tag_template_field_result2",
bash_command="echo create_tag_template_field_result",
)
# Delete
delete_entry = BashOperator(
task_id="delete_entry", bash_command="echo delete_entry"
)
create_entry_gcs >> delete_entry
delete_entry_group = BashOperator(
task_id="delete_entry_group", bash_command="echo delete_entry_group"
)
create_entry_group >> delete_entry_group
delete_tag = BashOperator(task_id="delete_tag", bash_command="echo delete_tag")
create_tag >> delete_tag
delete_tag_template_field = BashOperator(
task_id="delete_tag_template_field",
bash_command="echo delete_tag_template_field",
)
delete_tag_template = BashOperator(
task_id="delete_tag_template", bash_command="echo delete_tag_template"
)
# Get
get_entry_group = BashOperator(
task_id="get_entry_group", bash_command="echo get_entry_group"
)
get_entry_group_result = BashOperator(
task_id="get_entry_group_result", bash_command="echo get_entry_group_result"
)
get_entry = BashOperator(task_id="get_entry", bash_command="echo get_entry")
get_entry_result = BashOperator(
task_id="get_entry_result", bash_command="echo get_entry_result"
)
get_tag_template = BashOperator(
task_id="get_tag_template", bash_command="echo get_tag_template"
)
get_tag_template_result = BashOperator(
task_id="get_tag_template_result", bash_command="echo get_tag_template_result"
)
# List
list_tags = BashOperator(task_id="list_tags", bash_command="echo list_tags")
list_tags_result = BashOperator(
task_id="list_tags_result", bash_command="echo list_tags_result"
)
# Lookup
lookup_entry = BashOperator(
task_id="lookup_entry", bash_command="echo lookup_entry"
)
lookup_entry_result = BashOperator(
task_id="lookup_entry_result", bash_command="echo lookup_entry_result"
)
# Rename
rename_tag_template_field = BashOperator(
task_id="rename_tag_template_field",
bash_command="echo rename_tag_template_field",
)
# Search
search_catalog = PythonOperator(
task_id="search_catalog", python_callable=lambda: print("search_catalog")
)
search_catalog_result = BashOperator(
task_id="search_catalog_result", bash_command="echo search_catalog_result"
)
# Update
update_entry = BashOperator(
task_id="update_entry", bash_command="echo update_entry"
)
update_tag = BashOperator(task_id="update_tag", bash_command="echo update_tag")
update_tag_template = BashOperator(
task_id="update_tag_template", bash_command="echo update_tag_template"
)
update_tag_template_field = BashOperator(
task_id="update_tag_template_field",
bash_command="echo update_tag_template_field",
)
# Create
create_tasks = [
create_entry_group,
create_entry_gcs,
create_tag_template,
create_tag_template_field,
create_tag,
]
chain(*create_tasks)
create_entry_group >> delete_entry_group
create_entry_group >> create_entry_group_result
create_entry_group >> create_entry_group_result2
create_entry_gcs >> delete_entry
create_entry_gcs >> create_entry_gcs_result
create_entry_gcs >> create_entry_gcs_result2
create_tag_template >> delete_tag_template_field
create_tag_template >> create_tag_template_result
create_tag_template >> create_tag_template_result2
create_tag_template_field >> delete_tag_template_field
create_tag_template_field >> create_tag_template_field_result
create_tag_template_field >> create_tag_template_field_result2
create_tag >> delete_tag
create_tag >> create_tag_result
create_tag >> create_tag_result2
# Delete
delete_tasks = [
delete_tag,
delete_tag_template_field,
delete_tag_template,
delete_entry_group,
delete_entry,
]
chain(*delete_tasks)
# Get
create_tag_template >> get_tag_template >> delete_tag_template
get_tag_template >> get_tag_template_result
create_entry_gcs >> get_entry >> delete_entry
get_entry >> get_entry_result
create_entry_group >> get_entry_group >> delete_entry_group
get_entry_group >> get_entry_group_result
# List
create_tag >> list_tags >> delete_tag
list_tags >> list_tags_result
# Lookup
create_entry_gcs >> lookup_entry >> delete_entry
lookup_entry >> lookup_entry_result
# Rename
create_tag_template_field >> rename_tag_template_field >> delete_tag_template_field
# Search
chain(create_tasks, search_catalog, delete_tasks)
search_catalog >> search_catalog_result
# Update
create_entry_gcs >> update_entry >> delete_entry
create_tag >> update_tag >> delete_tag
create_tag_template >> update_tag_template >> delete_tag_template
create_tag_template_field >> update_tag_template_field >> rename_tag_template_field

@ -0,0 +1,93 @@
"""
Code that goes along with the Airflow located at:
http://airflow.readthedocs.org/en/latest/tutorial.html
"""
from datetime import datetime, timedelta
from airflow import DAG
# Operators; we need this to operate!
from airflow.operators.bash import BashOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.dates import days_ago
default_args = {
"owner": "donaldrich",
"depends_on_past": False,
"start_date": datetime(2016, 7, 13),
"email": ["email@gmail.com"],
"email_on_failure": False,
"email_on_retry": False,
"retries": 1,
"retry_delay": timedelta(minutes=15),
}
script_path = "/data/scripts"
data_path = "/data/data/archive"
dag = DAG(
"zip_docker",
default_args=default_args,
description="A simple tutorial DAG",
schedule_interval=None,
start_date=days_ago(2),
tags=["zip", "docker"],
)
with dag:
start = DummyOperator(task_id="start", dag=dag)
pull1 = BashOperator(
task_id="pull_chrome",
bash_command="sudo docker pull selenoid/chrome:latest",
# retries=3,
# dag=dag
)
pull2 = BashOperator(
task_id="pull_recorder",
bash_command="sudo docker pull selenoid/video-recorder:latest-release",
# retries=3,
# dag=dag
)
scrape = BashOperator(
task_id="scrape_listings",
bash_command="sh /data/scripts/scrape.sh -s target -k docker",
# retries=3,
# dag=dag
)
cleanup = BashOperator(
task_id="cleanup",
bash_command="sh /data/scripts/post-scrape.sh -s target -k devops",
# retries=3,
# dag=dag
)
end = DummyOperator(task_id="end", dag=dag)
start >> [pull1, pull2] >> scrape >> cleanup >> end
# scrape = BashOperator(
# task_id="scrape_listings",
# bash_command="python3 " + script_path + '/gather/zip.py -k "devops"',
# # retries=3,
# # dag=dag
# )
# cleanup = BashOperator(
# task_id="cleanup",
# bash_command=script_path + "/post-scrape.sh -s zip -k devops",
# # retries=3,
# # dag=dag
# )
# init >> [pre1,pre2]
# init >> pre2
# pre2 >> scrape
# scrape >> cleanup

@ -0,0 +1,125 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
### Tutorial Documentation
Documentation that goes along with the Airflow tutorial located
[here](https://airflow.apache.org/tutorial.html)
"""
# [START tutorial]
# [START import_module]
from datetime import timedelta
from textwrap import dedent
# The DAG object; we'll need this to instantiate a DAG
from airflow import DAG
# Operators; we need this to operate!
from airflow.operators.bash import BashOperator
from airflow.utils.dates import days_ago
# [END import_module]
# [START default_args]
# These args will get passed on to each operator
# You can override them on a per-task basis during operator initialization
default_args = {
"owner": "airflow",
"depends_on_past": False,
"email": ["airflow@example.com"],
"email_on_failure": False,
"email_on_retry": False,
"retries": 1,
"retry_delay": timedelta(minutes=5),
# 'queue': 'bash_queue',
# 'pool': 'backfill',
# 'priority_weight': 10,
# 'end_date': datetime(2016, 1, 1),
# 'wait_for_downstream': False,
# 'dag': dag,
# 'sla': timedelta(hours=2),
# 'execution_timeout': timedelta(seconds=300),
# 'on_failure_callback': some_function,
# 'on_success_callback': some_other_function,
# 'on_retry_callback': another_function,
# 'sla_miss_callback': yet_another_function,
# 'trigger_rule': 'all_success'
}
# [END default_args]
# [START instantiate_dag]
with DAG(
"debug",
default_args=default_args,
description="A simple tutorial DAG",
schedule_interval=None,
start_date=days_ago(2),
tags=["example"],
) as dag:
# [END instantiate_dag]
cleanup = BashOperator(
task_id="cleanup",
bash_command="sudo sh /data/scripts/post-scrape.sh -s target -k devops",
)
# t1, t2 and t3 are examples of tasks created by instantiating operators
# [START basic_task]
# t1 = BashOperator(
# task_id="print_date",
# bash_command="date",
# )
# t2 = BashOperator(
# task_id="sleep",
# depends_on_past=False,
# bash_command="sleep 5",
# retries=3,
# )
# # [END basic_task]
# # [START documentation]
# dag.doc_md = __doc__
cleanup.doc_md = dedent("""
#### Task Documentation
You can document your task using the attributes `doc_md` (markdown),
`doc` (plain text), `doc_rst`, `doc_json`, `doc_yaml` which gets
rendered in the UI's Task Instance Details page.
![img](http://montcs.bloomu.edu/~bobmon/Semesters/2012-01/491/import%20soul.png)
""")
# [END documentation]
# [START jinja_template]
templated_command = dedent("""
{% for i in range(5) %}
echo "{{ ds }}"
echo "{{ macros.ds_add(ds, 7)}}"
echo "{{ params.my_param }}"
{% endfor %}
""")
t3 = BashOperator(
task_id="templated",
depends_on_past=False,
bash_command=templated_command,
params={"my_param": "Parameter I passed in"},
)
# [END jinja_template]
cleanup
# [END tutorial]

@ -0,0 +1,101 @@
import random
from time import sleep
from selenium import webdriver
from selenium.webdriver import ActionChains
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.wait import WebDriverWait
# INSTANCE DRIVER
def configure_driver(settings, user_agent):
options = webdriver.ChromeOptions()
# Options to try to fool the site we are a normal browser.
options.add_experimental_option("excludeSwitches", ["enable-automation"])
options.add_experimental_option("useAutomationExtension", False)
options.add_experimental_option(
"prefs",
{
"download.default_directory": settings.files_path,
"download.prompt_for_download": False,
"download.directory_upgrade": True,
"safebrowsing_for_trusted_sources_enabled": False,
"safebrowsing.enabled": False,
},
)
options.add_argument("--incognito")
options.add_argument("start-maximized")
options.add_argument("user-agent={user_agent}")
options.add_argument("disable-blink-features")
options.add_argument("disable-blink-features=AutomationControlled")
driver = webdriver.Chrome(
options=options, executable_path=settings.chromedriver_path
)
# #Overwrite the webdriver property
# driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument",
# {
# "source": """
# Object.defineProperty(navigator, 'webdriver', {
# get: () => undefined
# })
# """
# })
driver.execute_cdp_cmd("Network.enable", {})
# Overwrite the User-Agent header
driver.execute_cdp_cmd(
"Network.setExtraHTTPHeaders", {"headers": {"User-Agent": user_agent}}
)
driver.command_executor._commands["send_command"] = (
"POST",
"/session/$sessionId/chromium/send_command",
)
return driver
# SOLVENTAR CAPTCHA
def solve_wait_recaptcha(driver):
###### Move to reCAPTCHA Iframe
WebDriverWait(driver, 5).until(
EC.frame_to_be_available_and_switch_to_it(
(
By.CSS_SELECTOR,
"iframe[src^='https://www.google.com/recaptcha/api2/anchor?']",
)
)
)
check_selector = "span.recaptcha-checkbox.goog-inline-block.recaptcha-checkbox-unchecked.rc-anchor-checkbox"
captcha_check = driver.find_element_by_css_selector(check_selector)
###### Random delay before hover & click the checkbox
sleep(random.uniform(3, 6))
ActionChains(driver).move_to_element(captcha_check).perform()
###### Hover it
sleep(random.uniform(0.5, 1))
hov = ActionChains(driver).move_to_element(captcha_check).perform()
###### Random delay before click the checkbox
sleep(random.uniform(0.5, 1))
driver.execute_script("arguments[0].click()", captcha_check)
###### Wait for recaptcha to be in solved state
elem = None
while elem is None:
try:
elem = driver.find_element_by_class_name("recaptcha-checkbox-checked")
except:
pass
sleep(5)

@ -0,0 +1,27 @@
import mechanicalsoup
browser = mechanicalsoup.StatefulBrowser()
browser.open("http://httpbin.org/")
print(browser.url)
browser.follow_link("forms")
print(browser.url)
print(browser.page)
browser.select_form('form[action="/post"]')
browser["custname"] = "Me"
browser["custtel"] = "00 00 0001"
browser["custemail"] = "nobody@example.com"
browser["size"] = "medium"
browser["topping"] = "onion"
browser["topping"] = ("bacon", "cheese")
browser["comments"] = "This pizza looks really good :-)"
# Uncomment to launch a real web browser on the current page.
# browser.launch_browser()
# Uncomment to display a summary of the filled-in form
browser.form.print_summary()
response = browser.submit_selected()
print(response.text)

@ -0,0 +1,270 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Airflow logging settings"""
import os
from pathlib import Path
from typing import Any, Dict, Union
from urllib.parse import urlparse
from airflow.configuration import conf
from airflow.exceptions import AirflowException
# TODO: Logging format and level should be configured
# in this file instead of from airflow.cfg. Currently
# there are other log format and level configurations in
# settings.py and cli.py. Please see AIRFLOW-1455.
LOG_LEVEL: str = conf.get('logging', 'LOGGING_LEVEL').upper()
# Flask appbuilder's info level log is very verbose,
# so it's set to 'WARN' by default.
FAB_LOG_LEVEL: str = conf.get('logging', 'FAB_LOGGING_LEVEL').upper()
LOG_FORMAT: str = conf.get('logging', 'LOG_FORMAT')
COLORED_LOG_FORMAT: str = conf.get('logging', 'COLORED_LOG_FORMAT')
COLORED_LOG: bool = conf.getboolean('logging', 'COLORED_CONSOLE_LOG')
COLORED_FORMATTER_CLASS: str = conf.get('logging', 'COLORED_FORMATTER_CLASS')
BASE_LOG_FOLDER: str = conf.get('logging', 'BASE_LOG_FOLDER')
PROCESSOR_LOG_FOLDER: str = conf.get('scheduler', 'CHILD_PROCESS_LOG_DIRECTORY')
DAG_PROCESSOR_MANAGER_LOG_LOCATION: str = conf.get('logging', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION')
FILENAME_TEMPLATE: str = conf.get('logging', 'LOG_FILENAME_TEMPLATE')
PROCESSOR_FILENAME_TEMPLATE: str = conf.get('logging', 'LOG_PROCESSOR_FILENAME_TEMPLATE')
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'airflow': {'format': LOG_FORMAT},
'airflow_coloured': {
'format': COLORED_LOG_FORMAT if COLORED_LOG else LOG_FORMAT,
'class': COLORED_FORMATTER_CLASS if COLORED_LOG else 'logging.Formatter',
},
},
'handlers': {
'console': {
'class': 'airflow.utils.log.logging_mixin.RedirectStdHandler',
'formatter': 'airflow_coloured',
'stream': 'sys.stdout',
},
'task': {
'class': 'airflow.utils.log.file_task_handler.FileTaskHandler',
'formatter': 'airflow',
'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER),
'filename_template': FILENAME_TEMPLATE,
},
'processor': {
'class': 'airflow.utils.log.file_processor_handler.FileProcessorHandler',
'formatter': 'airflow',
'base_log_folder': os.path.expanduser(PROCESSOR_LOG_FOLDER),
'filename_template': PROCESSOR_FILENAME_TEMPLATE,
},
},
'loggers': {
'airflow.processor': {
'handlers': ['processor'],
'level': LOG_LEVEL,
'propagate': False,
},
'airflow.task': {
'handlers': ['task'],
'level': LOG_LEVEL,
'propagate': False,
},
'flask_appbuilder': {
'handler': ['console'],
'level': FAB_LOG_LEVEL,
'propagate': True,
},
},
'root': {
'handlers': ['console'],
'level': LOG_LEVEL,
},
}
EXTRA_LOGGER_NAMES: str = conf.get('logging', 'EXTRA_LOGGER_NAMES', fallback=None)
if EXTRA_LOGGER_NAMES:
new_loggers = {
logger_name.strip(): {
'handler': ['console'],
'level': LOG_LEVEL,
'propagate': True,
}
for logger_name in EXTRA_LOGGER_NAMES.split(",")
}
DEFAULT_LOGGING_CONFIG['loggers'].update(new_loggers)
DEFAULT_DAG_PARSING_LOGGING_CONFIG: Dict[str, Dict[str, Dict[str, Any]]] = {
'handlers': {
'processor_manager': {
'class': 'logging.handlers.RotatingFileHandler',
'formatter': 'airflow',
'filename': DAG_PROCESSOR_MANAGER_LOG_LOCATION,
'mode': 'a',
'maxBytes': 104857600, # 100MB
'backupCount': 5,
}
},
'loggers': {
'airflow.processor_manager': {
'handlers': ['processor_manager'],
'level': LOG_LEVEL,
'propagate': False,
}
},
}
# Only update the handlers and loggers when CONFIG_PROCESSOR_MANAGER_LOGGER is set.
# This is to avoid exceptions when initializing RotatingFileHandler multiple times
# in multiple processes.
if os.environ.get('CONFIG_PROCESSOR_MANAGER_LOGGER') == 'True':
DEFAULT_LOGGING_CONFIG['handlers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers'])
DEFAULT_LOGGING_CONFIG['loggers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['loggers'])
# Manually create log directory for processor_manager handler as RotatingFileHandler
# will only create file but not the directory.
processor_manager_handler_config: Dict[str, Any] = DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers'][
'processor_manager'
]
directory: str = os.path.dirname(processor_manager_handler_config['filename'])
Path(directory).mkdir(parents=True, exist_ok=True, mode=0o755)
##################
# Remote logging #
##################
REMOTE_LOGGING: bool = conf.getboolean('logging', 'remote_logging')
if REMOTE_LOGGING:
ELASTICSEARCH_HOST: str = conf.get('elasticsearch', 'HOST')
# Storage bucket URL for remote logging
# S3 buckets should start with "s3://"
# Cloudwatch log groups should start with "cloudwatch://"
# GCS buckets should start with "gs://"
# WASB buckets should start with "wasb"
# just to help Airflow select correct handler
REMOTE_BASE_LOG_FOLDER: str = conf.get('logging', 'REMOTE_BASE_LOG_FOLDER')
if REMOTE_BASE_LOG_FOLDER.startswith('s3://'):
S3_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = {
'task': {
'class': 'airflow.providers.amazon.aws.log.s3_task_handler.S3TaskHandler',
'formatter': 'airflow',
'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)),
's3_log_folder': REMOTE_BASE_LOG_FOLDER,
'filename_template': FILENAME_TEMPLATE,
},
}
DEFAULT_LOGGING_CONFIG['handlers'].update(S3_REMOTE_HANDLERS)
elif REMOTE_BASE_LOG_FOLDER.startswith('cloudwatch://'):
CLOUDWATCH_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = {
'task': {
'class': 'airflow.providers.amazon.aws.log.cloudwatch_task_handler.CloudwatchTaskHandler',
'formatter': 'airflow',
'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)),
'log_group_arn': urlparse(REMOTE_BASE_LOG_FOLDER).netloc,
'filename_template': FILENAME_TEMPLATE,
},
}
DEFAULT_LOGGING_CONFIG['handlers'].update(CLOUDWATCH_REMOTE_HANDLERS)
elif REMOTE_BASE_LOG_FOLDER.startswith('gs://'):
key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None)
GCS_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = {
'task': {
'class': 'airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler',
'formatter': 'airflow',
'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)),
'gcs_log_folder': REMOTE_BASE_LOG_FOLDER,
'filename_template': FILENAME_TEMPLATE,
'gcp_key_path': key_path,
},
}
DEFAULT_LOGGING_CONFIG['handlers'].update(GCS_REMOTE_HANDLERS)
elif REMOTE_BASE_LOG_FOLDER.startswith('wasb'):
WASB_REMOTE_HANDLERS: Dict[str, Dict[str, Union[str, bool]]] = {
'task': {
'class': 'airflow.providers.microsoft.azure.log.wasb_task_handler.WasbTaskHandler',
'formatter': 'airflow',
'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)),
'wasb_log_folder': REMOTE_BASE_LOG_FOLDER,
'wasb_container': 'airflow-logs',
'filename_template': FILENAME_TEMPLATE,
'delete_local_copy': False,
},
}
DEFAULT_LOGGING_CONFIG['handlers'].update(WASB_REMOTE_HANDLERS)
elif REMOTE_BASE_LOG_FOLDER.startswith('stackdriver://'):
key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None)
# stackdriver:///airflow-tasks => airflow-tasks
log_name = urlparse(REMOTE_BASE_LOG_FOLDER).path[1:]
STACKDRIVER_REMOTE_HANDLERS = {
'task': {
'class': 'airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverTaskHandler',
'formatter': 'airflow',
'name': log_name,
'gcp_key_path': key_path,
}
}
DEFAULT_LOGGING_CONFIG['handlers'].update(STACKDRIVER_REMOTE_HANDLERS)
elif ELASTICSEARCH_HOST:
ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get('elasticsearch', 'LOG_ID_TEMPLATE')
ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get('elasticsearch', 'END_OF_LOG_MARK')
ELASTICSEARCH_FRONTEND: str = conf.get('elasticsearch', 'frontend')
ELASTICSEARCH_WRITE_STDOUT: bool = conf.getboolean('elasticsearch', 'WRITE_STDOUT')
ELASTICSEARCH_JSON_FORMAT: bool = conf.getboolean('elasticsearch', 'JSON_FORMAT')
ELASTICSEARCH_JSON_FIELDS: str = conf.get('elasticsearch', 'JSON_FIELDS')
ELASTIC_REMOTE_HANDLERS: Dict[str, Dict[str, Union[str, bool]]] = {
'task': {
'class': 'airflow.providers.elasticsearch.log.es_task_handler.ElasticsearchTaskHandler',
'formatter': 'airflow',
'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)),
'log_id_template': ELASTICSEARCH_LOG_ID_TEMPLATE,
'filename_template': FILENAME_TEMPLATE,
'end_of_log_mark': ELASTICSEARCH_END_OF_LOG_MARK,
'host': ELASTICSEARCH_HOST,
'frontend': ELASTICSEARCH_FRONTEND,
'write_stdout': ELASTICSEARCH_WRITE_STDOUT,
'json_format': ELASTICSEARCH_JSON_FORMAT,
'json_fields': ELASTICSEARCH_JSON_FIELDS,
},
}
DEFAULT_LOGGING_CONFIG['handlers'].update(ELASTIC_REMOTE_HANDLERS)
else:
raise AirflowException(
"Incorrect remote log configuration. Please check the configuration of option 'host' in "
"section 'elasticsearch' if you are using Elasticsearch. In the other case, "
"'remote_base_log_folder' option in the 'logging' section."
)

@ -0,0 +1,28 @@
from airflow.plugins_manager import AirflowPlugin
from flask_admin.base import MenuLink
github = MenuLink(
category="Astronomer",
name="Airflow Instance Github Repo",
url="https://github.com/astronomerio/astronomer-dags",
)
astronomer_home = MenuLink(
category="Astronomer", name="Astronomer Home", url="https://www.astronomer.io/"
)
aiflow_plugins = MenuLink(
category="Astronomer",
name="Airflow Plugins",
url="https://github.com/airflow-plugins",
)
# Defining the plugin class
class AirflowTestPlugin(AirflowPlugin):
name = "AstronomerMenuLinks"
operators = []
flask_blueprints = []
hooks = []
executors = []
admin_views = []
menu_links = [github, astronomer_home, aiflow_plugins]

@ -0,0 +1,32 @@
"""
Singer
This example shows how to use Singer within Airflow using a custom operator:
- SingerOperator
https://github.com/airflow-plugins/singer_plugin/blob/master/operators/singer_operator.py#L5
A complete list of Taps and Targets can be found in the Singer.io Github org:
https://github.com/singer-io
"""
from datetime import datetime
from airflow import DAG
from airflow.operators import SingerOperator
default_args = {
"start_date": datetime(2018, 2, 22),
"retries": 0,
"email": [],
"email_on_failure": True,
"email_on_retry": False,
}
dag = DAG(
"__singer__fixerio_to_csv", schedule_interval="@hourly", default_args=default_args
)
with dag:
singer = SingerOperator(task_id="singer", tap="fixerio", target="csv")
singer

@ -0,0 +1,181 @@
"""
Headless Site Navigation and File Download (Using Selenium) to S3
This example demonstrates using Selenium (via Firefox/GeckoDriver) to:
1) Log into a website w/ credentials stored in connection labeled 'selenium_conn_id'
2) Download a file (initiated on login)
3) Transform the CSV into JSON formatting
4) Append the current data to each record
5) Load the corresponding file into S3
To use this DAG, you will need to have the following installed:
[XVFB](https://www.x.org/archive/X11R7.6/doc/man/man1/Xvfb.1.xhtml)
[GeckoDriver](https://github.com/mozilla/geckodriver/releases/download)
selenium==3.11.0
xvfbwrapper==0.2.9
"""
import csv
import datetime
import json
import logging
import os
import time
from datetime import datetime, timedelta
import boa
import requests
from airflow import DAG
from airflow.models import Connection
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import PythonOperator, PythonVirtualenvOperator
from airflow.providers.docker.operators.docker import DockerOperator
from airflow.utils.dates import days_ago
from airflow.utils.db import provide_session
from bs4 import BeautifulSoup
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.by import By
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
from selenium.webdriver.common.keys import Keys
from tools.network import (
browser_config,
browser_feature,
driver_object,
ip_status,
proxy_ip,
selenoid_status,
uc_test,
user_agent,
vpn_settings,
)
from config.webdriver import browser_capabilities, browser_options
default_args = {
"start_date": days_ago(2),
"email": [],
"email_on_failure": True,
"email_on_retry": False,
# 'retries': 2,
"retry_delay": timedelta(minutes=5),
"catchup": False,
}
# def hello_world_py():
# # selenium_conn_id = kwargs.get('templates_dict', None).get('selenium_conn_id', None)
# # filename = kwargs.get('templates_dict', None).get('filename', None)
# # s3_conn_id = kwargs.get('templates_dict', None).get('s3_conn_id', None)
# # s3_bucket = kwargs.get('templates_dict', None).get('s3_bucket', None)
# # s3_key = kwargs.get('templates_dict', None).get('s3_key', None)
# # date = kwargs.get('templates_dict', None).get('date', None)
# # module_name = kwargs.get('templates_dict', None).get('module', None)
# module = "anon_browser_test"
# chrome_options = browser_options()
# capabilities = browser_capabilities(module)
# logging.info('Assembling driver')
# driver = webdriver.Remote(
# command_executor="http://192.168.1.101:4444/wd/hub",
# options=chrome_options,
# desired_capabilities=capabilities,
# )
# logging.info('proxy IP')
# proxy_ip()
# logging.info('driver')
# vpn_settings()
# logging.info('driver')
# selenoid_status()
# logging.info('driver')
# ip_status(driver)
# logging.info('driver')
# browser_config(driver)
# logging.info('driver')
# user_agent(driver)
# logging.info('driver')
# driver_object(driver)
# logging.info('driver')
# browser_feature(driver)
# logging.info('driver')
# driver.quit()
# logging.info('driver')
# uc_test()
# # print("Finished")
# return 'Whatever you return gets printed in the logs'
# dag = DAG(
# 'anon_browser_test',
# schedule_interval='@daily',
# default_args=default_args,
# catchup=False
# )
# dummy_operator = DummyOperator(task_id="dummy_task", retries=3, dag=dag)
# selenium = PythonOperator(
# task_id='anon_browser_test',
# python_callable=hello_world_py,
# templates_dict={"module": "anon_browser_test"},
# dag=dag
# # "s3_bucket": S3_BUCKET,
# # "s3_key": S3_KEY,
# # "date": date}
# # provide_context=True
# )
# t1 = DockerOperator(
# # api_version='1.19',
# # docker_url='tcp://localhost:2375', # Set your docker URL
# command='/bin/sleep 30',
# image='selenoid/chrome:latest',
# # network_mode='bridge',
# task_id='chrome',
# dag=dag,
# )
# t2 = DockerOperator(
# # api_version='1.19',
# # docker_url='tcp://localhost:2375', # Set your docker URL
# command='/bin/sleep 30',
# image='selenoid/video-recorder:latest-release',
# # network_mode='bridge',
# task_id='video_recorder',
# dag=dag,
# )
# [START howto_operator_python_venv]
def callable_virtualenv():
"""
Example function that will be performed in a virtual environment.
Importing at the module level ensures that it will not attempt to import the
library before it is installed.
"""
from time import sleep
from colorama import Back, Fore, Style
print(Fore.RED + "some red text")
print(Back.GREEN + "and with a green background")
print(Style.DIM + "and in dim text")
print(Style.RESET_ALL)
for _ in range(10):
print(Style.DIM + "Please wait...", flush=True)
sleep(10)
print("Finished")
virtualenv_task = PythonVirtualenvOperator(
task_id="virtualenv_python",
python_callable=callable_virtualenv,
requirements=["colorama==0.4.0"],
system_site_packages=False,
dag=dag,
)
# [END howto_operator_python_venv]
# selenium >> dummy_operator
# dummy_operator >> virtualenv_task
# t1 >> selenium
# t2 >> selenium

@ -0,0 +1,62 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import timedelta
from airflow import DAG
from airflow.operators.bash import BashOperator
from airflow.providers.docker.operators.docker import DockerOperator
from airflow.utils.dates import days_ago
default_args = {
"owner": "airflow",
"depends_on_past": False,
"email": ["airflow@example.com"],
"email_on_failure": False,
"email_on_retry": False,
"retries": 1,
"retry_delay": timedelta(minutes=5),
}
dag = DAG(
"docker_sample",
default_args=default_args,
schedule_interval=timedelta(minutes=10),
start_date=days_ago(2),
)
t1 = BashOperator(task_id="print_date", bash_command="date", dag=dag)
t2 = BashOperator(task_id="sleep", bash_command="sleep 5", retries=3, dag=dag)
t3 = DockerOperator(
api_version="1.19",
docker_url="tcp://localhost:2375", # Set your docker URL
command="/bin/sleep 30",
image="centos:latest",
network_mode="bridge",
task_id="docker_op_tester",
dag=dag,
)
t4 = BashOperator(task_id="print_hello", bash_command='echo "hello world!!!"', dag=dag)
t1 >> t2
t1 >> t3
t3 >> t4

@ -0,0 +1,114 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring
# [START tutorial]
# [START import_module]
import json
from airflow.decorators import dag, task
from airflow.utils.dates import days_ago
# [END import_module]
# [START default_args]
# These args will get passed on to each operator
# You can override them on a per-task basis during operator initialization
default_args = {
"owner": "airflow",
}
# [END default_args]
# [START instantiate_dag]
@dag(
default_args=default_args,
schedule_interval=None,
start_date=days_ago(2),
tags=["example"],
)
def tutorial_taskflow_api_etl():
"""
### TaskFlow API Tutorial Documentation
This is a simple ETL data pipeline example which demonstrates the use of
the TaskFlow API using three simple tasks for Extract, Transform, and Load.
Documentation that goes along with the Airflow TaskFlow API tutorial is
located
[here](https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html)
"""
# [END instantiate_dag]
# [START extract]
@task()
def extract():
"""
#### Extract task
A simple Extract task to get data ready for the rest of the data
pipeline. In this case, getting data is simulated by reading from a
hardcoded JSON string.
"""
data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
order_data_dict = json.loads(data_string)
return order_data_dict
# [END extract]
# [START transform]
@task(multiple_outputs=True)
def transform(order_data_dict: dict):
"""
#### Transform task
A simple Transform task which takes in the collection of order data and
computes the total order value.
"""
total_order_value = 0
for value in order_data_dict.values():
total_order_value += value
return {"total_order_value": total_order_value}
# [END transform]
# [START load]
@task()
def load(total_order_value: float):
"""
#### Load task
A simple Load task which takes in the result of the Transform task and
instead of saving it to end user review, just prints it out.
"""
print(f"Total order value is: {total_order_value:.2f}")
# [END load]
# [START main_flow]
order_data = extract()
order_summary = transform(order_data)
load(order_summary["total_order_value"])
# [END main_flow]
# [START dag_invocation]
tutorial_etl_dag = tutorial_taskflow_api_etl()
# [END dag_invocation]
# [END tutorial]

@ -0,0 +1,61 @@
from os import getenv
from sqlalchemy import VARCHAR, create_engine
from sqlalchemy.engine.url import URL
def _psql_insert_copy(table, conn, keys, data_iter):
"""
Execute SQL statement inserting data
Parameters
----------
table : pandas.io.sql.SQLTable
conn : sqlalchemy.engine.Engine or sqlalchemy.engine.Connection
keys : list of str
Column names
data_iter : Iterable that iterates the values to be inserted
"""
# Alternative to_sql() *method* for DBs that support COPY FROM
import csv
from io import StringIO
# gets a DBAPI connection that can provide a cursor
dbapi_conn = conn.connection
with dbapi_conn.cursor() as cur:
s_buf = StringIO()
writer = csv.writer(s_buf)
writer.writerows(data_iter)
s_buf.seek(0)
columns = ', '.join('"{}"'.format(k) for k in keys)
if table.schema:
table_name = '{}.{}'.format(table.schema, table.name)
else:
table_name = table.name
sql = 'COPY {} ({}) FROM STDIN WITH CSV'.format(
table_name, columns)
cur.copy_expert(sql=sql, file=s_buf)
def load_df_into_db(data_frame, schema: str, table: str) -> None:
engine = create_engine(URL(
username=getenv('DBT_POSTGRES_USER'),
password=getenv('DBT_POSTGRES_PASSWORD'),
host=getenv('DBT_POSTGRES_HOST'),
port=getenv('DBT_POSTGRES_PORT'),
database=getenv('DBT_POSTGRES_DB'),
drivername='postgres'
))
with engine.connect() as cursor:
cursor.execute(f'CREATE SCHEMA IF NOT EXISTS {schema}')
data_frame.to_sql(
schema=schema,
name=table,
dtype=VARCHAR,
if_exists='append',
con=engine,
method=_psql_insert_copy,
index=False
)

@ -0,0 +1,126 @@
"""
Code that goes along with the Airflow located at:
http://airflow.readthedocs.org/en/latest/tutorial.html
"""
from datetime import datetime, timedelta
from textwrap import dedent
from airflow import DAG
# Operators; we need this to operate!
from airflow.operators.bash import BashOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import PythonOperator, PythonVirtualenvOperator
from airflow.utils.dates import days_ago
from airflow.utils.task_group import TaskGroup
from utils.network import selenoid_status, vpn_settings
from airflow.providers.postgres.operators.postgres import PostgresOperator
from utils.notify import PushoverOperator
from utils.postgres import sqlLoad
default_args = {
"owner": "donaldrich",
"depends_on_past": False,
"start_date": datetime(2016, 7, 13),
"email": ["email@gmail.com"],
"email_on_failure": False,
"email_on_retry": False,
"retries": 1,
"retry_delay": timedelta(minutes=15),
# "on_failure_callback": some_function,
# "on_success_callback": some_other_function,
# "on_retry_callback": another_function,
}
# [START default_args]
# default_args = {
# 'owner': 'airflow',
# 'depends_on_past': False,
# 'email': ['airflow@example.com'],
# 'email_on_failure': False,
# 'email_on_retry': False,
# 'retries': 1,
# 'retry_delay': timedelta(minutes=5),
# 'queue': 'bash_queue',
# 'pool': 'backfill',
# 'priority_weight': 10,
# 'end_date': datetime(2016, 1, 1),
# 'wait_for_downstream': False,
# 'dag': dag,
# 'sla': timedelta(hours=2),
# 'execution_timeout': timedelta(seconds=300),
# 'on_failure_callback': some_function,
# 'on_success_callback': some_other_function,
# 'on_retry_callback': another_function,
# 'sla_miss_callback': yet_another_function,
# 'trigger_rule': 'all_success'
# }
# [END default_args]
dag = DAG(
"recon",
default_args=default_args,
description="A simple tutorial DAG",
schedule_interval=None,
start_date=days_ago(2),
tags=["target"],
)
with dag:
# start = DummyOperator(task_id="start", dag=dag)
with TaskGroup("pull_images") as start2:
pull1 = BashOperator(
task_id="chrome_latest",
bash_command="sudo docker pull selenoid/chrome:latest",
)
pull2 = BashOperator(
task_id="video_recorder",
bash_command=
"sudo docker pull selenoid/video-recorder:latest-release",
)
t1 = PythonOperator(
task_id="selenoid_status",
python_callable=selenoid_status,
)
[pull1, pull2] >> t1
start2.doc_md = dedent("""
#### Task Documentation
You can document your task using the attributes `doc_md` (markdown),
`doc` (plain text), `doc_rst`, `doc_json`, `doc_yaml` which gets
rendered in the UI's Task Instance Details page.
![img](http://montcs.bloomu.edu/~bobmon/Semesters/2012-01/491/import%20soul.png)
""")
scrape = BashOperator(
task_id="scrape_listings",
bash_command="python3 /data/scripts/scrapers/zip-scrape.py -k devops",
)
sqlLoad = PythonOperator(task_id="sql_load", python_callable=sqlLoad)
# get_birth_date = PostgresOperator(
# task_id="get_birth_date",
# postgres_conn_id="postgres_default",
# sql="sql/birth_date.sql",
# params={"begin_date": "2020-01-01", "end_date": "2020-12-31"},
# )
cleanup = BashOperator(
task_id="cleanup",
bash_command="sudo sh /data/scripts/post-scrape.sh -s zip -k devops",
)
success_notify = PushoverOperator(
task_id="finished",
title="Airflow Complete",
message="We did it!",
)
end = DummyOperator(task_id="end", dag=dag)
start2 >> scrape >> sqlLoad >> cleanup >> success_notify >> end

@ -0,0 +1,41 @@
from datetime import datetime, timedelta
from airflow import DAG
from airflow.operators.python_operator import PythonOperator
# from airflow.hooks.S3_hook import S3Hook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
DEFAULT_ARGS = {
"owner": "Airflow",
"depends_on_past": False,
"start_date": datetime(2020, 1, 13),
"email": ["airflow@example.com"],
"email_on_failure": False,
"email_on_retry": False,
"retries": 1,
"retry_delay": timedelta(minutes=5),
}
dag = DAG("create_date_dimension", default_args=DEFAULT_ARGS, schedule_interval="@once")
# Create a task to call your processing function
def write_text_file(ds, **kwargs):
with open("/data/data/temp.txt", "w") as fp:
# Add file generation/processing step here, E.g.:
# fp.write(ds)
# Upload generated file to Minio
# s3 = S3Hook('local_minio')
s3 = S3Hook(aws_conn_id="minio")
s3.load_file("/data/data/temp.txt", key=f"my-test-file.txt", bucket_name="airflow")
# Create a task to call your processing function
t1 = PythonOperator(
task_id="generate_and_upload_to_s3",
provide_context=True,
python_callable=write_text_file,
dag=dag,
)

@ -0,0 +1,154 @@
"""
Headless Site Navigation and File Download (Using Selenium) to S3
This example demonstrates using Selenium (via Firefox/GeckoDriver) to:
1) Log into a website w/ credentials stored in connection labeled 'selenium_conn_id'
2) Download a file (initiated on login)
3) Transform the CSV into JSON formatting
4) Append the current data to each record
5) Load the corresponding file into S3
To use this DAG, you will need to have the following installed:
[XVFB](https://www.x.org/archive/X11R7.6/doc/man/man1/Xvfb.1.xhtml)
[GeckoDriver](https://github.com/mozilla/geckodriver/releases/download)
selenium==3.11.0
xvfbwrapper==0.2.9
"""
# import boa
import csv
import json
import logging
import os
import time
from datetime import datetime, timedelta
from airflow import DAG
from airflow.hooks import S3Hook
from airflow.models import Connection
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import PythonOperator
from airflow.utils.db import provide_session
from selenium import webdriver
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.firefox.options import Options
from xvfbwrapper import Xvfb
S3_CONN_ID = ""
S3_BUCKET = ""
S3_KEY = ""
date = "{{ ds }}"
default_args = {
"start_date": datetime(2018, 2, 10, 0, 0),
"email": [],
"email_on_failure": True,
"email_on_retry": False,
"retries": 2,
"retry_delay": timedelta(minutes=5),
"catchup": False,
}
dag = DAG(
"selenium_extraction_to_s3",
schedule_interval="@daily",
default_args=default_args,
catchup=False,
)
def imap_py(**kwargs):
selenium_conn_id = kwargs.get("templates_dict", None).get("selenium_conn_id", None)
filename = kwargs.get("templates_dict", None).get("filename", None)
s3_conn_id = kwargs.get("templates_dict", None).get("s3_conn_id", None)
s3_bucket = kwargs.get("templates_dict", None).get("s3_bucket", None)
s3_key = kwargs.get("templates_dict", None).get("s3_key", None)
date = kwargs.get("templates_dict", None).get("date", None)
@provide_session
def get_conn(conn_id, session=None):
conn = session.query(Connection).filter(Connection.conn_id == conn_id).first()
return conn
url = get_conn(selenium_conn_id).host
email = get_conn(selenium_conn_id).user
pwd = get_conn(selenium_conn_id).password
vdisplay = Xvfb()
vdisplay.start()
caps = webdriver.DesiredCapabilities.FIREFOX
caps["marionette"] = True
profile = webdriver.FirefoxProfile()
profile.set_preference("browser.download.manager.showWhenStarting", False)
profile.set_preference("browser.helperApps.neverAsk.saveToDisk", "text/csv")
logging.info("Profile set...")
options = Options()
options.set_headless(headless=True)
logging.info("Options set...")
logging.info("Initializing Driver...")
driver = webdriver.Firefox(
firefox_profile=profile, firefox_options=options, capabilities=caps
)
logging.info("Driver Intialized...")
driver.get(url)
logging.info("Authenticating...")
elem = driver.find_element_by_id("email")
elem.send_keys(email)
elem = driver.find_element_by_id("password")
elem.send_keys(pwd)
elem.send_keys(Keys.RETURN)
logging.info("Successfully authenticated.")
sleep_time = 15
logging.info("Downloading File....Sleeping for {} Seconds.".format(str(sleep_time)))
time.sleep(sleep_time)
driver.close()
vdisplay.stop()
dest_s3 = S3Hook(s3_conn_id=s3_conn_id)
os.chdir("/root/Downloads")
csvfile = open(filename, "r")
output_json = "file.json"
with open(output_json, "w") as jsonfile:
reader = csv.DictReader(csvfile)
for row in reader:
# row = dict((boa.constrict(k), v) for k, v in row.items())
row["run_date"] = date
json.dump(row, jsonfile)
jsonfile.write("\n")
dest_s3.load_file(
filename=output_json, key=s3_key, bucket_name=s3_bucket, replace=True
)
dest_s3.connection.close()
with dag:
kick_off_dag = DummyOperator(task_id="kick_off_dag")
selenium = PythonOperator(
task_id="selenium_retrieval_to_s3",
python_callable=imap_py,
templates_dict={
"s3_conn_id": S3_CONN_ID,
"s3_bucket": S3_BUCKET,
"s3_key": S3_KEY,
"date": date,
},
provide_context=True,
)
kick_off_dag >> selenium

@ -0,0 +1,65 @@
from datetime import timedelta
from airflow import DAG
from airflow.operators.bash import BashOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.providers.docker.operators.docker import DockerOperator
from airflow.utils.dates import days_ago
default_args = {
"owner": "airflow",
"depends_on_past": False,
"email": ["airflow@example.com"],
"email_on_failure": False,
"email_on_retry": False,
"retries": 1,
"retry_delay": timedelta(minutes=5),
}
dag = DAG(
"selenoid_setup",
default_args=default_args,
schedule_interval="@daily",
start_date=days_ago(2),
)
with dag:
kick_off_dag = DummyOperator(task_id="kick_off_dag")
t1 = DockerOperator(
# api_version='1.19',
# docker_url='tcp://localhost:2375', # Set your docker URL
command="/bin/sleep 30",
image="selenoid/chrome:latest",
network_mode="bridge",
task_id="selenoid1",
# dag=dag,
)
t2 = DockerOperator(
# api_version='1.19',
# docker_url='tcp://localhost:2375', # Set your docker URL
command="/bin/sleep 30",
image="selenoid/video-recorder:latest-release",
network_mode="bridge",
task_id="selenoid2",
# dag=dag,
)
scrape = BashOperator(
task_id="pull_selenoid_video-recorder",
bash_command="docker pull selenoid/video-recorder:latest-release",
# retries=3,
# dag=dag
)
scrape2 = BashOperator(
task_id="pull_selenoid_chrome",
bash_command="docker pull selenoid/chrome:latest",
# retries=3,
# dag=dag
)
scrape >> t1 >> t2

@ -0,0 +1 @@
SELECT * FROM pet WHERE birth_date BETWEEN SYMMETRIC {{ params.begin_date }} AND {{ params.end_date }};

@ -0,0 +1,122 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is an example DAG for the use of the SqliteOperator.
In this example, we create two tasks that execute in sequence.
The first task calls an sql command, defined in the SQLite operator,
which when triggered, is performed on the connected sqlite database.
The second task is similar but instead calls the SQL command from an external file.
"""
import apprise
from airflow import DAG
from airflow.operators.bash import BashOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import PythonOperator
from airflow.providers.docker.operators.docker import DockerOperator
from airflow.providers.sqlite.operators.sqlite import SqliteOperator
from airflow.utils.dates import days_ago
default_args = {"owner": "airflow"}
dag = DAG(
dag_id="example_sqlite",
default_args=default_args,
schedule_interval="@hourly",
start_date=days_ago(2),
tags=["example"],
)
with dag:
# [START howto_operator_sqlite]
# Example of creating a task that calls a common CREATE TABLE sql command.
# def notify():
# # Create an Apprise instance
# apobj = apprise.Apprise()
# # Add all of the notification services by their server url.
# # A sample email notification:
# # apobj.add('mailto://myuserid:mypass@gmail.com')
# # A sample pushbullet notification
# apobj.add('pover://aejghiy6af1bshe8mbdksmkzeon3ip@umjiu36dxwwaj8pnfx3n6y2xbm3ssx')
# # Then notify these services any time you desire. The below would
# # notify all of the services loaded into our Apprise object.
# apobj.notify(
# body='what a great notification service!',
# title='my notification title',
# )
# return apobj
# apprise = PythonOperator(
# task_id="apprise",
# python_callable=notify,
# dag=dag
# # "s3_bucket": S3_BUCKET,
# # "s3_key": S3_KEY,
# # "date": date}
# # provide_context=True
# )
# t2 = DockerOperator(
# task_id='docker_command',
# image='selenoid/chrome:latest',
# api_version='auto',
# auto_remove=False,
# command="/bin/sleep 30",
# docker_url="unix://var/run/docker.sock",
# network_mode="bridge"
# )
start = DummyOperator(task_id="start")
docker = BashOperator(
task_id="pull_selenoid_2",
bash_command="sudo docker pull selenoid/chrome:latest",
# retries=3,
# dag=dag
)
pre2 = BashOperator(
task_id="apprise",
bash_command="apprise -vv -t 'my title' -b 'my notification body' pover://umjiu36dxwwaj8pnfx3n6y2xbm3ssx@aejghiy6af1bshe8mbdksmkzeon3ip",
# retries=3,
# dag=dag
)
# t3 = DockerOperator(
# api_version="1.19",
# docker_url="tcp://localhost:2375", # Set your docker URL
# command="/bin/sleep 30",
# image="centos:latest",
# network_mode="bridge",
# task_id="docker_op_tester",
# # dag=dag,
# )
end = DummyOperator(task_id="end")
start >> docker >> pre2 >> end
# [END howto_operator_sqlite_external_file]
# create_table_sqlite_task >> external_create_table_sqlite_task
#

@ -0,0 +1,137 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring
"""
### ETL DAG Tutorial Documentation
This ETL DAG is compatible with Airflow 1.10.x (specifically tested with 1.10.12) and is referenced
as part of the documentation that goes along with the Airflow Functional DAG tutorial located
[here](https://airflow.apache.org/tutorial_decorated_flows.html)
"""
# [START tutorial]
# [START import_module]
import json
from textwrap import dedent
# The DAG object; we'll need this to instantiate a DAG
from airflow import DAG
# Operators; we need this to operate!
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
# [END import_module]
# [START default_args]
# These args will get passed on to each operator
# You can override them on a per-task basis during operator initialization
default_args = {
"owner": "donaldrich",
"email": ["email@gmail.com"],
}
# [END default_args]
# [START instantiate_dag]
with DAG(
"tutorial_etl_dag",
default_args=default_args,
description="ETL DAG tutorial",
schedule_interval=None,
start_date=days_ago(2),
tags=["example"],
) as dag:
# [END instantiate_dag]
# [START documentation]
dag.doc_md = __doc__
# [END documentation]
# [START extract_function]
def extract(**kwargs):
ti = kwargs["ti"]
data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
ti.xcom_push("order_data", data_string)
# [END extract_function]
# [START transform_function]
def transform(**kwargs):
ti = kwargs["ti"]
extract_data_string = ti.xcom_pull(task_ids="extract",
key="order_data")
order_data = json.loads(extract_data_string)
total_order_value = 0
for value in order_data.values():
total_order_value += value
total_value = {"total_order_value": total_order_value}
total_value_json_string = json.dumps(total_value)
ti.xcom_push("total_order_value", total_value_json_string)
# [END transform_function]
# [START load_function]
def load(**kwargs):
ti = kwargs["ti"]
total_value_string = ti.xcom_pull(task_ids="transform",
key="total_order_value")
total_order_value = json.loads(total_value_string)
print(total_order_value)
# [END load_function]
# [START main_flow]
extract_task = PythonOperator(
task_id="extract",
python_callable=extract,
)
extract_task.doc_md = dedent("""\
#### Extract task
A simple Extract task to get data ready for the rest of the data pipeline.
In this case, getting data is simulated by reading from a hardcoded JSON string.
This data is then put into xcom, so that it can be processed by the next task.
""")
transform_task = PythonOperator(
task_id="transform",
python_callable=transform,
)
transform_task.doc_md = dedent("""\
#### Transform task
A simple Transform task which takes in the collection of order data from xcom
and computes the total order value.
This computed value is then put into xcom, so that it can be processed by the next task.
""")
load_task = PythonOperator(
task_id="load",
python_callable=load,
)
load_task.doc_md = dedent("""\
#### Load task
A simple Load task which takes in the result of the Transform task, by reading it
from xcom and instead of saving it to end user review, just prints it out.
""")
extract_task >> transform_task >> load_task
# [END main_flow]
# [END tutorial]

@ -0,0 +1,32 @@
import os
from datetime import datetime
from os import environ
from airflow import DAG
from airflow.hooks.base_hook import BaseHook
from airflow.operators.python_operator import PythonOperator
os.environ[
"AIRFLOW__SECRETS__BACKEND"
] = "airflow.providers.hashicorp.secrets.vault.VaultBackend"
os.environ[
"AIRFLOW__SECRETS__BACKEND_KWARGS"] = '{"connections_path": "myapp", "mount_point": "secret", "auth_type": "token", "token": "token", "url": "http://vault:8200"}'
def get_secrets(**kwargs):
conn = BaseHook.get_connection(kwargs["my_conn_id"])
print("Password:", {conn.password})
print(" Login:", {conn.login})
print(" URI:", {conn.get_uri()})
print("Host:", {conn.host})
with DAG(
"vault_example", start_date=datetime(2020, 1, 1), schedule_interval=None
) as dag:
test_task = PythonOperator(
task_id="test-task",
python_callable=get_secrets,
op_kwargs={"my_conn_id": "smtp_default"},
)

@ -0,0 +1,186 @@
version: "3.7"
x-airflow-common: &airflow-common
image: donaldrich/airflow:latest
# build: .
environment: &airflow-common-env
AIRFLOW__CORE__STORE_SERIALIZED_DAGS: "True"
AIRFLOW__CORE__STORE_DAG_CODE: "True"
AIRFLOW__CORE__EXECUTOR: "CeleryExecutor"
AIRFLOW__CORE__SQL_ALCHEMY_CONN: postgresql+psycopg2://airflow:airflow@postgres-dev:5432/airflow
AIRFLOW__CELERY__RESULT_BACKEND: db+postgresql://airflow:airflow@postgres-dev:5432/airflow
AIRFLOW__CELERY__BROKER_URL: redis://:@redis:6379/0
AIRFLOW__CORE__FERNET_KEY: ""
AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION: "false"
AIRFLOW__CORE__LOAD_EXAMPLES:
"false"
# AIRFLOW__CORE__PARALLELISM: 4
# AIRFLOW__CORE__DAG_CONCURRENCY: 4
# AIRFLOW__CORE__MAX_ACTIVE_RUNS_PER_DAG: 4
AIRFLOW_UID: "1000"
AIRFLOW_GID: "0"
_AIRFLOW_DB_UPGRADE: "true"
_AIRFLOW_WWW_USER_CREATE: "true"
_AIRFLOW_WWW_USER_USERNAME: "{{ user }}"
_AIRFLOW_WWW_USER_PASSWORD: "{{ password }}"
PYTHONPATH: "/data:$$PYTHONPATH"
user: "1000:0"
volumes:
- "./airflow:/opt/airflow"
- "./data:/data"
- "/var/run/docker.sock:/var/run/docker.sock"
# - "/usr/bin/docker:/bin/docker:ro"
networks:
- proxy
- backend
services:
airflow-init:
<<: *airflow-common
container_name: airflow-init
environment:
<<: *airflow-common-env
# depends_on:
# - airflow-db
command: bash -c "airflow db init && airflow db upgrade && airflow users create --role Admin --username {{ user }} --email {{ email }} --firstname Don --lastname Aldrich --password {{ password }}"
airflow-webserver:
<<: *airflow-common
container_name: airflow-webserver
hostname: airflow-webserver
command:
webserver
# command: >
# bash -c 'if [[ -z "$$AIRFLOW__API__AUTH_BACKEND" ]] && [[ $$(pip show -f apache-airflow | grep basic_auth.py) ]];
# then export AIRFLOW__API__AUTH_BACKEND=airflow.api.auth.backend.basic_auth ;
# else export AIRFLOW__API__AUTH_BACKEND=airflow.api.auth.backend.default ; fi &&
# { airflow create_user "$$@" || airflow users create "$$@" ; } &&
# { airflow sync_perm || airflow sync-perm ;} &&
# airflow webserver' -- -r Admin -u admin -e admin@example.com -f admin -l user -p admin
healthcheck:
test: ["CMD", "curl", "--fail", "http://localhost:8080/health"]
interval: 10s
timeout: 10s
retries: 5
restart: always
privileged: true
depends_on:
- airflow-scheduler
environment:
<<: *airflow-common-env
labels:
traefik.http.services.airflow.loadbalancer.server.port: "8080"
traefik.enable: "true"
traefik.http.routers.airflow.entrypoints: "https"
traefik.http.routers.airflow.tls.certResolver: "cloudflare"
traefik.http.routers.airflow.rule: "Host(`airflow.{{ domain }}.com`)"
traefik.http.routers.airflow.middlewares: "ip-whitelist@file"
traefik.http.routers.airflow.service: "airflow"
airflow-scheduler:
<<: *airflow-common
command: scheduler
container_name: airflow-scheduler
hostname: airflow-scheduler
restart: always
depends_on:
- airflow-init
environment:
<<: *airflow-common-env
# # entrypoint: sh -c '/app/scripts/wait-for postgres:5432 -- airflow db init && airflow scheduler'
airflow-worker:
<<: *airflow-common
command: celery worker
restart: always
container_name: airflow-worker
hostname: airflow-worker
airflow-queue:
<<: *airflow-common
command: celery flower
container_name: airflow-queue
hostname: airflow-queue
ports:
- 5555:5555
healthcheck:
test: ["CMD", "curl", "--fail", "http://localhost:5555/"]
interval: 10s
timeout: 10s
retries: 5
restart: always
dbt:
image: fishtownanalytics/dbt:0.19.1
container_name: dbt
volumes:
- "/home/{{ user }}/projects/jobfunnel:/data"
- "/home/{{ user }}/projects/jobfunnel/dbt:/usr/app"
# - "dbt-db:/var/lib/postgresql/data"
ports:
- "8081"
networks:
- proxy
- backend
command: docs serve --project-dir /data/transform
working_dir: "/data/transform"
environment:
DBT_SCHEMA: dbt
DBT_RAW_DATA_SCHEMA: dbt_raw_data
DBT_PROFILES_DIR: "/data/transform/profile"
# DBT_PROJECT_DIR: "/data/transform"
# DBT_PROFILES_DIR: "/data"
# DBT_POSTGRES_PASSWORD: dbt
# DBT_POSTGRES_USER : dbt
# DBT_POSTGRES_DB : dbt
# DBT_DBT_SCHEMA: dbt
# DBT_DBT_RAW_DATA_SCHEMA: dbt_raw_data
# DBT_POSTGRES_HOST: dbt-db
meltano:
image: meltano/meltano:latest
container_name: meltano
volumes:
- "/home/{{ user }}/projects/jobfunnel:/project"
ports:
- "5000:5000"
networks:
- proxy
- backend
# environment:
# MELTANO_UI_SERVER_NAME: "etl.{{ domain }}.com"
# MELTANO_DATABASE_URI: "postgresql://meltano:meltano@meltano-db:5432/meltano"
# MELTANO_WEBAPP_POSTGRES_URL: localhost
# MELTANO_WEBAPP_POSTGRES_DB=meltano
# MELTANO_WEBAPP_POSTGRES_USER=meltano
# MELTANO_WEBAPP_POSTGRES_PASSWORD=meltano
# MELTANO_WEBAPP_POSTGRES_PORT=5501
# MELTANO_WEBAPP_LOG_PATH="/tmp/meltano.log"
# MELTANO_WEBAPP_API_URL="http://localhost:5000"
# LOG_PATH="/tmp/meltano.log"
# API_URL="http://localhost:5000"
# MELTANO_MODEL_DIR="./model"
# MELTANO_TRANSFORM_DIR="./transform"
# MELTANO_UI_SESSION_COOKIE_DOMAIN: "etl.{{ domain }}.com"
# MELTANO_UI_SESSION_COOKIE_SECURE: "true"
labels:
traefik.http.services.meltano.loadbalancer.server.port: "5000"
traefik.enable: "true"
traefik.http.routers.meltano.entrypoints: "https"
traefik.http.routers.meltano.tls.certResolver: "cloudflare"
traefik.http.routers.meltano.rule: "Host(`etl.{{ domain }}.com`)"
traefik.http.routers.meltano.middlewares: "secured@file,ip-whitelist@file"
traefik.http.routers.meltano.service: "meltano"
volumes:
tmp_airflow:
driver: local
airflow-db:
driver: local
dbt-db:
driver: local
networks:
proxy:
external: true
backend:
external: true

@ -0,0 +1,49 @@
import csv
import logging
import os
import platform
import random
import re
import time
# from webdriver_manager.chrome import ChromeDriverManager
from datetime import datetime
from urllib.request import urlopen
# import pyautogui
import pandas as pd
import yaml
from bs4 import BeautifulSoup
# import mouseinfo
from selenium import webdriver
from selenium.common.exceptions import NoSuchElementException, TimeoutException
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.by import By
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait
def setupLogger():
dt = datetime.strftime(datetime.now(), "%m_%d_%y %H_%M_%S ")
if not os.path.isdir("./logs"):
os.mkdir("./logs")
# TODO need to check if there is a log dir available or not
logging.basicConfig(
filename=("./logs/" + str(dt) + "applyJobs.log"),
filemode="w",
format="%(asctime)s::%(name)s::%(levelname)s::%(message)s",
datefmt="./logs/%d-%b-%y %H:%M:%S",
)
log.setLevel(logging.DEBUG)
c_handler = logging.StreamHandler()
c_handler.setLevel(logging.DEBUG)
c_format = logging.Formatter(
"%(asctime)s - %(levelname)s - %(message)s", "%H:%M:%S"
)
c_handler.setFormatter(c_format)
log.addHandler(c_handler)

@ -0,0 +1,145 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import json
import re
from typing import Any, Dict, Optional
from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpHook
class DiscordWebhookHook(HttpHook):
"""
This hook allows you to post messages to Discord using incoming webhooks.
Takes a Discord connection ID with a default relative webhook endpoint. The
default endpoint can be overridden using the webhook_endpoint parameter
(https://discordapp.com/developers/docs/resources/webhook).
Each Discord webhook can be pre-configured to use a specific username and
avatar_url. You can override these defaults in this hook.
:param http_conn_id: Http connection ID with host as "https://discord.com/api/" and
default webhook endpoint in the extra field in the form of
{"webhook_endpoint": "webhooks/{webhook.id}/{webhook.token}"}
:type http_conn_id: str
:param webhook_endpoint: Discord webhook endpoint in the form of
"webhooks/{webhook.id}/{webhook.token}"
:type webhook_endpoint: str
:param message: The message you want to send to your Discord channel
(max 2000 characters)
:type message: str
:param username: Override the default username of the webhook
:type username: str
:param avatar_url: Override the default avatar of the webhook
:type avatar_url: str
:param tts: Is a text-to-speech message
:type tts: bool
:param proxy: Proxy to use to make the Discord webhook call
:type proxy: str
"""
def __init__(
self,
http_conn_id: Optional[str] = None,
webhook_endpoint: Optional[str] = None,
message: str = "",
username: Optional[str] = None,
avatar_url: Optional[str] = None,
tts: bool = False,
proxy: Optional[str] = None,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.http_conn_id: Any = http_conn_id
self.webhook_endpoint = self._get_webhook_endpoint(
http_conn_id, webhook_endpoint
)
self.message = message
self.username = username
self.avatar_url = avatar_url
self.tts = tts
self.proxy = proxy
def _get_webhook_endpoint(
self, http_conn_id: Optional[str], webhook_endpoint: Optional[str]
) -> str:
"""
Given a Discord http_conn_id, return the default webhook endpoint or override if a
webhook_endpoint is manually supplied.
:param http_conn_id: The provided connection ID
:param webhook_endpoint: The manually provided webhook endpoint
:return: Webhook endpoint (str) to use
"""
if webhook_endpoint:
endpoint = webhook_endpoint
elif http_conn_id:
conn = self.get_connection(http_conn_id)
extra = conn.extra_dejson
endpoint = extra.get("webhook_endpoint", "")
else:
raise AirflowException(
"Cannot get webhook endpoint: No valid Discord webhook endpoint or http_conn_id supplied."
)
# make sure endpoint matches the expected Discord webhook format
if not re.match("^webhooks/[0-9]+/[a-zA-Z0-9_-]+$", endpoint):
raise AirflowException(
'Expected Discord webhook endpoint in the form of "webhooks/{webhook.id}/{webhook.token}".'
)
return endpoint
def _build_discord_payload(self) -> str:
"""
Construct the Discord JSON payload. All relevant parameters are combined here
to a valid Discord JSON payload.
:return: Discord payload (str) to send
"""
payload: Dict[str, Any] = {}
if self.username:
payload["username"] = self.username
if self.avatar_url:
payload["avatar_url"] = self.avatar_url
payload["tts"] = self.tts
if len(self.message) <= 2000:
payload["content"] = self.message
else:
raise AirflowException(
"Discord message length must be 2000 or fewer characters."
)
return json.dumps(payload)
def execute(self) -> None:
"""Execute the Discord webhook call"""
proxies = {}
if self.proxy:
# we only need https proxy for Discord
proxies = {"https": self.proxy}
discord_payload = self._build_discord_payload()
self.run(
endpoint=self.webhook_endpoint,
data=discord_payload,
headers={"Content-type": "application/json"},
extra_options={"proxies": proxies},
)

@ -0,0 +1,108 @@
"""
There are two ways to authenticate the Google Analytics Hook.
If you have already obtained an OAUTH token, place it in the password field
of the relevant connection.
If you don't have an OAUTH token, you may authenticate by passing a
'client_secrets' object to the extras section of the relevant connection. This
object will expect the following fields and use them to generate an OAUTH token
on execution.
"type": "service_account",
"project_id": "example-project-id",
"private_key_id": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX",
"private_key": "-----BEGIN PRIVATE KEY-----\nXXXXX\n-----END PRIVATE KEY-----\n",
"client_email": "google-analytics@{PROJECT_ID}.iam.gserviceaccount.com",
"client_id": "XXXXXXXXXXXXXXXXXXXXXX",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://accounts.google.com/o/oauth2/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "{CERT_URL}"
More details can be found here:
https://developers.google.com/api-client-library/python/guide/aaa_client_secrets
"""
import time
from airflow.hooks.base_hook import BaseHook
from apiclient.discovery import build
from oauth2client.client import AccessTokenCredentials
from oauth2client.service_account import ServiceAccountCredentials
class GoogleHook(BaseHook):
def __init__(self, google_conn_id="google_default"):
self.google_analytics_conn_id = google_conn_id
self.connection = self.get_connection(google_conn_id)
if self.connection.extra_dejson.get("client_secrets", None):
self.client_secrets = self.connection.extra_dejson["client_secrets"]
def get_service_object(self, api_name, api_version, scopes=None):
if self.connection.password:
credentials = AccessTokenCredentials(
self.connection.password, "Airflow/1.0"
)
elif self.client_secrets:
credentials = ServiceAccountCredentials.from_json_keyfile_dict(
self.client_secrets, scopes
)
return build(api_name, api_version, credentials=credentials)
def get_analytics_report(
self,
view_id,
since,
until,
sampling_level,
dimensions,
metrics,
page_size,
include_empty_rows,
):
analytics = self.get_service_object(
"analyticsreporting",
"v4",
["https://www.googleapis.com/auth/analytics.readonly"],
)
reportRequest = {
"viewId": view_id,
"dateRanges": [{"startDate": since, "endDate": until}],
"samplingLevel": sampling_level or "LARGE",
"dimensions": dimensions,
"metrics": metrics,
"pageSize": page_size or 1000,
"includeEmptyRows": include_empty_rows or False,
}
response = (
analytics.reports()
.batchGet(body={"reportRequests": [reportRequest]})
.execute()
)
if response.get("reports"):
report = response["reports"][0]
rows = report.get("data", {}).get("rows", [])
while report.get("nextPageToken"):
time.sleep(1)
reportRequest.update({"pageToken": report["nextPageToken"]})
response = (
analytics.reports()
.batchGet(body={"reportRequests": [reportRequest]})
.execute()
)
report = response["reports"][0]
rows.extend(report.get("data", {}).get("rows", []))
if report["data"]:
report["data"]["rows"] = rows
return report
else:
return {}

@ -0,0 +1,95 @@
import logging
import logging as log
import os
import time
import docker
from airflow.hooks.base_hook import BaseHook
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
class SeleniumHook(BaseHook):
"""
Creates a Selenium Docker container on the host and controls the
browser by sending commands to the remote server.
"""
def __init__(self):
logging.info("initialised hook")
pass
def create_container(self):
"""
Creates the selenium docker container
"""
logging.info("creating_container")
cwd = os.getcwd()
self.local_downloads = os.path.join(cwd, "downloads")
self.sel_downloads = "/home/seluser/downloads"
volumes = [
"{}:{}".format(self.local_downloads, self.sel_downloads),
"/dev/shm:/dev/shm",
]
client = docker.from_env()
container = client.containers.run(
"selenium/standalone-chrome",
volumes=volumes,
network="container_bridge",
detach=True,
)
self.container = container
cli = docker.APIClient()
self.container_ip = cli.inspect_container(container.id)["NetworkSettings"][
"Networks"
]["container_bridge"]["IPAddress"]
def create_driver(self):
"""
creates and configure the remote Selenium webdriver.
"""
logging.info("creating driver")
options = Options()
options.add_argument("--headless")
options.add_argument("--window-size=1920x1080")
chrome_driver = "{}:4444/wd/hub".format(self.container_ip)
# chrome_driver = '{}:4444/wd/hub'.format('http://127.0.0.1') # local
# wait for remote, unless timeout.
while True:
try:
driver = webdriver.Remote(
command_executor=chrome_driver,
desired_capabilities=DesiredCapabilities.CHROME,
options=options,
)
print("remote ready")
break
except:
print("remote not ready, sleeping for ten seconds.")
time.sleep(10)
# Enable downloads in headless chrome.
driver.command_executor._commands["send_command"] = (
"POST",
"/session/$sessionId/chromium/send_command",
)
params = {
"cmd": "Page.setDownloadBehavior",
"params": {"behavior": "allow", "downloadPath": self.sel_downloads},
}
driver.execute("send_command", params)
self.driver = driver
def remove_container(self):
"""
This removes the Selenium container.
"""
self.container.remove(force=True)
print("Removed container: {}".format(self.container.id))
def run_script(self, script, args):
"""
This is a wrapper around the python script which sends commands to
the docker container. The first variable of the script must be the web driver.
"""
script(self.driver, *args)

@ -0,0 +1,2 @@
from operators.singer import SingerOperator
from operators.sudo_bash import SudoBashOperator

@ -0,0 +1,95 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from typing import Dict, Optional
from airflow.exceptions import AirflowException
from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookHook
from airflow.providers.http.operators.http import SimpleHttpOperator
from airflow.utils.decorators import apply_defaults
class DiscordWebhookOperator(SimpleHttpOperator):
"""
This operator allows you to post messages to Discord using incoming webhooks.
Takes a Discord connection ID with a default relative webhook endpoint. The
default endpoint can be overridden using the webhook_endpoint parameter
(https://discordapp.com/developers/docs/resources/webhook).
Each Discord webhook can be pre-configured to use a specific username and
avatar_url. You can override these defaults in this operator.
:param http_conn_id: Http connection ID with host as "https://discord.com/api/" and
default webhook endpoint in the extra field in the form of
{"webhook_endpoint": "webhooks/{webhook.id}/{webhook.token}"}
:type http_conn_id: str
:param webhook_endpoint: Discord webhook endpoint in the form of
"webhooks/{webhook.id}/{webhook.token}"
:type webhook_endpoint: str
:param message: The message you want to send to your Discord channel
(max 2000 characters). (templated)
:type message: str
:param username: Override the default username of the webhook. (templated)
:type username: str
:param avatar_url: Override the default avatar of the webhook
:type avatar_url: str
:param tts: Is a text-to-speech message
:type tts: bool
:param proxy: Proxy to use to make the Discord webhook call
:type proxy: str
"""
template_fields = ["username", "message"]
@apply_defaults
def __init__(
self,
*,
http_conn_id: Optional[str] = None,
webhook_endpoint: Optional[str] = None,
message: str = "",
username: Optional[str] = None,
avatar_url: Optional[str] = None,
tts: bool = False,
proxy: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(endpoint=webhook_endpoint, **kwargs)
if not http_conn_id:
raise AirflowException("No valid Discord http_conn_id supplied.")
self.http_conn_id = http_conn_id
self.webhook_endpoint = webhook_endpoint
self.message = message
self.username = username
self.avatar_url = avatar_url
self.tts = tts
self.proxy = proxy
self.hook: Optional[DiscordWebhookHook] = None
def execute(self, context: Dict) -> None:
"""Call the DiscordWebhookHook to post message"""
self.hook = DiscordWebhookHook(
self.http_conn_id,
self.webhook_endpoint,
self.message,
self.username,
self.avatar_url,
self.tts,
self.proxy,
)
self.hook.execute()

@ -0,0 +1,254 @@
import gzip
import json
import logging
import os
import time
import boa
from airflow.hooks.base_hook import BaseHook
from airflow.hooks.S3_hook import S3Hook
from airflow.models import BaseOperator, Variable
from hooks.google import GoogleHook
from six import BytesIO
class GoogleSheetsToS3Operator(BaseOperator):
"""
Google Sheets To S3 Operator
:param google_conn_id: The Google connection id.
:type google_conn_id: string
:param sheet_id: The id for associated report.
:type sheet_id: string
:param sheet_names: The name for the relevent sheets in the report.
:type sheet_names: string/array
:param range: The range of of cells containing the relevant data.
This must be the same for all sheets if multiple
are being pulled together.
Example: Sheet1!A2:E80
:type range: string
:param include_schema: If set to true, infer the schema of the data and
output to S3 as a separate file
:type include_schema: boolean
:param s3_conn_id: The s3 connection id.
:type s3_conn_id: string
:param s3_key: The S3 key to be used to store the
retrieved data.
:type s3_key: string
"""
template_fields = ("s3_key",)
def __init__(
self,
google_conn_id,
sheet_id,
s3_conn_id,
s3_key,
compression_bound,
include_schema=False,
sheet_names=[],
range=None,
output_format="json",
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.google_conn_id = google_conn_id
self.sheet_id = sheet_id
self.sheet_names = sheet_names
self.s3_conn_id = s3_conn_id
self.s3_key = s3_key
self.include_schema = include_schema
self.range = range
self.output_format = output_format.lower()
self.compression_bound = compression_bound
if self.output_format not in ("json"):
raise Exception("Acceptable output formats are: json.")
if self.sheet_names and not isinstance(self.sheet_names, (str, list, tuple)):
raise Exception("Please specify the sheet names as a string or list.")
def execute(self, context):
g_conn = GoogleHook(self.google_conn_id)
if isinstance(self.sheet_names, str) and "," in self.sheet_names:
sheet_names = self.sheet_names.split(",")
else:
sheet_names = self.sheet_names
sheets_object = g_conn.get_service_object("sheets", "v4")
logging.info("Retrieved Sheets Object")
response = (
sheets_object.spreadsheets()
.get(spreadsheetId=self.sheet_id, includeGridData=True)
.execute()
)
title = response.get("properties").get("title")
sheets = response.get("sheets")
final_output = dict()
total_sheets = []
for sheet in sheets:
name = sheet.get("properties").get("title")
name = boa.constrict(name)
total_sheets.append(name)
if self.sheet_names:
if name not in sheet_names:
logging.info(
"{} is not found in available sheet names.".format(name)
)
continue
table_name = name
data = sheet.get("data")[0].get("rowData")
output = []
for row in data:
row_data = []
values = row.get("values")
for value in values:
ev = value.get("effectiveValue")
if ev is None:
row_data.append(None)
else:
for v in ev.values():
row_data.append(v)
output.append(row_data)
if self.output_format == "json":
headers = output.pop(0)
output = [dict(zip(headers, row)) for row in output]
final_output[table_name] = output
s3 = S3Hook(self.s3_conn_id)
for sheet in final_output:
output_data = final_output.get(sheet)
file_name, file_extension = os.path.splitext(self.s3_key)
output_name = "".join([file_name, "_", sheet, file_extension])
if self.include_schema is True:
schema_name = "".join(
[file_name, "_", sheet, "_schema", file_extension]
)
self.output_manager(
s3, output_name, output_data, context, sheet, schema_name
)
dag_id = context["ti"].dag_id
var_key = "_".join([dag_id, self.sheet_id])
Variable.set(key=var_key, value=json.dumps(total_sheets))
time.sleep(10)
return boa.constrict(title)
def output_manager(
self, s3, output_name, output_data, context, sheet_name, schema_name=None
):
self.s3_bucket = BaseHook.get_connection(self.s3_conn_id).host
if self.output_format == "json":
output = "\n".join(
[
json.dumps({boa.constrict(str(k)): v for k, v in record.items()})
for record in output_data
]
)
enc_output = str.encode(output, "utf-8")
# if file is more than bound then apply gzip compression
if len(enc_output) / 1024 / 1024 >= self.compression_bound:
logging.info(
"File is more than {}MB, gzip compression will be applied".format(
self.compression_bound
)
)
output = gzip.compress(enc_output, compresslevel=5)
self.xcom_push(
context,
key="is_compressed_{}".format(sheet_name),
value="compressed",
)
self.load_bytes(
s3,
bytes_data=output,
key=output_name,
bucket_name=self.s3_bucket,
replace=True,
)
else:
logging.info(
"File is less than {}MB, compression will not be applied".format(
self.compression_bound
)
)
self.xcom_push(
context,
key="is_compressed_{}".format(sheet_name),
value="non-compressed",
)
s3.load_string(
string_data=output,
key=output_name,
bucket_name=self.s3_bucket,
replace=True,
)
if self.include_schema is True:
output_keys = output_data[0].keys()
schema = [
{"name": boa.constrict(a), "type": "varchar(512)"}
for a in output_keys
if a is not None
]
schema = {"columns": schema}
s3.load_string(
string_data=json.dumps(schema),
key=schema_name,
bucket_name=self.s3_bucket,
replace=True,
)
logging.info('Successfully output of "{}" to S3.'.format(output_name))
# TODO -- Add support for csv output
# elif self.output_format == 'csv':
# with NamedTemporaryFile("w") as f:
# writer = csv.writer(f)
# writer.writerows(output_data)
# s3.load_file(
# filename=f.name,
# key=output_name,
# bucket_name=self.s3_bucket,
# replace=True
# )
#
# if self.include_schema is True:
# pass
# TODO: remove when airflow version is upgraded to 1.10
def load_bytes(self, s3, bytes_data, key, bucket_name=None, replace=False):
if not bucket_name:
(bucket_name, key) = s3.parse_s3_url(key)
if not replace and s3.check_for_key(key, bucket_name):
raise ValueError("The key {key} already exists.".format(key=key))
filelike_buffer = BytesIO(bytes_data)
client = s3.get_conn()
client.upload_fileobj(filelike_buffer, bucket_name, key, ExtraArgs={})

@ -0,0 +1,24 @@
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from hooks.selenium_hook import SeleniumHook
class SeleniumOperator(BaseOperator):
"""
Selenium Operator
"""
template_fields = ["script_args"]
@apply_defaults
def __init__(self, script, script_args, *args, **kwargs):
super().__init__(*args, **kwargs)
self.script = script
self.script_args = script_args
def execute(self, context):
hook = SeleniumHook()
hook.create_container()
hook.create_driver()
hook.run_script(self.script, self.script_args)
hook.remove_container()

@ -0,0 +1,45 @@
from airflow.operators.bash_operator import BashOperator
from airflow.utils.decorators import apply_defaults
class SingerOperator(BashOperator):
"""
Singer Plugin
:param tap: The relevant Singer Tap.
:type tap: string
:param target: The relevant Singer Target
:type target: string
:param tap_config: The config path for the Singer Tap.
:type tap_config: string
:param target_config: The config path for the Singer Target.
:type target_config: string
"""
@apply_defaults
def __init__(
self, tap, target, tap_config=None, target_config=None, *args, **kwargs
):
self.tap = "tap-{}".format(tap)
self.target = "target-{}".format(target)
self.tap_config = tap_config
self.target_config = target_config
if self.tap_config:
if self.target_config:
self.bash_command = "{} -c {} | {} -c {}".format(
self.tap, self.tap_config, self.target, self.target_config
)
else:
self.bash_command = "{} -c {} | {}".format(
self.tap, self.tap_config, self.target
)
elif self.target_config:
self.bash_command = "{} | {} -c {}".format(
self.tap, self.target, self.target_config
)
else:
self.bash_command = "{} | {}".format(self.tap, self.target)
super().__init__(bash_command=self.bash_command, *args, **kwargs)

@ -0,0 +1,110 @@
import getpass
import logging
import os
from builtins import bytes
from subprocess import PIPE, STDOUT, Popen
from tempfile import NamedTemporaryFile, gettempdir
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.utils.file import TemporaryDirectory
class SudoBashOperator(BaseOperator):
"""
Execute a Bash script, command or set of commands but sudo's as another user to execute them.
:param bash_command: The command, set of commands or reference to a
bash script (must be '.sh') to be executed.
:type bash_command: string
:param user: The user to run the command as. The Airflow worker user
must have permission to sudo as that user
:type user: string
:param env: If env is not None, it must be a mapping that defines the
environment variables for the new process; these are used instead
of inheriting the current process environment, which is the default
behavior.
:type env: dict
:type output_encoding: output encoding of bash command
"""
template_fields = ("bash_command", "user", "env", "output_encoding")
template_ext = (
".sh",
".bash",
)
ui_color = "#f0ede4"
@apply_defaults
def __init__(
self,
bash_command,
user,
xcom_push=False,
env=None,
output_encoding="utf-8",
*args,
**kwargs,
):
super(SudoBashOperator, self).__init__(*args, **kwargs)
self.bash_command = bash_command
self.user = user
self.env = env
self.xcom_push_flag = xcom_push
self.output_encoding = output_encoding
def execute(self, context):
"""
Execute the bash command in a temporary directory which will be cleaned afterwards.
"""
logging.info("tmp dir root location: \n" + gettempdir())
with TemporaryDirectory(prefix="airflowtmp") as tmp_dir:
os.chmod(tmp_dir, 777)
# Ensure the sudo user has perms to their current working directory for making tempfiles
# This is not really a security flaw because the only thing in that dir is the
# temp script, owned by the airflow user and any temp files made by the sudo user
# and all of those will be created with the owning user's umask
# If a process needs finer control over the tempfiles it creates, that process can chmod
# them as they are created.
with NamedTemporaryFile(dir=tmp_dir, prefix=self.task_id) as f:
if self.user == getpass.getuser(): # don't try to sudo as yourself
f.write(bytes(self.bash_command, "utf_8"))
else:
sudo_cmd = "sudo -u {} sh -c '{}'".format(
self.user, self.bash_command
)
f.write(bytes(sudo_cmd, "utf_8"))
f.flush()
logging.info("Temporary script location: {0}".format(f.name))
logging.info("Running command: {}".format(self.bash_command))
self.sp = Popen(
["bash", f.name],
stdout=PIPE,
stderr=STDOUT,
cwd=tmp_dir,
env=self.env,
)
logging.info("Output:")
line = ""
for line in iter(self.sp.stdout.readline, b""):
line = line.decode(self.output_encoding).strip()
logging.info(line)
self.sp.wait()
logging.info(
"Command exited with return code {0}".format(self.sp.returncode)
)
if self.sp.returncode:
raise AirflowException("Bash command failed")
if self.xcom_push_flag:
return line
def on_kill(self):
logging.warn("Sending SIGTERM signal to bash subprocess")
self.sp.terminate()

@ -0,0 +1,15 @@
from airflow.plugins_manager import AirflowPlugin
from hooks.google import GoogleHook
from operators.google import GoogleSheetsToS3Operator
class google_sheets_plugin(AirflowPlugin):
name = "GoogleSheetsPlugin"
operators = [GoogleSheetsToS3Operator]
# Leave in for explicitness
hooks = [GoogleHook]
executors = []
macros = []
admin_views = []
flask_blueprints = []
menu_links = []

@ -0,0 +1,13 @@
from airflow.plugins_manager import AirflowPlugin
from operators.singer import SingerOperator
class SingerPlugin(AirflowPlugin):
name = "singer_plugin"
hooks = []
operators = [SingerOperator]
executors = []
macros = []
admin_views = []
flask_blueprints = []
menu_links = []

@ -0,0 +1,25 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
Changelog
---------
1.0.0
.....
Initial version of the provider.

@ -0,0 +1,17 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

@ -0,0 +1,64 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example DAG demonstrating the usage of the AirbyteTriggerSyncOperator."""
from datetime import timedelta
from airflow import DAG
from airflow.providers.airbyte.operators.airbyte import AirbyteTriggerSyncOperator
from airflow.providers.airbyte.sensors.airbyte import AirbyteJobSensor
from airflow.utils.dates import days_ago
args = {
"owner": "airflow",
}
with DAG(
dag_id="example_airbyte_operator",
default_args=args,
schedule_interval=None,
start_date=days_ago(1),
dagrun_timeout=timedelta(minutes=60),
tags=["example"],
) as dag:
# [START howto_operator_airbyte_synchronous]
sync_source_destination = AirbyteTriggerSyncOperator(
task_id="airbyte_sync_source_dest_example",
airbyte_conn_id="airbyte_default",
connection_id="15bc3800-82e4-48c3-a32d-620661273f28",
)
# [END howto_operator_airbyte_synchronous]
# [START howto_operator_airbyte_asynchronous]
async_source_destination = AirbyteTriggerSyncOperator(
task_id="airbyte_async_source_dest_example",
airbyte_conn_id="airbyte_default",
connection_id="15bc3800-82e4-48c3-a32d-620661273f28",
asynchronous=True,
)
airbyte_sensor = AirbyteJobSensor(
task_id="airbyte_sensor_source_dest_example",
airbyte_job_id="{{task_instance.xcom_pull(task_ids='airbyte_async_source_dest_example')}}",
airbyte_conn_id="airbyte_default",
)
# [END howto_operator_airbyte_asynchronous]
async_source_destination >> airbyte_sensor

@ -0,0 +1,17 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

@ -0,0 +1,123 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import time
from typing import Any, Optional
from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpHook
class AirbyteHook(HttpHook):
"""
Hook for Airbyte API
:param airbyte_conn_id: Required. The name of the Airflow connection to get
connection information for Airbyte.
:type airbyte_conn_id: str
:param api_version: Optional. Airbyte API version.
:type api_version: str
"""
RUNNING = "running"
SUCCEEDED = "succeeded"
CANCELLED = "cancelled"
PENDING = "pending"
FAILED = "failed"
ERROR = "error"
def __init__(
self,
airbyte_conn_id: str = "airbyte_default",
api_version: Optional[str] = "v1",
) -> None:
super().__init__(http_conn_id=airbyte_conn_id)
self.api_version: str = api_version
def wait_for_job(
self,
job_id: str,
wait_seconds: Optional[float] = 3,
timeout: Optional[float] = 3600,
) -> None:
"""
Helper method which polls a job to check if it finishes.
:param job_id: Required. Id of the Airbyte job
:type job_id: str
:param wait_seconds: Optional. Number of seconds between checks.
:type wait_seconds: float
:param timeout: Optional. How many seconds wait for job to be ready.
Used only if ``asynchronous`` is False.
:type timeout: float
"""
state = None
start = time.monotonic()
while True:
if timeout and start + timeout < time.monotonic():
raise AirflowException(
f"Timeout: Airbyte job {job_id} is not ready after {timeout}s"
)
time.sleep(wait_seconds)
try:
job = self.get_job(job_id=job_id)
state = job.json()["job"]["status"]
except AirflowException as err:
self.log.info(
"Retrying. Airbyte API returned server error when waiting for job: %s",
err,
)
continue
if state in (self.RUNNING, self.PENDING):
continue
if state == self.SUCCEEDED:
break
if state == self.ERROR:
raise AirflowException(f"Job failed:\n{job}")
elif state == self.CANCELLED:
raise AirflowException(f"Job was cancelled:\n{job}")
else:
raise Exception(
f"Encountered unexpected state `{state}` for job_id `{job_id}`"
)
def submit_sync_connection(self, connection_id: str) -> Any:
"""
Submits a job to a Airbyte server.
:param connection_id: Required. The ConnectionId of the Airbyte Connection.
:type connectiond_id: str
"""
return self.run(
endpoint=f"api/{self.api_version}/connections/sync",
json={"connectionId": connection_id},
headers={"accept": "application/json"},
)
def get_job(self, job_id: int) -> Any:
"""
Gets the resource representation for a job in Airbyte.
:param job_id: Required. Id of the Airbyte job
:type job_id: int
"""
return self.run(
endpoint=f"api/{self.api_version}/jobs/get",
json={"id": job_id},
headers={"accept": "application/json"},
)

@ -0,0 +1,17 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

@ -0,0 +1,89 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional
from airflow.models import BaseOperator
from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
from airflow.utils.decorators import apply_defaults
class AirbyteTriggerSyncOperator(BaseOperator):
"""
This operator allows you to submit a job to an Airbyte server to run a integration
process between your source and destination.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:AirbyteTriggerSyncOperator`
:param airbyte_conn_id: Required. The name of the Airflow connection to get connection
information for Airbyte.
:type airbyte_conn_id: str
:param connection_id: Required. The Airbyte ConnectionId UUID between a source and destination.
:type connection_id: str
:param asynchronous: Optional. Flag to get job_id after submitting the job to the Airbyte API.
This is useful for submitting long running jobs and
waiting on them asynchronously using the AirbyteJobSensor.
:type asynchronous: bool
:param api_version: Optional. Airbyte API version.
:type api_version: str
:param wait_seconds: Optional. Number of seconds between checks. Only used when ``asynchronous`` is False.
:type wait_seconds: float
:param timeout: Optional. The amount of time, in seconds, to wait for the request to complete.
Only used when ``asynchronous`` is False.
:type timeout: float
"""
template_fields = ("connection_id",)
@apply_defaults
def __init__(
self,
connection_id: str,
airbyte_conn_id: str = "airbyte_default",
asynchronous: Optional[bool] = False,
api_version: Optional[str] = "v1",
wait_seconds: Optional[float] = 3,
timeout: Optional[float] = 3600,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.airbyte_conn_id = airbyte_conn_id
self.connection_id = connection_id
self.timeout = timeout
self.api_version = api_version
self.wait_seconds = wait_seconds
self.asynchronous = asynchronous
def execute(self, context) -> None:
"""Create Airbyte Job and wait to finish"""
hook = AirbyteHook(
airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version
)
job_object = hook.submit_sync_connection(connection_id=self.connection_id)
job_id = job_object.json()["job"]["id"]
self.log.info("Job %s was submitted to Airbyte Server", job_id)
if not self.asynchronous:
self.log.info("Waiting for job %s to complete", job_id)
hook.wait_for_job(
job_id=job_id, wait_seconds=self.wait_seconds, timeout=self.timeout
)
self.log.info("Job %s completed successfully", job_id)
return job_id

@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
---
package-name: apache-airflow-providers-airbyte
name: Airbyte
description: |
`Airbyte <https://airbyte.io/>`__
versions:
- 1.0.0
integrations:
- integration-name: Airbyte
external-doc-url: https://www.airbyte.io/
logo: /integration-logos/airbyte/Airbyte.png
how-to-guide:
- /docs/apache-airflow-providers-airbyte/operators/airbyte.rst
tags: [service]
operators:
- integration-name: Airbyte
python-modules:
- airflow.providers.airbyte.operators.airbyte
hooks:
- integration-name: Airbyte
python-modules:
- airflow.providers.airbyte.hooks.airbyte
sensors:
- integration-name: Airbyte
python-modules:
- airflow.providers.airbyte.sensors.airbyte
hook-class-names:
- airflow.providers.airbyte.hooks.airbyte.AirbyteHook

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

@ -0,0 +1,75 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains a Airbyte Job sensor."""
from typing import Optional
from airflow.exceptions import AirflowException
from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
from airflow.sensors.base import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
class AirbyteJobSensor(BaseSensorOperator):
"""
Check for the state of a previously submitted Airbyte job.
:param airbyte_job_id: Required. Id of the Airbyte job
:type airbyte_job_id: str
:param airbyte_conn_id: Required. The name of the Airflow connection to get
connection information for Airbyte.
:type airbyte_conn_id: str
:param api_version: Optional. Airbyte API version.
:type api_version: str
"""
template_fields = ("airbyte_job_id",)
ui_color = "#6C51FD"
@apply_defaults
def __init__(
self,
*,
airbyte_job_id: str,
airbyte_conn_id: str = "airbyte_default",
api_version: Optional[str] = "v1",
**kwargs,
) -> None:
super().__init__(**kwargs)
self.airbyte_conn_id = airbyte_conn_id
self.airbyte_job_id = airbyte_job_id
self.api_version = api_version
def poke(self, context: dict) -> bool:
hook = AirbyteHook(
airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version
)
job = hook.get_job(job_id=self.airbyte_job_id)
status = job.json()["job"]["status"]
if status == hook.FAILED:
raise AirflowException(f"Job failed: \n{job}")
elif status == hook.CANCELLED:
raise AirflowException(f"Job was cancelled: \n{job}")
elif status == hook.SUCCEEDED:
self.log.info("Job %s completed successfully.", self.airbyte_job_id)
return True
elif status == hook.ERROR:
self.log.info("Job %s attempt has failed.", self.airbyte_job_id)
self.log.info("Waiting for job %s to complete.", self.airbyte_job_id)
return False

@ -0,0 +1,63 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
Changelog
---------
1.2.0
.....
Features
~~~~~~~~
* ``Avoid using threads in S3 remote logging upload (#14414)``
* ``Allow AWS Operator RedshiftToS3Transfer To Run a Custom Query (#14177)``
* ``includes the STS token if STS credentials are used (#11227)``
1.1.0
.....
Features
~~~~~~~~
* ``Adding support to put extra arguments for Glue Job. (#14027)``
* ``Add aws ses email backend for use with EmailOperator. (#13986)``
* ``Add bucket_name to template fileds in S3 operators (#13973)``
* ``Add ExasolToS3Operator (#13847)``
* ``AWS Glue Crawler Integration (#13072)``
* ``Add acl_policy to S3CopyObjectOperator (#13773)``
* ``AllowDiskUse parameter and docs in MongotoS3Operator (#12033)``
* ``Add S3ToFTPOperator (#11747)``
* ``add xcom push for ECSOperator (#12096)``
* ``[AIRFLOW-3723] Add Gzip capability to mongo_to_S3 operator (#13187)``
* ``Add S3KeySizeSensor (#13049)``
* ``Add 'mongo_collection' to template_fields in MongoToS3Operator (#13361)``
* ``Allow Tags on AWS Batch Job Submission (#13396)``
Bug fixes
~~~~~~~~~
* ``Fix bug in GCSToS3Operator (#13718)``
* ``Fix S3KeysUnchangedSensor so that template_fields work (#13490)``
1.0.0
.....
Initial version of the provider.

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is an example dag for using `AWSDataSyncOperator` in a straightforward manner.
This DAG gets an AWS TaskArn for a specified source and destination, and then attempts to execute it.
It assumes there is a single task returned and does not do error checking (eg if multiple tasks were found).
This DAG relies on the following environment variables:
* SOURCE_LOCATION_URI - Source location URI, usually on premises SMB or NFS
* DESTINATION_LOCATION_URI - Destination location URI, usually S3
"""
from os import getenv
from airflow import models
from airflow.providers.amazon.aws.operators.datasync import AWSDataSyncOperator
from airflow.utils.dates import days_ago
# [START howto_operator_datasync_1_args_1]
TASK_ARN = getenv("TASK_ARN", "my_aws_datasync_task_arn")
# [END howto_operator_datasync_1_args_1]
# [START howto_operator_datasync_1_args_2]
SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/")
DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix")
# [END howto_operator_datasync_1_args_2]
with models.DAG(
"example_datasync_1_1",
schedule_interval=None, # Override to match your needs
start_date=days_ago(1),
tags=["example"],
) as dag:
# [START howto_operator_datasync_1_1]
datasync_task_1 = AWSDataSyncOperator(
aws_conn_id="aws_default", task_id="datasync_task_1", task_arn=TASK_ARN
)
# [END howto_operator_datasync_1_1]
with models.DAG(
"example_datasync_1_2",
start_date=days_ago(1),
schedule_interval=None, # Override to match your needs
) as dag:
# [START howto_operator_datasync_1_2]
datasync_task_2 = AWSDataSyncOperator(
aws_conn_id="aws_default",
task_id="datasync_task_2",
source_location_uri=SOURCE_LOCATION_URI,
destination_location_uri=DESTINATION_LOCATION_URI,
)
# [END howto_operator_datasync_1_2]

@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is an example dag for using `AWSDataSyncOperator` in a more complex manner.
- Try to get a TaskArn. If one exists, update it.
- If no tasks exist, try to create a new DataSync Task.
- If source and destination locations don't exist for the new task, create them first
- If many tasks exist, raise an Exception
- After getting or creating a DataSync Task, run it
This DAG relies on the following environment variables:
* SOURCE_LOCATION_URI - Source location URI, usually on premises SMB or NFS
* DESTINATION_LOCATION_URI - Destination location URI, usually S3
* CREATE_TASK_KWARGS - Passed to boto3.create_task(**kwargs)
* CREATE_SOURCE_LOCATION_KWARGS - Passed to boto3.create_location(**kwargs)
* CREATE_DESTINATION_LOCATION_KWARGS - Passed to boto3.create_location(**kwargs)
* UPDATE_TASK_KWARGS - Passed to boto3.update_task(**kwargs)
"""
import json
import re
from os import getenv
from airflow import models
from airflow.providers.amazon.aws.operators.datasync import AWSDataSyncOperator
from airflow.utils.dates import days_ago
# [START howto_operator_datasync_2_args]
SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/")
DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix")
default_create_task_kwargs = '{"Name": "Created by Airflow"}'
CREATE_TASK_KWARGS = json.loads(
getenv("CREATE_TASK_KWARGS", default_create_task_kwargs)
)
default_create_source_location_kwargs = "{}"
CREATE_SOURCE_LOCATION_KWARGS = json.loads(
getenv("CREATE_SOURCE_LOCATION_KWARGS", default_create_source_location_kwargs)
)
bucket_access_role_arn = (
"arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role"
)
default_destination_location_kwargs = """\
{"S3BucketArn": "arn:aws:s3:::mybucket",
"S3Config": {"BucketAccessRoleArn":
"arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role"}
}"""
CREATE_DESTINATION_LOCATION_KWARGS = json.loads(
getenv(
"CREATE_DESTINATION_LOCATION_KWARGS",
re.sub(r"[\s+]", "", default_destination_location_kwargs),
)
)
default_update_task_kwargs = '{"Name": "Updated by Airflow"}'
UPDATE_TASK_KWARGS = json.loads(
getenv("UPDATE_TASK_KWARGS", default_update_task_kwargs)
)
# [END howto_operator_datasync_2_args]
with models.DAG(
"example_datasync_2",
schedule_interval=None, # Override to match your needs
start_date=days_ago(1),
tags=["example"],
) as dag:
# [START howto_operator_datasync_2]
datasync_task = AWSDataSyncOperator(
aws_conn_id="aws_default",
task_id="datasync_task",
source_location_uri=SOURCE_LOCATION_URI,
destination_location_uri=DESTINATION_LOCATION_URI,
create_task_kwargs=CREATE_TASK_KWARGS,
create_source_location_kwargs=CREATE_SOURCE_LOCATION_KWARGS,
create_destination_location_kwargs=CREATE_DESTINATION_LOCATION_KWARGS,
update_task_kwargs=UPDATE_TASK_KWARGS,
delete_task_after_execution=True,
)
# [END howto_operator_datasync_2]

@ -0,0 +1,82 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is an example dag for ECSOperator.
The task "hello_world" runs `hello-world` task in `c` cluster.
It overrides the command in the `hello-world-container` container.
"""
import datetime
import os
from airflow import DAG
from airflow.providers.amazon.aws.operators.ecs import ECSOperator
DEFAULT_ARGS = {
"owner": "airflow",
"depends_on_past": False,
"email": ["airflow@example.com"],
"email_on_failure": False,
"email_on_retry": False,
}
dag = DAG(
dag_id="ecs_fargate_dag",
default_args=DEFAULT_ARGS,
default_view="graph",
schedule_interval=None,
start_date=datetime.datetime(2020, 1, 1),
tags=["example"],
)
# generate dag documentation
dag.doc_md = __doc__
# [START howto_operator_ecs]
hello_world = ECSOperator(
task_id="hello_world",
dag=dag,
aws_conn_id="aws_ecs",
cluster="c",
task_definition="hello-world",
launch_type="FARGATE",
overrides={
"containerOverrides": [
{
"name": "hello-world-container",
"command": ["echo", "hello", "world"],
},
],
},
network_configuration={
"awsvpcConfiguration": {
"securityGroups": [os.environ.get("SECURITY_GROUP_ID", "sg-123abc")],
"subnets": [os.environ.get("SUBNET_ID", "subnet-123456ab")],
},
},
tags={
"Customer": "X",
"Project": "Y",
"Application": "Z",
"Version": "0.0.1",
"Environment": "Development",
},
awslogs_group="/ecs/hello-world",
awslogs_stream_prefix="prefix_b/hello-world-container", # prefix with container name
)
# [END howto_operator_ecs]

@ -0,0 +1,96 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is an example dag for a AWS EMR Pipeline with auto steps.
"""
from datetime import timedelta
from airflow import DAG
from airflow.providers.amazon.aws.operators.emr_create_job_flow import (
EmrCreateJobFlowOperator,
)
from airflow.providers.amazon.aws.sensors.emr_job_flow import EmrJobFlowSensor
from airflow.utils.dates import days_ago
DEFAULT_ARGS = {
"owner": "airflow",
"depends_on_past": False,
"email": ["airflow@example.com"],
"email_on_failure": False,
"email_on_retry": False,
}
# [START howto_operator_emr_automatic_steps_config]
SPARK_STEPS = [
{
"Name": "calculate_pi",
"ActionOnFailure": "CONTINUE",
"HadoopJarStep": {
"Jar": "command-runner.jar",
"Args": ["/usr/lib/spark/bin/run-example", "SparkPi", "10"],
},
}
]
JOB_FLOW_OVERRIDES = {
"Name": "PiCalc",
"ReleaseLabel": "emr-5.29.0",
"Instances": {
"InstanceGroups": [
{
"Name": "Master node",
"Market": "SPOT",
"InstanceRole": "MASTER",
"InstanceType": "m1.medium",
"InstanceCount": 1,
}
],
"KeepJobFlowAliveWhenNoSteps": False,
"TerminationProtected": False,
},
"Steps": SPARK_STEPS,
"JobFlowRole": "EMR_EC2_DefaultRole",
"ServiceRole": "EMR_DefaultRole",
}
# [END howto_operator_emr_automatic_steps_config]
with DAG(
dag_id="emr_job_flow_automatic_steps_dag",
default_args=DEFAULT_ARGS,
dagrun_timeout=timedelta(hours=2),
start_date=days_ago(2),
schedule_interval="0 3 * * *",
tags=["example"],
) as dag:
# [START howto_operator_emr_automatic_steps_tasks]
job_flow_creator = EmrCreateJobFlowOperator(
task_id="create_job_flow",
job_flow_overrides=JOB_FLOW_OVERRIDES,
aws_conn_id="aws_default",
emr_conn_id="emr_default",
)
job_sensor = EmrJobFlowSensor(
task_id="check_job_flow",
job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}",
aws_conn_id="aws_default",
)
job_flow_creator >> job_sensor
# [END howto_operator_emr_automatic_steps_tasks]

@ -0,0 +1,114 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is an example dag for a AWS EMR Pipeline.
Starting by creating a cluster, adding steps/operations, checking steps and finally when finished
terminating the cluster.
"""
from datetime import timedelta
from airflow import DAG
from airflow.providers.amazon.aws.operators.emr_add_steps import EmrAddStepsOperator
from airflow.providers.amazon.aws.operators.emr_create_job_flow import (
EmrCreateJobFlowOperator,
)
from airflow.providers.amazon.aws.operators.emr_terminate_job_flow import (
EmrTerminateJobFlowOperator,
)
from airflow.providers.amazon.aws.sensors.emr_step import EmrStepSensor
from airflow.utils.dates import days_ago
DEFAULT_ARGS = {
"owner": "airflow",
"depends_on_past": False,
"email": ["airflow@example.com"],
"email_on_failure": False,
"email_on_retry": False,
}
SPARK_STEPS = [
{
"Name": "calculate_pi",
"ActionOnFailure": "CONTINUE",
"HadoopJarStep": {
"Jar": "command-runner.jar",
"Args": ["/usr/lib/spark/bin/run-example", "SparkPi", "10"],
},
}
]
JOB_FLOW_OVERRIDES = {
"Name": "PiCalc",
"ReleaseLabel": "emr-5.29.0",
"Instances": {
"InstanceGroups": [
{
"Name": "Master node",
"Market": "SPOT",
"InstanceRole": "MASTER",
"InstanceType": "m1.medium",
"InstanceCount": 1,
}
],
"KeepJobFlowAliveWhenNoSteps": True,
"TerminationProtected": False,
},
"JobFlowRole": "EMR_EC2_DefaultRole",
"ServiceRole": "EMR_DefaultRole",
}
with DAG(
dag_id="emr_job_flow_manual_steps_dag",
default_args=DEFAULT_ARGS,
dagrun_timeout=timedelta(hours=2),
start_date=days_ago(2),
schedule_interval="0 3 * * *",
tags=["example"],
) as dag:
# [START howto_operator_emr_manual_steps_tasks]
cluster_creator = EmrCreateJobFlowOperator(
task_id="create_job_flow",
job_flow_overrides=JOB_FLOW_OVERRIDES,
aws_conn_id="aws_default",
emr_conn_id="emr_default",
)
step_adder = EmrAddStepsOperator(
task_id="add_steps",
job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}",
aws_conn_id="aws_default",
steps=SPARK_STEPS,
)
step_checker = EmrStepSensor(
task_id="watch_step",
job_flow_id="{{ task_instance.xcom_pull('create_job_flow', key='return_value') }}",
step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}",
aws_conn_id="aws_default",
)
cluster_remover = EmrTerminateJobFlowOperator(
task_id="remove_cluster",
job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}",
aws_conn_id="aws_default",
)
cluster_creator >> step_adder >> step_checker >> cluster_remover
# [END howto_operator_emr_manual_steps_tasks]

@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
from airflow import models
from airflow.providers.amazon.aws.operators.glacier import GlacierCreateJobOperator
from airflow.providers.amazon.aws.sensors.glacier import GlacierJobOperationSensor
from airflow.providers.amazon.aws.transfers.glacier_to_gcs import GlacierToGCSOperator
from airflow.utils.dates import days_ago
VAULT_NAME = "airflow"
BUCKET_NAME = os.environ.get("GLACIER_GCS_BUCKET_NAME", "gs://glacier_bucket")
OBJECT_NAME = os.environ.get("GLACIER_OBJECT", "example-text.txt")
with models.DAG(
"example_glacier_to_gcs",
schedule_interval=None,
start_date=days_ago(1), # Override to match your needs
) as dag:
# [START howto_glacier_create_job_operator]
create_glacier_job = GlacierCreateJobOperator(
task_id="create_glacier_job",
aws_conn_id="aws_default",
vault_name=VAULT_NAME,
)
JOB_ID = '{{ task_instance.xcom_pull("create_glacier_job")["jobId"] }}'
# [END howto_glacier_create_job_operator]
# [START howto_glacier_job_operation_sensor]
wait_for_operation_complete = GlacierJobOperationSensor(
aws_conn_id="aws_default",
vault_name=VAULT_NAME,
job_id=JOB_ID,
task_id="wait_for_operation_complete",
)
# [END howto_glacier_job_operation_sensor]
# [START howto_glacier_transfer_data_to_gcs]
transfer_archive_to_gcs = GlacierToGCSOperator(
task_id="transfer_archive_to_gcs",
aws_conn_id="aws_default",
gcp_conn_id="google_cloud_default",
vault_name=VAULT_NAME,
bucket_name=BUCKET_NAME,
object_name=OBJECT_NAME,
gzip=False,
# Override to match your needs
# If chunk size is bigger than actual file size
# then whole file will be downloaded
chunk_size=1024,
delegate_to=None,
google_impersonation_chain=None,
)
# [END howto_glacier_transfer_data_to_gcs]
create_glacier_job >> wait_for_operation_complete >> transfer_archive_to_gcs

@ -0,0 +1,141 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is a more advanced example dag for using `GoogleApiToS3Transfer` which uses xcom to pass data between
tasks to retrieve specific information about YouTube videos:
First it searches for up to 50 videos (due to pagination) in a given time range
(YOUTUBE_VIDEO_PUBLISHED_AFTER, YOUTUBE_VIDEO_PUBLISHED_BEFORE) on a YouTube channel (YOUTUBE_CHANNEL_ID)
saves the response in S3 + passes over the YouTube IDs to the next request which then gets the information
(YOUTUBE_VIDEO_FIELDS) for the requested videos and saves them in S3 (S3_DESTINATION_KEY).
Further information:
YOUTUBE_VIDEO_PUBLISHED_AFTER and YOUTUBE_VIDEO_PUBLISHED_BEFORE needs to be formatted
"YYYY-MM-DDThh:mm:ss.sZ". See https://developers.google.com/youtube/v3/docs/search/list for more information.
YOUTUBE_VIDEO_PARTS depends on the fields you pass via YOUTUBE_VIDEO_FIELDS. See
https://developers.google.com/youtube/v3/docs/videos/list#parameters for more information.
YOUTUBE_CONN_ID is optional for public videos. It does only need to authenticate when there are private videos
on a YouTube channel you want to retrieve.
"""
from os import getenv
from airflow import DAG
from airflow.operators.dummy import DummyOperator
from airflow.operators.python import BranchPythonOperator
from airflow.providers.amazon.aws.transfers.google_api_to_s3 import (
GoogleApiToS3Operator,
)
from airflow.utils.dates import days_ago
# [START howto_operator_google_api_to_s3_transfer_advanced_env_variables]
YOUTUBE_CONN_ID = getenv("YOUTUBE_CONN_ID", "google_cloud_default")
YOUTUBE_CHANNEL_ID = getenv(
"YOUTUBE_CHANNEL_ID", "UCSXwxpWZQ7XZ1WL3wqevChA"
) # "Apache Airflow"
YOUTUBE_VIDEO_PUBLISHED_AFTER = getenv(
"YOUTUBE_VIDEO_PUBLISHED_AFTER", "2019-09-25T00:00:00Z"
)
YOUTUBE_VIDEO_PUBLISHED_BEFORE = getenv(
"YOUTUBE_VIDEO_PUBLISHED_BEFORE", "2019-10-18T00:00:00Z"
)
S3_DESTINATION_KEY = getenv("S3_DESTINATION_KEY", "s3://bucket/key.json")
YOUTUBE_VIDEO_PARTS = getenv("YOUTUBE_VIDEO_PARTS", "snippet")
YOUTUBE_VIDEO_FIELDS = getenv(
"YOUTUBE_VIDEO_FIELDS", "items(id,snippet(description,publishedAt,tags,title))"
)
# [END howto_operator_google_api_to_s3_transfer_advanced_env_variables]
# pylint: disable=unused-argument
# [START howto_operator_google_api_to_s3_transfer_advanced_task_1_2]
def _check_and_transform_video_ids(xcom_key, task_ids, task_instance, **kwargs):
video_ids_response = task_instance.xcom_pull(task_ids=task_ids, key=xcom_key)
video_ids = [item["id"]["videoId"] for item in video_ids_response["items"]]
if video_ids:
task_instance.xcom_push(key="video_ids", value={"id": ",".join(video_ids)})
return "video_data_to_s3"
return "no_video_ids"
# [END howto_operator_google_api_to_s3_transfer_advanced_task_1_2]
# pylint: enable=unused-argument
s3_directory, s3_file = S3_DESTINATION_KEY.rsplit("/", 1)
s3_file_name, _ = s3_file.rsplit(".", 1)
with DAG(
dag_id="example_google_api_to_s3_transfer_advanced",
schedule_interval=None,
start_date=days_ago(1),
tags=["example"],
) as dag:
# [START howto_operator_google_api_to_s3_transfer_advanced_task_1]
task_video_ids_to_s3 = GoogleApiToS3Operator(
gcp_conn_id=YOUTUBE_CONN_ID,
google_api_service_name="youtube",
google_api_service_version="v3",
google_api_endpoint_path="youtube.search.list",
google_api_endpoint_params={
"part": "snippet",
"channelId": YOUTUBE_CHANNEL_ID,
"maxResults": 50,
"publishedAfter": YOUTUBE_VIDEO_PUBLISHED_AFTER,
"publishedBefore": YOUTUBE_VIDEO_PUBLISHED_BEFORE,
"type": "video",
"fields": "items/id/videoId",
},
google_api_response_via_xcom="video_ids_response",
s3_destination_key=f"{s3_directory}/youtube_search_{s3_file_name}.json",
task_id="video_ids_to_s3",
)
# [END howto_operator_google_api_to_s3_transfer_advanced_task_1]
# [START howto_operator_google_api_to_s3_transfer_advanced_task_1_1]
task_check_and_transform_video_ids = BranchPythonOperator(
python_callable=_check_and_transform_video_ids,
op_args=[
task_video_ids_to_s3.google_api_response_via_xcom,
task_video_ids_to_s3.task_id,
],
task_id="check_and_transform_video_ids",
)
# [END howto_operator_google_api_to_s3_transfer_advanced_task_1_1]
# [START howto_operator_google_api_to_s3_transfer_advanced_task_2]
task_video_data_to_s3 = GoogleApiToS3Operator(
gcp_conn_id=YOUTUBE_CONN_ID,
google_api_service_name="youtube",
google_api_service_version="v3",
google_api_endpoint_path="youtube.videos.list",
google_api_endpoint_params={
"part": YOUTUBE_VIDEO_PARTS,
"maxResults": 50,
"fields": YOUTUBE_VIDEO_FIELDS,
},
google_api_endpoint_params_via_xcom="video_ids",
s3_destination_key=f"{s3_directory}/youtube_videos_{s3_file_name}.json",
task_id="video_data_to_s3",
)
# [END howto_operator_google_api_to_s3_transfer_advanced_task_2]
# [START howto_operator_google_api_to_s3_transfer_advanced_task_2_1]
task_no_video_ids = DummyOperator(task_id="no_video_ids")
# [END howto_operator_google_api_to_s3_transfer_advanced_task_2_1]
task_video_ids_to_s3 >> task_check_and_transform_video_ids >> [
task_video_data_to_s3,
task_no_video_ids,
]

@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is a basic example dag for using `GoogleApiToS3Transfer` to retrieve Google Sheets data:
You need to set all env variables to request the data.
"""
from os import getenv
from airflow import DAG
from airflow.providers.amazon.aws.transfers.google_api_to_s3 import (
GoogleApiToS3Operator,
)
from airflow.utils.dates import days_ago
# [START howto_operator_google_api_to_s3_transfer_basic_env_variables]
GOOGLE_SHEET_ID = getenv("GOOGLE_SHEET_ID")
GOOGLE_SHEET_RANGE = getenv("GOOGLE_SHEET_RANGE")
S3_DESTINATION_KEY = getenv("S3_DESTINATION_KEY", "s3://bucket/key.json")
# [END howto_operator_google_api_to_s3_transfer_basic_env_variables]
with DAG(
dag_id="example_google_api_to_s3_transfer_basic",
schedule_interval=None,
start_date=days_ago(1),
tags=["example"],
) as dag:
# [START howto_operator_google_api_to_s3_transfer_basic_task_1]
task_google_sheets_values_to_s3 = GoogleApiToS3Operator(
google_api_service_name="sheets",
google_api_service_version="v4",
google_api_endpoint_path="sheets.spreadsheets.values.get",
google_api_endpoint_params={
"spreadsheetId": GOOGLE_SHEET_ID,
"range": GOOGLE_SHEET_RANGE,
},
s3_destination_key=S3_DESTINATION_KEY,
task_id="google_sheets_values_to_s3",
dag=dag,
)
# [END howto_operator_google_api_to_s3_transfer_basic_task_1]

@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is an example dag for using `ImapAttachmentToS3Operator` to transfer an email attachment via IMAP
protocol from a mail server to S3 Bucket.
"""
from os import getenv
from airflow import DAG
from airflow.providers.amazon.aws.transfers.imap_attachment_to_s3 import (
ImapAttachmentToS3Operator,
)
from airflow.utils.dates import days_ago
# [START howto_operator_imap_attachment_to_s3_env_variables]
IMAP_ATTACHMENT_NAME = getenv("IMAP_ATTACHMENT_NAME", "test.txt")
IMAP_MAIL_FOLDER = getenv("IMAP_MAIL_FOLDER", "INBOX")
IMAP_MAIL_FILTER = getenv("IMAP_MAIL_FILTER", "All")
S3_DESTINATION_KEY = getenv("S3_DESTINATION_KEY", "s3://bucket/key.json")
# [END howto_operator_imap_attachment_to_s3_env_variables]
with DAG(
dag_id="example_imap_attachment_to_s3",
start_date=days_ago(1),
schedule_interval=None,
tags=["example"],
) as dag:
# [START howto_operator_imap_attachment_to_s3_task_1]
task_transfer_imap_attachment_to_s3 = ImapAttachmentToS3Operator(
imap_attachment_name=IMAP_ATTACHMENT_NAME,
s3_key=S3_DESTINATION_KEY,
imap_mail_folder=IMAP_MAIL_FOLDER,
imap_mail_filter=IMAP_MAIL_FILTER,
task_id="transfer_imap_attachment_to_s3",
dag=dag,
)
# [END howto_operator_imap_attachment_to_s3_task_1]

@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
from airflow.models.dag import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.s3_bucket import (
S3CreateBucketOperator,
S3DeleteBucketOperator,
)
from airflow.utils.dates import days_ago
BUCKET_NAME = os.environ.get("BUCKET_NAME", "test-airflow-12345")
def upload_keys():
"""This is a python callback to add keys into the s3 bucket"""
# add keys to bucket
s3_hook = S3Hook()
for i in range(0, 3):
s3_hook.load_string(
string_data="input",
key=f"path/data{i}",
bucket_name=BUCKET_NAME,
)
with DAG(
dag_id="s3_bucket_dag",
schedule_interval=None,
start_date=days_ago(2),
max_active_runs=1,
tags=["example"],
) as dag:
# [START howto_operator_s3_bucket]
create_bucket = S3CreateBucketOperator(
task_id="s3_bucket_dag_create",
bucket_name=BUCKET_NAME,
region_name="us-east-1",
)
add_keys_to_bucket = PythonOperator(
task_id="s3_bucket_dag_add_keys_to_bucket", python_callable=upload_keys
)
delete_bucket = S3DeleteBucketOperator(
task_id="s3_bucket_dag_delete",
bucket_name=BUCKET_NAME,
force_delete=True,
)
# [END howto_operator_s3_bucket]
create_bucket >> add_keys_to_bucket >> delete_bucket

@ -0,0 +1,73 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.operators.s3_bucket import (
S3CreateBucketOperator,
S3DeleteBucketOperator,
)
from airflow.providers.amazon.aws.operators.s3_bucket_tagging import (
S3DeleteBucketTaggingOperator,
S3GetBucketTaggingOperator,
S3PutBucketTaggingOperator,
)
from airflow.utils.dates import days_ago
BUCKET_NAME = os.environ.get("BUCKET_NAME", "test-s3-bucket-tagging")
TAG_KEY = os.environ.get("TAG_KEY", "test-s3-bucket-tagging-key")
TAG_VALUE = os.environ.get("TAG_VALUE", "test-s3-bucket-tagging-value")
with DAG(
dag_id="s3_bucket_tagging_dag",
schedule_interval=None,
start_date=days_ago(2),
max_active_runs=1,
tags=["example"],
) as dag:
create_bucket = S3CreateBucketOperator(
task_id="s3_bucket_tagging_dag_create",
bucket_name=BUCKET_NAME,
region_name="us-east-1",
)
delete_bucket = S3DeleteBucketOperator(
task_id="s3_bucket_tagging_dag_delete",
bucket_name=BUCKET_NAME,
force_delete=True,
)
# [START howto_operator_s3_bucket_tagging]
get_tagging = S3GetBucketTaggingOperator(
task_id="s3_bucket_tagging_dag_get_tagging", bucket_name=BUCKET_NAME
)
put_tagging = S3PutBucketTaggingOperator(
task_id="s3_bucket_tagging_dag_put_tagging",
bucket_name=BUCKET_NAME,
key=TAG_KEY,
value=TAG_VALUE,
)
delete_tagging = S3DeleteBucketTaggingOperator(
task_id="s3_bucket_tagging_dag_delete_tagging", bucket_name=BUCKET_NAME
)
# [END howto_operator_s3_bucket_tagging]
create_bucket >> put_tagging >> get_tagging >> delete_tagging >> delete_bucket

@ -0,0 +1,90 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is an example dag for using `S3ToRedshiftOperator` to copy a S3 key into a Redshift table.
"""
from os import getenv
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.transfers.s3_to_redshift import S3ToRedshiftOperator
from airflow.providers.postgres.operators.postgres import PostgresOperator
from airflow.utils.dates import days_ago
# [START howto_operator_s3_to_redshift_env_variables]
S3_BUCKET = getenv("S3_BUCKET", "test-bucket")
S3_KEY = getenv("S3_KEY", "key")
REDSHIFT_TABLE = getenv("REDSHIFT_TABLE", "test_table")
# [END howto_operator_s3_to_redshift_env_variables]
def _add_sample_data_to_s3():
s3_hook = S3Hook()
s3_hook.load_string(
"0,Airflow", f"{S3_KEY}/{REDSHIFT_TABLE}", S3_BUCKET, replace=True
)
def _remove_sample_data_from_s3():
s3_hook = S3Hook()
if s3_hook.check_for_key(f"{S3_KEY}/{REDSHIFT_TABLE}", S3_BUCKET):
s3_hook.delete_objects(S3_BUCKET, f"{S3_KEY}/{REDSHIFT_TABLE}")
with DAG(
dag_id="example_s3_to_redshift",
start_date=days_ago(1),
schedule_interval=None,
tags=["example"],
) as dag:
setup__task_add_sample_data_to_s3 = PythonOperator(
python_callable=_add_sample_data_to_s3, task_id="setup__add_sample_data_to_s3"
)
setup__task_create_table = PostgresOperator(
sql=f"CREATE TABLE IF NOT EXISTS {REDSHIFT_TABLE}(Id int, Name varchar)",
postgres_conn_id="redshift_default",
task_id="setup__create_table",
)
# [START howto_operator_s3_to_redshift_task_1]
task_transfer_s3_to_redshift = S3ToRedshiftOperator(
s3_bucket=S3_BUCKET,
s3_key=S3_KEY,
schema="PUBLIC",
table=REDSHIFT_TABLE,
copy_options=["csv"],
task_id="transfer_s3_to_redshift",
)
# [END howto_operator_s3_to_redshift_task_1]
teardown__task_drop_table = PostgresOperator(
sql=f"DROP TABLE IF EXISTS {REDSHIFT_TABLE}",
postgres_conn_id="redshift_default",
task_id="teardown__drop_table",
)
teardown__task_remove_sample_data_from_s3 = PythonOperator(
python_callable=_remove_sample_data_from_s3,
task_id="teardown__remove_sample_data_from_s3",
)
[
setup__task_add_sample_data_to_s3,
setup__task_create_table,
] >> task_transfer_s3_to_redshift >> [
teardown__task_drop_table,
teardown__task_remove_sample_data_from_s3,
]

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

@ -0,0 +1,263 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains AWS Athena hook"""
from time import sleep
from typing import Any, Dict, Optional
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from botocore.paginate import PageIterator
class AWSAthenaHook(AwsBaseHook):
"""
Interact with AWS Athena to run, poll queries and return query results
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
:param sleep_time: Time (in seconds) to wait between two consecutive calls to check query status on Athena
:type sleep_time: int
"""
INTERMEDIATE_STATES = (
"QUEUED",
"RUNNING",
)
FAILURE_STATES = (
"FAILED",
"CANCELLED",
)
SUCCESS_STATES = ("SUCCEEDED",)
def __init__(self, *args: Any, sleep_time: int = 30, **kwargs: Any) -> None:
super().__init__(client_type="athena", *args, **kwargs) # type: ignore
self.sleep_time = sleep_time
def run_query(
self,
query: str,
query_context: Dict[str, str],
result_configuration: Dict[str, Any],
client_request_token: Optional[str] = None,
workgroup: str = "primary",
) -> str:
"""
Run Presto query on athena with provided config and return submitted query_execution_id
:param query: Presto query to run
:type query: str
:param query_context: Context in which query need to be run
:type query_context: dict
:param result_configuration: Dict with path to store results in and config related to encryption
:type result_configuration: dict
:param client_request_token: Unique token created by user to avoid multiple executions of same query
:type client_request_token: str
:param workgroup: Athena workgroup name, when not specified, will be 'primary'
:type workgroup: str
:return: str
"""
params = {
"QueryString": query,
"QueryExecutionContext": query_context,
"ResultConfiguration": result_configuration,
"WorkGroup": workgroup,
}
if client_request_token:
params["ClientRequestToken"] = client_request_token
response = self.get_conn().start_query_execution(**params)
query_execution_id = response["QueryExecutionId"]
return query_execution_id
def check_query_status(self, query_execution_id: str) -> Optional[str]:
"""
Fetch the status of submitted athena query. Returns None or one of valid query states.
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:return: str
"""
response = self.get_conn().get_query_execution(
QueryExecutionId=query_execution_id
)
state = None
try:
state = response["QueryExecution"]["Status"]["State"]
except Exception as ex: # pylint: disable=broad-except
self.log.error("Exception while getting query state %s", ex)
finally:
# The error is being absorbed here and is being handled by the caller.
# The error is being absorbed to implement retries.
return state # pylint: disable=lost-exception
def get_state_change_reason(self, query_execution_id: str) -> Optional[str]:
"""
Fetch the reason for a state change (e.g. error message). Returns None or reason string.
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:return: str
"""
response = self.get_conn().get_query_execution(
QueryExecutionId=query_execution_id
)
reason = None
try:
reason = response["QueryExecution"]["Status"]["StateChangeReason"]
except Exception as ex: # pylint: disable=broad-except
self.log.error("Exception while getting query state change reason: %s", ex)
finally:
# The error is being absorbed here and is being handled by the caller.
# The error is being absorbed to implement retries.
return reason # pylint: disable=lost-exception
def get_query_results(
self,
query_execution_id: str,
next_token_id: Optional[str] = None,
max_results: int = 1000,
) -> Optional[dict]:
"""
Fetch submitted athena query results. returns none if query is in intermediate state or
failed/cancelled state else dict of query output
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:param next_token_id: The token that specifies where to start pagination.
:type next_token_id: str
:param max_results: The maximum number of results (rows) to return in this request.
:type max_results: int
:return: dict
"""
query_state = self.check_query_status(query_execution_id)
if query_state is None:
self.log.error("Invalid Query state")
return None
elif (
query_state in self.INTERMEDIATE_STATES
or query_state in self.FAILURE_STATES
):
self.log.error('Query is in "%s" state. Cannot fetch results', query_state)
return None
result_params = {
"QueryExecutionId": query_execution_id,
"MaxResults": max_results,
}
if next_token_id:
result_params["NextToken"] = next_token_id
return self.get_conn().get_query_results(**result_params)
def get_query_results_paginator(
self,
query_execution_id: str,
max_items: Optional[int] = None,
page_size: Optional[int] = None,
starting_token: Optional[str] = None,
) -> Optional[PageIterator]:
"""
Fetch submitted athena query results. returns none if query is in intermediate state or
failed/cancelled state else a paginator to iterate through pages of results. If you
wish to get all results at once, call build_full_result() on the returned PageIterator
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:param max_items: The total number of items to return.
:type max_items: int
:param page_size: The size of each page.
:type page_size: int
:param starting_token: A token to specify where to start paginating.
:type starting_token: str
:return: PageIterator
"""
query_state = self.check_query_status(query_execution_id)
if query_state is None:
self.log.error("Invalid Query state (null)")
return None
if (
query_state in self.INTERMEDIATE_STATES
or query_state in self.FAILURE_STATES
):
self.log.error('Query is in "%s" state. Cannot fetch results', query_state)
return None
result_params = {
"QueryExecutionId": query_execution_id,
"PaginationConfig": {
"MaxItems": max_items,
"PageSize": page_size,
"StartingToken": starting_token,
},
}
paginator = self.get_conn().get_paginator("get_query_results")
return paginator.paginate(**result_params)
def poll_query_status(
self, query_execution_id: str, max_tries: Optional[int] = None
) -> Optional[str]:
"""
Poll the status of submitted athena query until query state reaches final state.
Returns one of the final states
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:param max_tries: Number of times to poll for query state before function exits
:type max_tries: int
:return: str
"""
try_number = 1
final_query_state = (
None # Query state when query reaches final state or max_tries reached
)
while True:
query_state = self.check_query_status(query_execution_id)
if query_state is None:
self.log.info(
"Trial %s: Invalid query state. Retrying again", try_number
)
elif query_state in self.INTERMEDIATE_STATES:
self.log.info(
"Trial %s: Query is still in an intermediate state - %s",
try_number,
query_state,
)
else:
self.log.info(
"Trial %s: Query execution completed. Final state is %s}",
try_number,
query_state,
)
final_query_state = query_state
break
if max_tries and try_number >= max_tries: # Break loop if max_tries reached
final_query_state = query_state
break
try_number += 1
sleep(self.sleep_time)
return final_query_state
def stop_query(self, query_execution_id: str) -> Dict:
"""
Cancel the submitted athena query
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:return: dict
"""
return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id)

@ -0,0 +1,29 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`."""
import warnings
# pylint: disable=unused-import
from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook # noqa
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`.",
DeprecationWarning,
stacklevel=2,
)

@ -0,0 +1,596 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains Base AWS Hook.
.. seealso::
For more information on how to use this hook, take a look at the guide:
:ref:`howto/connection:AWSHook`
"""
import configparser
import datetime
import logging
from typing import Any, Dict, Optional, Tuple, Union
import boto3
import botocore
import botocore.session
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials
try:
from functools import cached_property
except ImportError:
from cached_property import cached_property
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models.connection import Connection
from airflow.utils.log.logging_mixin import LoggingMixin
from dateutil.tz import tzlocal
class _SessionFactory(LoggingMixin):
def __init__(
self, conn: Connection, region_name: Optional[str], config: Config
) -> None:
super().__init__()
self.conn = conn
self.region_name = region_name
self.config = config
self.extra_config = self.conn.extra_dejson
def create_session(self) -> boto3.session.Session:
"""Create AWS session."""
session_kwargs = {}
if "session_kwargs" in self.extra_config:
self.log.info(
"Retrieving session_kwargs from Connection.extra_config['session_kwargs']: %s",
self.extra_config["session_kwargs"],
)
session_kwargs = self.extra_config["session_kwargs"]
session = self._create_basic_session(session_kwargs=session_kwargs)
role_arn = self._read_role_arn_from_extra_config()
# If role_arn was specified then STS + assume_role
if role_arn is None:
return session
return self._impersonate_to_role(
role_arn=role_arn, session=session, session_kwargs=session_kwargs
)
def _create_basic_session(
self, session_kwargs: Dict[str, Any]
) -> boto3.session.Session:
(
aws_access_key_id,
aws_secret_access_key,
) = self._read_credentials_from_connection()
aws_session_token = self.extra_config.get("aws_session_token")
region_name = self.region_name
if self.region_name is None and "region_name" in self.extra_config:
self.log.info(
"Retrieving region_name from Connection.extra_config['region_name']"
)
region_name = self.extra_config["region_name"]
self.log.info(
"Creating session with aws_access_key_id=%s region_name=%s",
aws_access_key_id,
region_name,
)
return boto3.session.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name,
aws_session_token=aws_session_token,
**session_kwargs,
)
def _impersonate_to_role(
self,
role_arn: str,
session: boto3.session.Session,
session_kwargs: Dict[str, Any],
) -> boto3.session.Session:
assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {})
assume_role_method = self.extra_config.get("assume_role_method")
self.log.info("assume_role_method=%s", assume_role_method)
if not assume_role_method or assume_role_method == "assume_role":
sts_client = session.client("sts", config=self.config)
sts_response = self._assume_role(
sts_client=sts_client,
role_arn=role_arn,
assume_role_kwargs=assume_role_kwargs,
)
elif assume_role_method == "assume_role_with_saml":
sts_client = session.client("sts", config=self.config)
sts_response = self._assume_role_with_saml(
sts_client=sts_client,
role_arn=role_arn,
assume_role_kwargs=assume_role_kwargs,
)
elif assume_role_method == "assume_role_with_web_identity":
botocore_session = self._assume_role_with_web_identity(
role_arn=role_arn,
assume_role_kwargs=assume_role_kwargs,
base_session=session._session, # pylint: disable=protected-access
)
return boto3.session.Session(
region_name=session.region_name,
botocore_session=botocore_session,
**session_kwargs,
)
else:
raise NotImplementedError(
f"assume_role_method={assume_role_method} in Connection {self.conn.conn_id} Extra."
'Currently "assume_role" or "assume_role_with_saml" are supported.'
'(Exclude this setting will default to "assume_role").'
)
# Use credentials retrieved from STS
credentials = sts_response["Credentials"]
aws_access_key_id = credentials["AccessKeyId"]
aws_secret_access_key = credentials["SecretAccessKey"]
aws_session_token = credentials["SessionToken"]
self.log.info(
"Creating session with aws_access_key_id=%s region_name=%s",
aws_access_key_id,
session.region_name,
)
return boto3.session.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=session.region_name,
aws_session_token=aws_session_token,
**session_kwargs,
)
def _read_role_arn_from_extra_config(self) -> Optional[str]:
aws_account_id = self.extra_config.get("aws_account_id")
aws_iam_role = self.extra_config.get("aws_iam_role")
role_arn = self.extra_config.get("role_arn")
if role_arn is None and aws_account_id is not None and aws_iam_role is not None:
self.log.info("Constructing role_arn from aws_account_id and aws_iam_role")
role_arn = f"arn:aws:iam::{aws_account_id}:role/{aws_iam_role}"
self.log.info("role_arn is %s", role_arn)
return role_arn
def _read_credentials_from_connection(self) -> Tuple[Optional[str], Optional[str]]:
aws_access_key_id = None
aws_secret_access_key = None
if self.conn.login:
aws_access_key_id = self.conn.login
aws_secret_access_key = self.conn.password
self.log.info("Credentials retrieved from login")
elif (
"aws_access_key_id" in self.extra_config
and "aws_secret_access_key" in self.extra_config
):
aws_access_key_id = self.extra_config["aws_access_key_id"]
aws_secret_access_key = self.extra_config["aws_secret_access_key"]
self.log.info("Credentials retrieved from extra_config")
elif "s3_config_file" in self.extra_config:
aws_access_key_id, aws_secret_access_key = _parse_s3_config(
self.extra_config["s3_config_file"],
self.extra_config.get("s3_config_format"),
self.extra_config.get("profile"),
)
self.log.info("Credentials retrieved from extra_config['s3_config_file']")
else:
self.log.info("No credentials retrieved from Connection")
return aws_access_key_id, aws_secret_access_key
def _assume_role(
self,
sts_client: boto3.client,
role_arn: str,
assume_role_kwargs: Dict[str, Any],
) -> Dict:
if "external_id" in self.extra_config: # Backwards compatibility
assume_role_kwargs["ExternalId"] = self.extra_config.get("external_id")
role_session_name = f"Airflow_{self.conn.conn_id}"
self.log.info(
"Doing sts_client.assume_role to role_arn=%s (role_session_name=%s)",
role_arn,
role_session_name,
)
return sts_client.assume_role(
RoleArn=role_arn, RoleSessionName=role_session_name, **assume_role_kwargs
)
def _assume_role_with_saml(
self,
sts_client: boto3.client,
role_arn: str,
assume_role_kwargs: Dict[str, Any],
) -> Dict[str, Any]:
saml_config = self.extra_config["assume_role_with_saml"]
principal_arn = saml_config["principal_arn"]
idp_auth_method = saml_config["idp_auth_method"]
if idp_auth_method == "http_spegno_auth":
saml_assertion = self._fetch_saml_assertion_using_http_spegno_auth(
saml_config
)
else:
raise NotImplementedError(
f"idp_auth_method={idp_auth_method} in Connection {self.conn.conn_id} Extra."
'Currently only "http_spegno_auth" is supported, and must be specified.'
)
self.log.info("Doing sts_client.assume_role_with_saml to role_arn=%s", role_arn)
return sts_client.assume_role_with_saml(
RoleArn=role_arn,
PrincipalArn=principal_arn,
SAMLAssertion=saml_assertion,
**assume_role_kwargs,
)
def _fetch_saml_assertion_using_http_spegno_auth(
self, saml_config: Dict[str, Any]
) -> str:
import requests
# requests_gssapi will need paramiko > 2.6 since you'll need
# 'gssapi' not 'python-gssapi' from PyPi.
# https://github.com/paramiko/paramiko/pull/1311
import requests_gssapi
from lxml import etree
idp_url = saml_config["idp_url"]
self.log.info("idp_url= %s", idp_url)
idp_request_kwargs = saml_config["idp_request_kwargs"]
auth = requests_gssapi.HTTPSPNEGOAuth()
if "mutual_authentication" in saml_config:
mutual_auth = saml_config["mutual_authentication"]
if mutual_auth == "REQUIRED":
auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.REQUIRED)
elif mutual_auth == "OPTIONAL":
auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.OPTIONAL)
elif mutual_auth == "DISABLED":
auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.DISABLED)
else:
raise NotImplementedError(
f"mutual_authentication={mutual_auth} in Connection {self.conn.conn_id} Extra."
'Currently "REQUIRED", "OPTIONAL" and "DISABLED" are supported.'
"(Exclude this setting will default to HTTPSPNEGOAuth() )."
)
# Query the IDP
idp_response = requests.get(idp_url, auth=auth, **idp_request_kwargs)
idp_response.raise_for_status()
# Assist with debugging. Note: contains sensitive info!
xpath = saml_config["saml_response_xpath"]
log_idp_response = (
"log_idp_response" in saml_config and saml_config["log_idp_response"]
)
if log_idp_response:
self.log.warning(
"The IDP response contains sensitive information, but log_idp_response is ON (%s).",
log_idp_response,
)
self.log.info("idp_response.content= %s", idp_response.content)
self.log.info("xpath= %s", xpath)
# Extract SAML Assertion from the returned HTML / XML
xml = etree.fromstring(idp_response.content)
saml_assertion = xml.xpath(xpath)
if isinstance(saml_assertion, list):
if len(saml_assertion) == 1:
saml_assertion = saml_assertion[0]
if not saml_assertion:
raise ValueError("Invalid SAML Assertion")
return saml_assertion
def _assume_role_with_web_identity(
self, role_arn, assume_role_kwargs, base_session
):
base_session = base_session or botocore.session.get_session()
client_creator = base_session.create_client
federation = self.extra_config.get("assume_role_with_web_identity_federation")
if federation == "google":
web_identity_token_loader = self._get_google_identity_token_loader()
else:
raise AirflowException(
f'Unsupported federation: {federation}. Currently "google" only are supported.'
)
fetcher = botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher(
client_creator=client_creator,
web_identity_token_loader=web_identity_token_loader,
role_arn=role_arn,
extra_args=assume_role_kwargs or {},
)
aws_creds = botocore.credentials.DeferredRefreshableCredentials(
method="assume-role-with-web-identity",
refresh_using=fetcher.fetch_credentials,
time_fetcher=lambda: datetime.datetime.now(tz=tzlocal()),
)
botocore_session = botocore.session.Session()
botocore_session._credentials = aws_creds # pylint: disable=protected-access
return botocore_session
def _get_google_identity_token_loader(self):
from airflow.providers.google.common.utils.id_token_credentials import (
get_default_id_token_credentials,
)
from google.auth.transport import requests as requests_transport
audience = self.extra_config.get(
"assume_role_with_web_identity_federation_audience"
)
google_id_token_credentials = get_default_id_token_credentials(
target_audience=audience
)
def web_identity_token_loader():
if not google_id_token_credentials.valid:
request_adapter = requests_transport.Request()
google_id_token_credentials.refresh(request=request_adapter)
return google_id_token_credentials.token
return web_identity_token_loader
class AwsBaseHook(BaseHook):
"""
Interact with AWS.
This class is a thin wrapper around the boto3 python library.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:type aws_conn_id: str
:param verify: Whether or not to verify SSL certificates.
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:type verify: Union[bool, str, None]
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:type region_name: Optional[str]
:param client_type: boto3.client client_type. Eg 's3', 'emr' etc
:type client_type: Optional[str]
:param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc
:type resource_type: Optional[str]
:param config: Configuration for botocore client.
(https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html)
:type config: Optional[botocore.client.Config]
"""
conn_name_attr = "aws_conn_id"
default_conn_name = "aws_default"
conn_type = "aws"
hook_name = "Amazon Web Services"
def __init__(
self,
aws_conn_id: Optional[str] = default_conn_name,
verify: Union[bool, str, None] = None,
region_name: Optional[str] = None,
client_type: Optional[str] = None,
resource_type: Optional[str] = None,
config: Optional[Config] = None,
) -> None:
super().__init__()
self.aws_conn_id = aws_conn_id
self.verify = verify
self.client_type = client_type
self.resource_type = resource_type
self.region_name = region_name
self.config = config
if not (self.client_type or self.resource_type):
raise AirflowException(
"Either client_type or resource_type must be provided."
)
def _get_credentials(
self, region_name: Optional[str]
) -> Tuple[boto3.session.Session, Optional[str]]:
if not self.aws_conn_id:
session = boto3.session.Session(region_name=region_name)
return session, None
self.log.info("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)
try:
# Fetch the Airflow connection object
connection_object = self.get_connection(self.aws_conn_id)
extra_config = connection_object.extra_dejson
endpoint_url = extra_config.get("host")
# https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html#botocore.config.Config
if "config_kwargs" in extra_config:
self.log.info(
"Retrieving config_kwargs from Connection.extra_config['config_kwargs']: %s",
extra_config["config_kwargs"],
)
self.config = Config(**extra_config["config_kwargs"])
session = _SessionFactory(
conn=connection_object, region_name=region_name, config=self.config
).create_session()
return session, endpoint_url
except AirflowException:
self.log.warning("Unable to use Airflow Connection for credentials.")
self.log.info("Fallback on boto3 credential strategy")
# http://boto3.readthedocs.io/en/latest/guide/configuration.html
self.log.info(
"Creating session using boto3 credential strategy region_name=%s",
region_name,
)
session = boto3.session.Session(region_name=region_name)
return session, None
def get_client_type(
self,
client_type: str,
region_name: Optional[str] = None,
config: Optional[Config] = None,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session"""
session, endpoint_url = self._get_credentials(region_name)
# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
config = self.config
return session.client(
client_type, endpoint_url=endpoint_url, config=config, verify=self.verify
)
def get_resource_type(
self,
resource_type: str,
region_name: Optional[str] = None,
config: Optional[Config] = None,
) -> boto3.re#
"""Get the underlying boto3 resource using boto3 session"""
session, endpoint_url = self._get_credentials(region_name)
# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
config = self.config
return session.resource(
resource_type, endpoint_url=endpoint_url, config=config, verify=self.verify
)
@cached_property
def conn(self) -> Union[boto3.client, boto3.resource]:
"""
Get the underlying boto3 client/resource (cached)
:return: boto3.client or boto3.resource
:rtype: Union[boto3.client, boto3.resource]
"""
if self.client_type:
return self.get_client_type(self.client_type, region_name=self.region_name)
elif self.resource_type:
return self.get_resource_type(
self.resource_type, region_name=self.region_name
)
else:
# Rare possibility - subclasses have not specified a client_type or resource_type
raise NotImplementedError("Could not get boto3 connection!")
def get_conn(self) -> Union[boto3.client, boto3.resource]:
"""
Get the underlying boto3 client/resource (cached)
Implemented so that caching works as intended. It exists for compatibility
with subclasses that rely on a super().get_conn() method.
:return: boto3.client or boto3.resource
:rtype: Union[boto3.client, boto3.resource]
"""
# Compat shim
return self.conn
def get_session(self, region_name: Optional[str] = None) -> boto3.session.Session:
"""Get the underlying boto3.session."""
session, _ = self._get_credentials(region_name)
return session
def get_credentials(self, region_name: Optional[str] = None) -> ReadOnlyCredentials:
"""
Get the underlying `botocore.Credentials` object.
This contains the following authentication attributes: access_key, secret_key and token.
"""
session, _ = self._get_credentials(region_name)
# Credentials are refreshable, so accessing your access key and
# secret key separately can lead to a race condition.
# See https://stackoverflow.com/a/36291428/8283373
return session.get_credentials().get_frozen_credentials()
def expand_role(self, role: str) -> str:
"""
If the IAM role is a role name, get the Amazon Resource Name (ARN) for the role.
If IAM role is already an IAM role ARN, no change is made.
:param role: IAM role name or ARN
:return: IAM role ARN
"""
if "/" in role:
return role
else:
return self.get_client_type("iam").get_role(RoleName=role)["Role"]["Arn"]
def _parse_s3_config(
config_file_name: str,
config_format: Optional[str] = "boto",
profile: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
"""
Parses a config file for s3 credentials. Can currently
parse boto, s3cmd.conf and AWS SDK config formats
:param config_file_name: path to the config file
:type config_file_name: str
:param config_format: config type. One of "boto", "s3cmd" or "aws".
Defaults to "boto"
:type config_format: str
:param profile: profile name in AWS type config file
:type profile: str
"""
config = configparser.ConfigParser()
if config.read(config_file_name): # pragma: no cover
sections = config.sections()
else:
raise AirflowException(f"Couldn't read {config_file_name}")
# Setting option names depending on file format
if config_format is None:
config_format = "boto"
conf_format = config_format.lower()
if conf_format == "boto": # pragma: no cover
if profile is not None and "profile " + profile in sections:
cred_section = "profile " + profile
else:
cred_section = "Credentials"
elif conf_format == "aws" and profile is not None:
cred_section = profile
else:
cred_section = "default"
# Option names
if conf_format in ("boto", "aws"): # pragma: no cover
key_id_option = "aws_access_key_id"
secret_key_option = "aws_secret_access_key"
# security_token_option = 'aws_security_token'
else:
key_id_option = "access_key"
secret_key_option = "secret_key"
# Actual Parsing
if cred_section not in sections:
raise AirflowException("This config file format is not recognized")
else:
try:
access_key = config.get(cred_section, key_id_option)
secret_key = config.get(cred_section, secret_key_option)
except Exception:
logging.warning("Option Error in parsing s3 config file")
raise
return access_key, secret_key

@ -0,0 +1,556 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
A client for AWS batch services
.. seealso::
- http://boto3.readthedocs.io/en/latest/guide/configuration.html
- http://boto3.readthedocs.io/en/latest/reference/services/batch.html
- https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html
"""
from random import uniform
from time import sleep
from typing import Dict, List, Optional, Union
import botocore.client
import botocore.exceptions
import botocore.waiter
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.typing_compat import Protocol, runtime_checkable
# Add exceptions to pylint for the boto3 protocol only; ideally the boto3 library
# could provide
# protocols for all their dynamically generated classes (try to migrate this to a PR on botocore).
# Note that the use of invalid-name parameters should be restricted to the boto3 mappings only;
# all the Airflow wrappers of boto3 clients should not adopt invalid-names to match boto3.
# pylint: disable=invalid-name, unused-argument
@runtime_checkable
class AwsBatchProtocol(Protocol):
"""
A structured Protocol for ``boto3.client('batch') -> botocore.client.Batch``.
This is used for type hints on :py:meth:`.AwsBatchClient.client`; it covers
only the subset of client methods required.
.. seealso::
- https://mypy.readthedocs.io/en/latest/protocols.html
- http://boto3.readthedocs.io/en/latest/reference/services/batch.html
"""
def describe_jobs(self, jobs: List[str]) -> Dict:
"""
Get job descriptions from AWS batch
:param jobs: a list of JobId to describe
:type jobs: List[str]
:return: an API response to describe jobs
:rtype: Dict
"""
...
def get_waiter(self, waiterName: str) -> botocore.waiter.Waiter:
"""
Get an AWS Batch service waiter
:param waiterName: The name of the waiter. The name should match
the name (including the casing) of the key name in the waiter
model file (typically this is CamelCasing).
:type waiterName: str
:return: a waiter object for the named AWS batch service
:rtype: botocore.waiter.Waiter
.. note::
AWS batch might not have any waiters (until botocore PR-1307 is released).
.. code-block:: python
import boto3
boto3.client('batch').waiter_names == []
.. seealso::
- https://boto3.amazonaws.com/v1/documentation/api/latest/guide/clients.html#waiters
- https://github.com/boto/botocore/pull/1307
"""
...
def submit_job(
self,
jobName: str,
jobQueue: str,
jobDefinition: str,
arrayProperties: Dict,
parameters: Dict,
containerOverrides: Dict,
tags: Dict,
) -> Dict:
"""
Submit a batch job
:param jobName: the name for the AWS batch job
:type jobName: str
:param jobQueue: the queue name on AWS Batch
:type jobQueue: str
:param jobDefinition: the job definition name on AWS Batch
:type jobDefinition: str
:param arrayProperties: the same parameter that boto3 will receive
:type arrayProperties: Dict
:param parameters: the same parameter that boto3 will receive
:type parameters: Dict
:param containerOverrides: the same parameter that boto3 will receive
:type containerOverrides: Dict
:param tags: the same parameter that boto3 will receive
:type tags: Dict
:return: an API response
:rtype: Dict
"""
...
def terminate_job(self, jobId: str, reason: str) -> Dict:
"""
Terminate a batch job
:param jobId: a job ID to terminate
:type jobId: str
:param reason: a reason to terminate job ID
:type reason: str
:return: an API response
:rtype: Dict
"""
...
# Note that the use of invalid-name parameters should be restricted to the boto3 mappings only;
# all the Airflow wrappers of boto3 clients should not adopt invalid-names to match boto3.
# pylint: enable=invalid-name, unused-argument
class AwsBatchClientHook(AwsBaseHook):
"""
A client for AWS batch services.
:param max_retries: exponential back-off retries, 4200 = 48 hours;
polling is only used when waiters is None
:type max_retries: Optional[int]
:param status_retries: number of HTTP retries to get job status, 10;
polling is only used when waiters is None
:type status_retries: Optional[int]
.. note::
Several methods use a default random delay to check or poll for job status, i.e.
``random.uniform(DEFAULT_DELAY_MIN, DEFAULT_DELAY_MAX)``
Using a random interval helps to avoid AWS API throttle limits
when many concurrent tasks request job-descriptions.
To modify the global defaults for the range of jitter allowed when a
random delay is used to check batch job status, modify these defaults, e.g.:
.. code-block::
AwsBatchClient.DEFAULT_DELAY_MIN = 0
AwsBatchClient.DEFAULT_DELAY_MAX = 5
When explicit delay values are used, a 1 second random jitter is applied to the
delay (e.g. a delay of 0 sec will be a ``random.uniform(0, 1)`` delay. It is
generally recommended that random jitter is added to API requests. A
convenience method is provided for this, e.g. to get a random delay of
10 sec +/- 5 sec: ``delay = AwsBatchClient.add_jitter(10, width=5, minima=0)``
.. seealso::
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
- https://docs.aws.amazon.com/general/latest/gr/api-retries.html
- https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
"""
MAX_RETRIES = 4200
STATUS_RETRIES = 10
# delays are in seconds
DEFAULT_DELAY_MIN = 1
DEFAULT_DELAY_MAX = 10
def __init__(
self,
*args,
max_retries: Optional[int] = None,
status_retries: Optional[int] = None,
**kwargs,
) -> None:
# https://github.com/python/mypy/issues/6799 hence type: ignore
super().__init__(client_type="batch", *args, **kwargs) # type: ignore
self.max_retries = max_retries or self.MAX_RETRIES
self.status_retries = status_retries or self.STATUS_RETRIES
@property
def client(
self,
) -> Union[AwsBatchProtocol, botocore.client.BaseClient]: # noqa: D402
"""
An AWS API client for batch services, like ``boto3.client('batch')``
:return: a boto3 'batch' client for the ``.region_name``
:rtype: Union[AwsBatchProtocol, botocore.client.BaseClient]
"""
return self.conn
def terminate_job(self, job_id: str, reason: str) -> Dict:
"""
Terminate a batch job
:param job_id: a job ID to terminate
:type job_id: str
:param reason: a reason to terminate job ID
:type reason: str
:return: an API response
:rtype: Dict
"""
response = self.get_conn().terminate_job(jobId=job_id, reason=reason)
self.log.info(response)
return response
def check_job_success(self, job_id: str) -> bool:
"""
Check the final status of the batch job; return True if the job
'SUCCEEDED', else raise an AirflowException
:param job_id: a batch job ID
:type job_id: str
:rtype: bool
:raises: AirflowException
"""
job = self.get_job_description(job_id)
job_status = job.get("status")
if job_status == "SUCCEEDED":
self.log.info("AWS batch job (%s) succeeded: %s", job_id, job)
return True
if job_status == "FAILED":
raise AirflowException(f"AWS Batch job ({job_id}) failed: {job}")
if job_status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]:
raise AirflowException(f"AWS Batch job ({job_id}) is not complete: {job}")
raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}")
def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None:
"""
Wait for batch job to complete
:param job_id: a batch job ID
:type job_id: str
:param delay: a delay before polling for job status
:type delay: Optional[Union[int, float]]
:raises: AirflowException
"""
self.delay(delay)
self.poll_for_job_running(job_id, delay)
self.poll_for_job_complete(job_id, delay)
self.log.info("AWS Batch job (%s) has completed", job_id)
def poll_for_job_running(
self, job_id: str, delay: Union[int, float, None] = None
) -> None:
"""
Poll for job running. The status that indicates a job is running or
already complete are: 'RUNNING'|'SUCCEEDED'|'FAILED'.
So the status options that this will wait for are the transitions from:
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'|'SUCCEEDED'|'FAILED'
The completed status options are included for cases where the status
changes too quickly for polling to detect a RUNNING status that moves
quickly from STARTING to RUNNING to completed (often a failure).
:param job_id: a batch job ID
:type job_id: str
:param delay: a delay before polling for job status
:type delay: Optional[Union[int, float]]
:raises: AirflowException
"""
self.delay(delay)
running_status = ["RUNNING", "SUCCEEDED", "FAILED"]
self.poll_job_status(job_id, running_status)
def poll_for_job_complete(
self, job_id: str, delay: Union[int, float, None] = None
) -> None:
"""
Poll for job completion. The status that indicates job completion
are: 'SUCCEEDED'|'FAILED'.
So the status options that this will wait for are the transitions from:
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED'
:param job_id: a batch job ID
:type job_id: str
:param delay: a delay before polling for job status
:type delay: Optional[Union[int, float]]
:raises: AirflowException
"""
self.delay(delay)
complete_status = ["SUCCEEDED", "FAILED"]
self.poll_job_status(job_id, complete_status)
def poll_job_status(self, job_id: str, match_status: List[str]) -> bool:
"""
Poll for job status using an exponential back-off strategy (with max_retries).
:param job_id: a batch job ID
:type job_id: str
:param match_status: a list of job status to match; the batch job status are:
'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
:type match_status: List[str]
:rtype: bool
:raises: AirflowException
"""
retries = 0
while True:
job = self.get_job_description(job_id)
job_status = job.get("status")
self.log.info(
"AWS Batch job (%s) check status (%s) in %s",
job_id,
job_status,
match_status,
)
if job_status in match_status:
return True
if retries >= self.max_retries:
raise AirflowException(
f"AWS Batch job ({job_id}) status checks exceed max_retries"
)
retries += 1
pause = self.exponential_delay(retries)
self.log.info(
"AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds",
job_id,
retries,
self.max_retries,
pause,
)
self.delay(pause)
def get_job_description(self, job_id: str) -> Dict:
"""
Get job description (using status_retries).
:param job_id: a batch job ID
:type job_id: str
:return: an API response for describe jobs
:rtype: Dict
:raises: AirflowException
"""
retries = 0
while True:
try:
response = self.get_conn().describe_jobs(jobs=[job_id])
return self.parse_job_description(job_id, response)
except botocore.exceptions.ClientError as err:
error = err.response.get("Error", {})
if error.get("Code") == "TooManyRequestsException":
pass # allow it to retry, if possible
else:
raise AirflowException(
f"AWS Batch job ({job_id}) description error: {err}"
)
retries += 1
if retries >= self.status_retries:
raise AirflowException(
"AWS Batch job ({}) description error: exceeded "
"status_retries ({})".format(job_id, self.status_retries)
)
pause = self.exponential_delay(retries)
self.log.info(
"AWS Batch job (%s) description retry (%d of %d) in the next %.2f seconds",
job_id,
retries,
self.status_retries,
pause,
)
self.delay(pause)
@staticmethod
def parse_job_description(job_id: str, response: Dict) -> Dict:
"""
Parse job description to extract description for job_id
:param job_id: a batch job ID
:type job_id: str
:param response: an API response for describe jobs
:type response: Dict
:return: an API response to describe job_id
:rtype: Dict
:raises: AirflowException
"""
jobs = response.get("jobs", [])
matching_jobs = [job for job in jobs if job.get("jobId") == job_id]
if len(matching_jobs) != 1:
raise AirflowException(
f"AWS Batch job ({job_id}) description error: response: {response}"
)
return matching_jobs[0]
@staticmethod
def add_jitter(
delay: Union[int, float],
width: Union[int, float] = 1,
minima: Union[int, float] = 0,
) -> float:
"""
Use delay +/- width for random jitter
Adding jitter to status polling can help to avoid
AWS batch API limits for monitoring batch jobs with
a high concurrency in Airflow tasks.
:param delay: number of seconds to pause;
delay is assumed to be a positive number
:type delay: Union[int, float]
:param width: delay +/- width for random jitter;
width is assumed to be a positive number
:type width: Union[int, float]
:param minima: minimum delay allowed;
minima is assumed to be a non-negative number
:type minima: Union[int, float]
:return: uniform(delay - width, delay + width) jitter
and it is a non-negative number
:rtype: float
"""
delay = abs(delay)
width = abs(width)
minima = abs(minima)
lower = max(minima, delay - width)
upper = delay + width
return uniform(lower, upper)
@staticmethod
def delay(delay: Union[int, float, None] = None) -> None:
"""
Pause execution for ``delay`` seconds.
:param delay: a delay to pause execution using ``time.sleep(delay)``;
a small 1 second jitter is applied to the delay.
:type delay: Optional[Union[int, float]]
.. note::
This method uses a default random delay, i.e.
``random.uniform(DEFAULT_DELAY_MIN, DEFAULT_DELAY_MAX)``;
using a random interval helps to avoid AWS API throttle limits
when many concurrent tasks request job-descriptions.
"""
if delay is None:
delay = uniform(
AwsBatchClientHook.DEFAULT_DELAY_MIN,
AwsBatchClientHook.DEFAULT_DELAY_MAX,
)
else:
delay = AwsBatchClientHook.add_jitter(delay)
sleep(delay)
@staticmethod
def exponential_delay(tries: int) -> float:
"""
An exponential back-off delay, with random jitter. There is a maximum
interval of 10 minutes (with random jitter between 3 and 10 minutes).
This is used in the :py:meth:`.poll_for_job_status` method.
:param tries: Number of tries
:type tries: int
:rtype: float
Examples of behavior:
.. code-block:: python
def exp(tries):
max_interval = 600.0 # 10 minutes in seconds
delay = 1 + pow(tries * 0.6, 2)
delay = min(max_interval, delay)
print(delay / 3, delay)
for tries in range(10):
exp(tries)
# 0.33 1.0
# 0.45 1.35
# 0.81 2.44
# 1.41 4.23
# 2.25 6.76
# 3.33 10.00
# 4.65 13.95
# 6.21 18.64
# 8.01 24.04
# 10.05 30.15
.. seealso::
- https://docs.aws.amazon.com/general/latest/gr/api-retries.html
- https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
"""
max_interval = 600.0 # results in 3 to 10 minute delay
delay = 1 + pow(tries * 0.6, 2)
delay = min(max_interval, delay)
return uniform(delay / 3, delay)

@ -0,0 +1,105 @@
{
"version": 2,
"waiters": {
"JobExists": {
"delay": 2,
"operation": "DescribeJobs",
"maxAttempts": 100,
"acceptors": [
{
"argument": "jobs[].status",
"expected": "SUBMITTED",
"matcher": "pathAll",
"state": "success"
},
{
"argument": "jobs[].status",
"expected": "PENDING",
"matcher": "pathAll",
"state": "success"
},
{
"argument": "jobs[].status",
"expected": "RUNNABLE",
"matcher": "pathAll",
"state": "success"
},
{
"argument": "jobs[].status",
"expected": "STARTING",
"matcher": "pathAll",
"state": "success"
},
{
"argument": "jobs[].status",
"expected": "RUNNING",
"matcher": "pathAll",
"state": "success"
},
{
"argument": "jobs[].status",
"expected": "FAILED",
"matcher": "pathAll",
"state": "success"
},
{
"argument": "jobs[].status",
"expected": "SUCCEEDED",
"matcher": "pathAll",
"state": "success"
},
{
"argument": "jobs[].status",
"expected": "",
"matcher": "error",
"state": "failure"
}
]
},
"JobRunning": {
"delay": 5,
"operation": "DescribeJobs",
"maxAttempts": 100,
"acceptors": [
{
"argument": "jobs[].status",
"expected": "RUNNING",
"matcher": "pathAll",
"state": "success"
},
{
"argument": "jobs[].status",
"expected": "FAILED",
"matcher": "pathAll",
"state": "success"
},
{
"argument": "jobs[].status",
"expected": "SUCCEEDED",
"matcher": "pathAll",
"state": "success"
}
]
},
"JobComplete": {
"delay": 300,
"operation": "DescribeJobs",
"maxAttempts": 288,
"description": "Wait until job status is SUCCEEDED or FAILED",
"acceptors": [
{
"argument": "jobs[].status",
"expected": "SUCCEEDED",
"matcher": "pathAll",
"state": "success"
},
{
"argument": "jobs[].status",
"expected": "FAILED",
"matcher": "pathAny",
"state": "failure"
}
]
}
}
}

@ -0,0 +1,242 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
"""
AWS batch service waiters
.. seealso::
- https://boto3.amazonaws.com/v1/documentation/api/latest/guide/clients.html#waiters
- https://github.com/boto/botocore/blob/develop/botocore/waiter.py
"""
import json
import sys
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Optional, Union
import botocore.client
import botocore.exceptions
import botocore.waiter
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import AwsBatchClientHook
class AwsBatchWaitersHook(AwsBatchClientHook):
"""
A utility to manage waiters for AWS batch services
Examples:
.. code-block:: python
import random
from airflow.providers.amazon.aws.operators.batch_waiters import AwsBatchWaiters
# to inspect default waiters
waiters = AwsBatchWaiters()
config = waiters.default_config # type: Dict
waiter_names = waiters.list_waiters() # -> ["JobComplete", "JobExists", "JobRunning"]
# The default_config is a useful stepping stone to creating custom waiters, e.g.
custom_config = waiters.default_config # this is a deepcopy
# modify custom_config['waiters'] as necessary and get a new instance:
waiters = AwsBatchWaiters(waiter_config=custom_config)
waiters.waiter_config # check the custom configuration (this is a deepcopy)
waiters.list_waiters() # names of custom waiters
# During the init for AwsBatchWaiters, the waiter_config is used to build a waiter_model;
# and note that this only occurs during the class init, to avoid any accidental mutations
# of waiter_config leaking into the waiter_model.
waiters.waiter_model # -> botocore.waiter.WaiterModel object
# The waiter_model is combined with the waiters.client to get a specific waiter
# and the details of the config on that waiter can be further modified without any
# accidental impact on the generation of new waiters from the defined waiter_model, e.g.
waiters.get_waiter("JobExists").config.delay # -> 5
waiter = waiters.get_waiter("JobExists") # -> botocore.waiter.Batch.Waiter.JobExists object
waiter.config.delay = 10
waiters.get_waiter("JobExists").config.delay # -> 5 as defined by waiter_model
# To use a specific waiter, update the config and call the `wait()` method for jobId, e.g.
waiter = waiters.get_waiter("JobExists") # -> botocore.waiter.Batch.Waiter.JobExists object
waiter.config.delay = random.uniform(1, 10) # seconds
waiter.config.max_attempts = 10
waiter.wait(jobs=[jobId])
.. seealso::
- https://www.2ndwatch.com/blog/use-waiters-boto3-write/
- https://github.com/boto/botocore/blob/develop/botocore/waiter.py
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html#waiters
- https://github.com/boto/botocore/tree/develop/botocore/data/ec2/2016-11-15
- https://github.com/boto/botocore/issues/1915
:param waiter_config: a custom waiter configuration for AWS batch services
:type waiter_config: Optional[Dict]
:param aws_conn_id: connection id of AWS credentials / region name. If None,
credential boto3 strategy will be used
(http://boto3.readthedocs.io/en/latest/guide/configuration.html).
:type aws_conn_id: Optional[str]
:param region_name: region name to use in AWS client.
Override the AWS region in connection (if provided)
:type region_name: Optional[str]
"""
def __init__(self, *args, waiter_config: Optional[Dict] = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._default_config = None # type: Optional[Dict]
self._waiter_config = waiter_config or self.default_config
self._waiter_model = botocore.waiter.WaiterModel(self._waiter_config)
@property
def default_config(self) -> Dict:
"""
An immutable default waiter configuration
:return: a waiter configuration for AWS batch services
:rtype: Dict
"""
if self._default_config is None:
config_path = Path(__file__).with_name("batch_waiters.json").absolute()
with open(config_path) as config_file:
self._default_config = json.load(config_file)
return deepcopy(self._default_config) # avoid accidental mutation
@property
def waiter_config(self) -> Dict:
"""
An immutable waiter configuration for this instance; a ``deepcopy`` is returned by this
property. During the init for AwsBatchWaiters, the waiter_config is used to build a
waiter_model and this only occurs during the class init, to avoid any accidental
mutations of waiter_config leaking into the waiter_model.
:return: a waiter configuration for AWS batch services
:rtype: Dict
"""
return deepcopy(self._waiter_config) # avoid accidental mutation
@property
def waiter_model(self) -> botocore.waiter.WaiterModel:
"""
A configured waiter model used to generate waiters on AWS batch services.
:return: a waiter model for AWS batch services
:rtype: botocore.waiter.WaiterModel
"""
return self._waiter_model
def get_waiter(self, waiter_name: str) -> botocore.waiter.Waiter:
"""
Get an AWS Batch service waiter, using the configured ``.waiter_model``.
The ``.waiter_model`` is combined with the ``.client`` to get a specific waiter and
the properties of that waiter can be modified without any accidental impact on the
generation of new waiters from the ``.waiter_model``, e.g.
.. code-block::
waiters.get_waiter("JobExists").config.delay # -> 5
waiter = waiters.get_waiter("JobExists") # a new waiter object
waiter.config.delay = 10
waiters.get_waiter("JobExists").config.delay # -> 5 as defined by waiter_model
To use a specific waiter, update the config and call the `wait()` method for jobId, e.g.
.. code-block::
import random
waiter = waiters.get_waiter("JobExists") # a new waiter object
waiter.config.delay = random.uniform(1, 10) # seconds
waiter.config.max_attempts = 10
waiter.wait(jobs=[jobId])
:param waiter_name: The name of the waiter. The name should match
the name (including the casing) of the key name in the waiter
model file (typically this is CamelCasing); see ``.list_waiters``.
:type waiter_name: str
:return: a waiter object for the named AWS batch service
:rtype: botocore.waiter.Waiter
"""
return botocore.waiter.create_waiter_with_client(
waiter_name, self.waiter_model, self.client
)
def list_waiters(self) -> List[str]:
"""
List the waiters in a waiter configuration for AWS Batch services.
:return: waiter names for AWS batch services
:rtype: List[str]
"""
return self.waiter_model.waiter_names
def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None:
"""
Wait for batch job to complete. This assumes that the ``.waiter_model`` is configured
using some variation of the ``.default_config`` so that it can generate waiters with the
following names: "JobExists", "JobRunning" and "JobComplete".
:param job_id: a batch job ID
:type job_id: str
:param delay: A delay before polling for job status
:type delay: Union[int, float, None]
:raises: AirflowException
.. note::
This method adds a small random jitter to the ``delay`` (+/- 2 sec, >= 1 sec).
Using a random interval helps to avoid AWS API throttle limits when many
concurrent tasks request job-descriptions.
It also modifies the ``max_attempts`` to use the ``sys.maxsize``,
which allows Airflow to manage the timeout on waiting.
"""
self.delay(delay)
try:
waiter = self.get_waiter("JobExists")
waiter.config.delay = self.add_jitter(
waiter.config.delay, width=2, minima=1
)
waiter.config.max_attempts = sys.maxsize # timeout is managed by Airflow
waiter.wait(jobs=[job_id])
waiter = self.get_waiter("JobRunning")
waiter.config.delay = self.add_jitter(
waiter.config.delay, width=2, minima=1
)
waiter.config.max_attempts = sys.maxsize # timeout is managed by Airflow
waiter.wait(jobs=[job_id])
waiter = self.get_waiter("JobComplete")
waiter.config.delay = self.add_jitter(
waiter.config.delay, width=2, minima=1
)
waiter.config.max_attempts = sys.maxsize # timeout is managed by Airflow
waiter.wait(jobs=[job_id])
except (
botocore.exceptions.ClientError,
botocore.exceptions.WaiterError,
) as err:
raise AirflowException(err)

@ -0,0 +1,79 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains AWS CloudFormation Hook"""
from typing import Optional, Union
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from boto3 import client, resource
from botocore.exceptions import ClientError
class AWSCloudFormationHook(AwsBaseHook):
"""
Interact with AWS CloudFormation.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
def __init__(self, *args, **kwargs):
super().__init__(client_type="cloudformation", *args, **kwargs)
def get_stack_status(self, stack_name: Union[client, resource]) -> Optional[dict]:
"""Get stack status from CloudFormation."""
self.log.info("Poking for stack %s", stack_name)
try:
stacks = self.get_conn().describe_stacks(StackName=stack_name)["Stacks"]
return stacks[0]["StackStatus"]
except ClientError as e:
if "does not exist" in str(e):
return None
else:
raise e
def create_stack(self, stack_name: str, params: dict) -> None:
"""
Create stack in CloudFormation.
:param stack_name: stack_name.
:type stack_name: str
:param params: parameters to be passed to CloudFormation.
:type params: dict
"""
if "StackName" not in params:
params["StackName"] = stack_name
self.get_conn().create_stack(**params)
def delete_stack(self, stack_name: str, params: Optional[dict] = None) -> None:
"""
Delete stack in CloudFormation.
:param stack_name: stack_name.
:type stack_name: str
:param params: parameters to be passed to CloudFormation (optional).
:type params: dict
"""
params = params or {}
if "StackName" not in params:
params["StackName"] = stack_name
self.get_conn().delete_stack(**params)

@ -0,0 +1,334 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Interact with AWS DataSync, using the AWS ``boto3`` library."""
import time
from typing import List, Optional
from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowTaskTimeout
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class AWSDataSyncHook(AwsBaseHook):
"""
Interact with AWS DataSync.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
:class:`~airflow.providers.amazon.aws.operators.datasync.AWSDataSyncOperator`
:param int wait_for_task_execution: Time to wait between two
consecutive calls to check TaskExecution status.
:raises ValueError: If wait_interval_seconds is not between 0 and 15*60 seconds.
"""
TASK_EXECUTION_INTERMEDIATE_STATES = (
"INITIALIZING",
"QUEUED",
"LAUNCHING",
"PREPARING",
"TRANSFERRING",
"VERIFYING",
)
TASK_EXECUTION_FAILURE_STATES = ("ERROR",)
TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",)
def __init__(self, wait_interval_seconds: int = 5, *args, **kwargs) -> None:
super().__init__(client_type="datasync", *args, **kwargs) # type: ignore[misc]
self.locations: list = []
self.tasks: list = []
# wait_interval_seconds = 0 is used during unit tests
if wait_interval_seconds < 0 or wait_interval_seconds > 15 * 60:
raise ValueError(f"Invalid wait_interval_seconds {wait_interval_seconds}")
self.wait_interval_seconds = wait_interval_seconds
def create_location(self, location_uri: str, **create_location_kwargs) -> str:
r"""Creates a new location.
:param str location_uri: Location URI used to determine the location type (S3, SMB, NFS, EFS).
:param create_location_kwargs: Passed to ``boto.create_location_xyz()``.
See AWS boto3 datasync documentation.
:return str: LocationArn of the created Location.
:raises AirflowException: If location type (prefix from ``location_uri``) is invalid.
"""
typ = location_uri.split(":")[0]
if typ == "smb":
location = self.get_conn().create_location_smb(**create_location_kwargs)
elif typ == "s3":
location = self.get_conn().create_location_s3(**create_location_kwargs)
elif typ == "nfs":
location = self.get_conn().create_loction_nfs(**create_location_kwargs)
elif typ == "efs":
location = self.get_conn().create_loction_efs(**create_location_kwargs)
else:
raise AirflowException(f"Invalid location type: {typ}")
self._refresh_locations()
return location["LocationArn"]
def get_location_arns(
self,
location_uri: str,
case_sensitive: bool = False,
ignore_trailing_slash: bool = True,
) -> List[str]:
"""
Return all LocationArns which match a LocationUri.
:param str location_uri: Location URI to search for, eg ``s3://mybucket/mypath``
:param bool case_sensitive: Do a case sensitive search for location URI.
:param bool ignore_trailing_slash: Ignore / at the end of URI when matching.
:return: List of LocationArns.
:rtype: list(str)
:raises AirflowBadRequest: if ``location_uri`` is empty
"""
if not location_uri:
raise AirflowBadRequest("location_uri not specified")
if not self.locations:
self._refresh_locations()
result = []
if not case_sensitive:
location_uri = location_uri.lower()
if ignore_trailing_slash and location_uri.endswith("/"):
location_uri = location_uri[:-1]
for location_from_aws in self.locations:
location_uri_from_aws = location_from_aws["LocationUri"]
if not case_sensitive:
location_uri_from_aws = location_uri_from_aws.lower()
if ignore_trailing_slash and location_uri_from_aws.endswith("/"):
location_uri_from_aws = location_uri_from_aws[:-1]
if location_uri == location_uri_from_aws:
result.append(location_from_aws["LocationArn"])
return result
def _refresh_locations(self) -> None:
"""Refresh the local list of Locations."""
self.locations = []
next_token = None
while True:
if next_token:
locations = self.get_conn().list_locations(NextToken=next_token)
else:
locations = self.get_conn().list_locations()
self.locations.extend(locations["Locations"])
if "NextToken" not in locations:
break
next_token = locations["NextToken"]
def create_task(
self,
source_location_arn: str,
destination_location_arn: str,
**create_task_kwargs,
) -> str:
r"""Create a Task between the specified source and destination LocationArns.
:param str source_location_arn: Source LocationArn. Must exist already.
:param str destination_location_arn: Destination LocationArn. Must exist already.
:param create_task_kwargs: Passed to ``boto.create_task()``. See AWS boto3 datasync documentation.
:return: TaskArn of the created Task
:rtype: str
"""
task = self.get_conn().create_task(
SourceLocationArn=source_location_arn,
DestinationLocationArn=destination_location_arn,
**create_task_kwargs,
)
self._refresh_tasks()
return task["TaskArn"]
def update_task(self, task_arn: str, **update_task_kwargs) -> None:
r"""Update a Task.
:param str task_arn: The TaskArn to update.
:param update_task_kwargs: Passed to ``boto.update_task()``, See AWS boto3 datasync documentation.
"""
self.get_conn().update_task(TaskArn=task_arn, **update_task_kwargs)
def delete_task(self, task_arn: str) -> None:
r"""Delete a Task.
:param str task_arn: The TaskArn to delete.
"""
self.get_conn().delete_task(TaskArn=task_arn)
def _refresh_tasks(self) -> None:
"""Refreshes the local list of Tasks"""
self.tasks = []
next_token = None
while True:
if next_token:
tasks = self.get_conn().list_tasks(NextToken=next_token)
else:
tasks = self.get_conn().list_tasks()
self.tasks.extend(tasks["Tasks"])
if "NextToken" not in tasks:
break
next_token = tasks["NextToken"]
def get_task_arns_for_location_arns(
self,
source_location_arns: list,
destination_location_arns: list,
) -> list:
"""
Return list of TaskArns for which use any one of the specified
source LocationArns and any one of the specified destination LocationArns.
:param list source_location_arns: List of source LocationArns.
:param list destination_location_arns: List of destination LocationArns.
:return: list
:rtype: list(TaskArns)
:raises AirflowBadRequest: if ``source_location_arns`` or ``destination_location_arns`` are empty.
"""
if not source_location_arns:
raise AirflowBadRequest("source_location_arns not specified")
if not destination_location_arns:
raise AirflowBadRequest("destination_location_arns not specified")
if not self.tasks:
self._refresh_tasks()
result = []
for task in self.tasks:
task_arn = task["TaskArn"]
task_description = self.get_task_description(task_arn)
if task_description["SourceLocationArn"] in source_location_arns:
if (
task_description["DestinationLocationArn"]
in destination_location_arns
):
result.append(task_arn)
return result
def start_task_execution(self, task_arn: str, **kwargs) -> str:
r"""
Start a TaskExecution for the specified task_arn.
Each task can have at most one TaskExecution.
:param str task_arn: TaskArn
:return: TaskExecutionArn
:param kwargs: kwargs sent to ``boto3.start_task_execution()``
:rtype: str
:raises ClientError: If a TaskExecution is already busy running for this ``task_arn``.
:raises AirflowBadRequest: If ``task_arn`` is empty.
"""
if not task_arn:
raise AirflowBadRequest("task_arn not specified")
task_execution = self.get_conn().start_task_execution(
TaskArn=task_arn, **kwargs
)
return task_execution["TaskExecutionArn"]
def cancel_task_execution(self, task_execution_arn: str) -> None:
"""
Cancel a TaskExecution for the specified ``task_execution_arn``.
:param str task_execution_arn: TaskExecutionArn.
:raises AirflowBadRequest: If ``task_execution_arn`` is empty.
"""
if not task_execution_arn:
raise AirflowBadRequest("task_execution_arn not specified")
self.get_conn().cancel_task_execution(TaskExecutionArn=task_execution_arn)
def get_task_description(self, task_arn: str) -> dict:
"""
Get description for the specified ``task_arn``.
:param str task_arn: TaskArn
:return: AWS metadata about a task.
:rtype: dict
:raises AirflowBadRequest: If ``task_arn`` is empty.
"""
if not task_arn:
raise AirflowBadRequest("task_arn not specified")
return self.get_conn().describe_task(TaskArn=task_arn)
def describe_task_execution(self, task_execution_arn: str) -> dict:
"""
Get description for the specified ``task_execution_arn``.
:param str task_execution_arn: TaskExecutionArn
:return: AWS metadata about a task execution.
:rtype: dict
:raises AirflowBadRequest: If ``task_execution_arn`` is empty.
"""
return self.get_conn().describe_task_execution(
TaskExecutionArn=task_execution_arn
)
def get_current_task_execution_arn(self, task_arn: str) -> Optional[str]:
"""
Get current TaskExecutionArn (if one exists) for the specified ``task_arn``.
:param str task_arn: TaskArn
:return: CurrentTaskExecutionArn for this ``task_arn`` or None.
:rtype: str
:raises AirflowBadRequest: if ``task_arn`` is empty.
"""
if not task_arn:
raise AirflowBadRequest("task_arn not specified")
task_description = self.get_task_description(task_arn)
if "CurrentTaskExecutionArn" in task_description:
return task_description["CurrentTaskExecutionArn"]
return None
def wait_for_task_execution(
self, task_execution_arn: str, max_iterations: int = 2 * 180
) -> bool:
"""
Wait for Task Execution status to be complete (SUCCESS/ERROR).
The ``task_execution_arn`` must exist, or a boto3 ClientError will be raised.
:param str task_execution_arn: TaskExecutionArn
:param int max_iterations: Maximum number of iterations before timing out.
:return: Result of task execution.
:rtype: bool
:raises AirflowTaskTimeout: If maximum iterations is exceeded.
:raises AirflowBadRequest: If ``task_execution_arn`` is empty.
"""
if not task_execution_arn:
raise AirflowBadRequest("task_execution_arn not specified")
status = None
iterations = max_iterations
while status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES:
task_execution = self.get_conn().describe_task_execution(
TaskExecutionArn=task_execution_arn
)
status = task_execution["Status"]
self.log.info("status=%s", status)
iterations -= 1
if status in self.TASK_EXECUTION_FAILURE_STATES:
break
if status in self.TASK_EXECUTION_SUCCESS_STATES:
break
if iterations <= 0:
break
time.sleep(self.wait_interval_seconds)
if status in self.TASK_EXECUTION_SUCCESS_STATES:
return True
if status in self.TASK_EXECUTION_FAILURE_STATES:
return False
if iterations <= 0:
raise AirflowTaskTimeout("Max iterations exceeded!")
raise AirflowException(f"Unknown status: {status}") # Should never happen

@ -0,0 +1,67 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains the AWS DynamoDB hook"""
from typing import Iterable, List, Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class AwsDynamoDBHook(AwsBaseHook):
"""
Interact with AWS DynamoDB.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
:param table_keys: partition key and sort key
:type table_keys: list
:param table_name: target DynamoDB table
:type table_name: str
"""
def __init__(
self,
*args,
table_keys: Optional[List] = None,
table_name: Optional[str] = None,
**kwargs,
) -> None:
self.table_keys = table_keys
self.table_name = table_name
kwargs["resource_type"] = "dynamodb"
super().__init__(*args, **kwargs)
def write_batch_data(self, items: Iterable) -> bool:
"""Write batch items to DynamoDB table with provisioned throughout capacity."""
try:
table = self.get_conn().Table(self.table_name)
with table.batch_writer(overwrite_by_pkeys=self.table_keys) as batch:
for item in items:
batch.put_item(Item=item)
return True
except Exception as general_error:
raise AirflowException(
f"Failed to insert items in dynamodb, error: {str(general_error)}"
)

@ -0,0 +1,82 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import time
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class EC2Hook(AwsBaseHook):
"""
Interact with AWS EC2 Service.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
def __init__(self, *args, **kwargs) -> None:
kwargs["resource_type"] = "ec2"
super().__init__(*args, **kwargs)
def get_instance(self, instance_id: str):
"""
Get EC2 instance by id and return it.
:param instance_id: id of the AWS EC2 instance
:type instance_id: str
:return: Instance object
:rtype: ec2.Instance
"""
return self.conn.Instance(id=instance_id)
def get_instance_state(self, instance_id: str) -> str:
"""
Get EC2 instance state by id and return it.
:param instance_id: id of the AWS EC2 instance
:type instance_id: str
:return: current state of the instance
:rtype: str
"""
return self.get_instance(instance_id=instance_id).state["Name"]
def wait_for_state(
self, instance_id: str, target_state: str, check_interval: float
) -> None:
"""
Wait EC2 instance until its state is equal to the target_state.
:param instance_id: id of the AWS EC2 instance
:type instance_id: str
:param target_state: target state of instance
:type target_state: str
:param check_interval: time in seconds that the job should wait in
between each instance state checks until operation is completed
:type check_interval: float
:return: None
:rtype: None
"""
instance_state = self.get_instance_state(instance_id=instance_id)
while instance_state != target_state:
self.log.info("instance state: %s", instance_state)
time.sleep(check_interval)
instance_state = self.get_instance_state(instance_id=instance_id)

@ -0,0 +1,325 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from time import sleep
from typing import Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class ElastiCacheReplicationGroupHook(AwsBaseHook):
"""
Interact with AWS ElastiCache
:param max_retries: Max retries for checking availability of and deleting replication group
If this is not supplied then this is defaulted to 10
:type max_retries: int
:param exponential_back_off_factor: Multiplication factor for deciding next sleep time
If this is not supplied then this is defaulted to 1
:type exponential_back_off_factor: float
:param initial_poke_interval: Initial sleep time in seconds
If this is not supplied then this is defaulted to 60 seconds
:type initial_poke_interval: float
"""
TERMINAL_STATES = frozenset({"available", "create-failed", "deleting"})
def __init__(
self,
max_retries: int = 10,
exponential_back_off_factor: float = 1,
initial_poke_interval: float = 60,
*args,
**kwargs,
):
self.max_retries = max_retries
self.exponential_back_off_factor = exponential_back_off_factor
self.initial_poke_interval = initial_poke_interval
kwargs["client_type"] = "elasticache"
super().__init__(*args, **kwargs)
def create_replication_group(self, config: dict) -> dict:
"""
Call ElastiCache API for creating a replication group
:param config: Configuration for creating the replication group
:type config: dict
:return: Response from ElastiCache create replication group API
:rtype: dict
"""
return self.conn.create_replication_group(**config)
def delete_replication_group(self, replication_group_id: str) -> dict:
"""
Call ElastiCache API for deleting a replication group
:param replication_group_id: ID of replication group to delete
:type replication_group_id: str
:return: Response from ElastiCache delete replication group API
:rtype: dict
"""
return self.conn.delete_replication_group(
ReplicationGroupId=replication_group_id
)
def describe_replication_group(self, replication_group_id: str) -> dict:
"""
Call ElastiCache API for describing a replication group
:param replication_group_id: ID of replication group to describe
:type replication_group_id: str
:return: Response from ElastiCache describe replication group API
:rtype: dict
"""
return self.conn.describe_replication_groups(
ReplicationGroupId=replication_group_id
)
def get_replication_group_status(self, replication_group_id: str) -> str:
"""
Get current status of replication group
:param replication_group_id: ID of replication group to check for status
:type replication_group_id: str
:return: Current status of replication group
:rtype: str
"""
return self.describe_replication_group(replication_group_id)[
"ReplicationGroups"
][0]["Status"]
def is_replication_group_available(self, replication_group_id: str) -> bool:
"""
Helper for checking if replication group is available or not
:param replication_group_id: ID of replication group to check for availability
:type replication_group_id: str
:return: True if available else False
:rtype: bool
"""
return self.get_replication_group_status(replication_group_id) == "available"
def wait_for_availability(
self,
replication_group_id: str,
initial_sleep_time: Optional[float] = None,
exponential_back_off_factor: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Check if replication group is available or not by performing a describe over it
:param replication_group_id: ID of replication group to check for availability
:type replication_group_id: str
:param initial_sleep_time: Initial sleep time in seconds
If this is not supplied then this is defaulted to class level value
:type initial_sleep_time: float
:param exponential_back_off_factor: Multiplication factor for deciding next sleep time
If this is not supplied then this is defaulted to class level value
:type exponential_back_off_factor: float
:param max_retries: Max retries for checking availability of replication group
If this is not supplied then this is defaulted to class level value
:type max_retries: int
:return: True if replication is available else False
:rtype: bool
"""
sleep_time = initial_sleep_time or self.initial_poke_interval
exponential_back_off_factor = (
exponential_back_off_factor or self.exponential_back_off_factor
)
max_retries = max_retries or self.max_retries
num_tries = 0
status = "not-found"
stop_poking = False
while not stop_poking and num_tries <= max_retries:
status = self.get_replication_group_status(
replication_group_id=replication_group_id
)
stop_poking = status in self.TERMINAL_STATES
self.log.info(
"Current status of replication group with ID %s is %s",
replication_group_id,
status,
)
if not stop_poking:
num_tries += 1
# No point in sleeping if all tries have exhausted
if num_tries > max_retries:
break
self.log.info(
"Poke retry %s. Sleep time %s seconds. Sleeping...",
num_tries,
sleep_time,
)
sleep(sleep_time)
sleep_time *= exponential_back_off_factor
if status != "available":
self.log.warning(
'Replication group is not available. Current status is "%s"', status
)
return False
return True
def wait_for_deletion(
self,
replication_group_id: str,
initial_sleep_time: Optional[float] = None,
exponential_back_off_factor: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Helper for deleting a replication group ensuring it is either deleted or can't be deleted
:param replication_group_id: ID of replication to delete
:type replication_group_id: str
:param initial_sleep_time: Initial sleep time in second
If this is not supplied then this is defaulted to class level value
:type initial_sleep_time: float
:param exponential_back_off_factor: Multiplication factor for deciding next sleep time
If this is not supplied then this is defaulted to class level value
:type exponential_back_off_factor: float
:param max_retries: Max retries for checking availability of replication group
If this is not supplied then this is defaulted to class level value
:type max_retries: int
:return: Response from ElastiCache delete replication group API and flag to identify if deleted or not
:rtype: (dict, bool)
"""
deleted = False
sleep_time = initial_sleep_time or self.initial_poke_interval
exponential_back_off_factor = (
exponential_back_off_factor or self.exponential_back_off_factor
)
max_retries = max_retries or self.max_retries
num_tries = 0
response = None
while not deleted and num_tries <= max_retries:
try:
status = self.get_replication_group_status(
replication_group_id=replication_group_id
)
self.log.info(
"Current status of replication group with ID %s is %s",
replication_group_id,
status,
)
# Can only delete if status is `available`
# Status becomes `deleting` after this call so this will only run once
if status == "available":
self.log.info("Initiating delete and then wait for it to finish")
response = self.delete_replication_group(
replication_group_id=replication_group_id
)
except self.conn.exceptions.ReplicationGroupNotFoundFault:
self.log.info(
"Replication group with ID '%s' does not exist",
replication_group_id,
)
deleted = True
# This should never occur as we only issue a delete request when status is `available`
# which is a valid status for deletion. Still handling for safety.
except self.conn.exceptions.InvalidReplicationGroupStateFault as exp:
# status Error Response
# creating - Cache cluster <cluster_id> is not in a valid state to be deleted.
# deleting - Replication group <replication_group_id> has status deleting which is not valid
# for deletion.
# modifying - Replication group <replication_group_id> has status deleting which is not valid
# for deletion.
message = exp.response["Error"]["Message"]
self.log.warning(
"Received error message from AWS ElastiCache API : %s", message
)
if not deleted:
num_tries += 1
# No point in sleeping if all tries have exhausted
if num_tries > max_retries:
break
self.log.info(
"Poke retry %s. Sleep time %s seconds. Sleeping...",
num_tries,
sleep_time,
)
sleep(sleep_time)
sleep_time *= exponential_back_off_factor
return response, deleted
def ensure_delete_replication_group(
self,
replication_group_id: str,
initial_sleep_time: Optional[float] = None,
exponential_back_off_factor: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Delete a replication group ensuring it is either deleted or can't be deleted
:param replication_group_id: ID of replication to delete
:type replication_group_id: str
:param initial_sleep_time: Initial sleep time in second
If this is not supplied then this is defaulted to class level value
:type initial_sleep_time: float
:param exponential_back_off_factor: Multiplication factor for deciding next sleep time
If this is not supplied then this is defaulted to class level value
:type exponential_back_off_factor: float
:param max_retries: Max retries for checking availability of replication group
If this is not supplied then this is defaulted to class level value
:type max_retries: int
:return: Response from ElastiCache delete replication group API
:rtype: dict
:raises AirflowException: If replication group is not deleted
"""
self.log.info("Deleting replication group with ID %s", replication_group_id)
response, deleted = self.wait_for_deletion(
replication_group_id=replication_group_id,
initial_sleep_time=initial_sleep_time,
exponential_back_off_factor=exponential_back_off_factor,
max_retries=max_retries,
)
if not deleted:
raise AirflowException(
f'Replication group could not be deleted. Response "{response}"'
)
return response

@ -0,0 +1,101 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class EmrHook(AwsBaseHook):
"""
Interact with AWS EMR. emr_conn_id is only necessary for using the
create_job_flow method.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
conn_name_attr = "emr_conn_id"
default_conn_name = "emr_default"
conn_type = "emr"
hook_name = "Elastic MapReduce"
def __init__(
self, emr_conn_id: Optional[str] = default_conn_name, *args, **kwargs
) -> None:
self.emr_conn_id = emr_conn_id
kwargs["client_type"] = "emr"
super().__init__(*args, **kwargs)
def get_cluster_id_by_name(
self, emr_cluster_name: str, cluster_states: List[str]
) -> Optional[str]:
"""
Fetch id of EMR cluster with given name and (optional) states.
Will return only if single id is found.
:param emr_cluster_name: Name of a cluster to find
:type emr_cluster_name: str
:param cluster_states: State(s) of cluster to find
:type cluster_states: list
:return: id of the EMR cluster
"""
response = self.get_conn().list_clusters(ClusterStates=cluster_states)
matching_clusters = list(
filter(
lambda cluster: cluster["Name"] == emr_cluster_name,
response["Clusters"],
)
)
if len(matching_clusters) == 1:
cluster_id = matching_clusters[0]["Id"]
self.log.info(
"Found cluster name = %s id = %s", emr_cluster_name, cluster_id
)
return cluster_id
elif len(matching_clusters) > 1:
raise AirflowException(
f"More than one cluster found for name {emr_cluster_name}"
)
else:
self.log.info("No cluster found for name %s", emr_cluster_name)
return None
def create_job_flow(self, job_flow_overrides: Dict[str, Any]) -> Dict[str, Any]:
"""
Creates a job flow using the config from the EMR connection.
Keys of the json extra hash may have the arguments of the boto3
run_job_flow method.
Overrides for this config may be passed as the job_flow_overrides.
"""
if not self.emr_conn_id:
raise AirflowException("emr_conn_id must be present to use create_job_flow")
emr_conn = self.get_connection(self.emr_conn_id)
config = emr_conn.extra_dejson.copy()
config.update(job_flow_overrides)
response = self.get_conn().run_job_flow(**config)
return response

@ -0,0 +1,80 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class GlacierHook(AwsBaseHook):
"""Hook for connection with Amazon Glacier"""
def __init__(self, aws_conn_id: str = "aws_default") -> None:
super().__init__(client_type="glacier")
self.aws_conn_id = aws_conn_id
def retrieve_inventory(self, vault_name: str) -> Dict[str, Any]:
"""
Initiate an Amazon Glacier inventory-retrieval job
:param vault_name: the Glacier vault on which job is executed
:type vault_name: str
"""
job_params = {"Type": "inventory-retrieval"}
self.log.info("Retrieving inventory for vault: %s", vault_name)
response = self.get_conn().initiate_job(
vaultName=vault_name, jobParameters=job_params
)
self.log.info("Initiated inventory-retrieval job for: %s", vault_name)
self.log.info("Retrieval Job ID: %s", response["jobId"])
return response
def retrieve_inventory_results(
self, vault_name: str, job_id: str
) -> Dict[str, Any]:
"""
Retrieve the results of an Amazon Glacier inventory-retrieval job
:param vault_name: the Glacier vault on which job is executed
:type vault_name: string
:param job_id: the job ID was returned by retrieve_inventory()
:type job_id: str
"""
self.log.info("Retrieving the job results for vault: %s...", vault_name)
response = self.get_conn().get_job_output(vaultName=vault_name, jobId=job_id)
return response
def describe_job(self, vault_name: str, job_id: str) -> Dict[str, Any]:
"""
Retrieve the status of an Amazon S3 Glacier job, such as an
inventory-retrieval job
:param vault_name: the Glacier vault on which job is executed
:type vault_name: string
:param job_id: the job ID was returned by retrieve_inventory()
:type job_id: str
"""
self.log.info("Retrieving status for vault: %s and job %s", vault_name, job_id)
response = self.get_conn().describe_job(vaultName=vault_name, jobId=job_id)
self.log.info(
"Job status: %s, code status: %s",
response["Action"],
response["StatusCode"],
)
return response

@ -0,0 +1,205 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import time
from typing import Dict, List, Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class AwsGlueJobHook(AwsBaseHook):
"""
Interact with AWS Glue - create job, trigger, crawler
:param s3_bucket: S3 bucket where logs and local etl script will be uploaded
:type s3_bucket: Optional[str]
:param job_name: unique job name per AWS account
:type job_name: Optional[str]
:param desc: job description
:type desc: Optional[str]
:param concurrent_run_limit: The maximum number of concurrent runs allowed for a job
:type concurrent_run_limit: int
:param script_location: path to etl script on s3
:type script_location: Optional[str]
:param retry_limit: Maximum number of times to retry this job if it fails
:type retry_limit: int
:param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job
:type num_of_dpus: int
:param region_name: aws region name (example: us-east-1)
:type region_name: Optional[str]
:param iam_role_name: AWS IAM Role for Glue Job Execution
:type iam_role_name: Optional[str]
:param create_job_kwargs: Extra arguments for Glue Job Creation
:type create_job_kwargs: Optional[dict]
"""
JOB_POLL_INTERVAL = 6 # polls job status after every JOB_POLL_INTERVAL seconds
def __init__(
self,
s3_bucket: Optional[str] = None,
job_name: Optional[str] = None,
desc: Optional[str] = None,
concurrent_run_limit: int = 1,
script_location: Optional[str] = None,
retry_limit: int = 0,
num_of_dpus: int = 10,
region_name: Optional[str] = None,
iam_role_name: Optional[str] = None,
create_job_kwargs: Optional[dict] = None,
*args,
**kwargs,
): # pylint: disable=too-many-arguments
self.job_name = job_name
self.desc = desc
self.concurrent_run_limit = concurrent_run_limit
self.script_location = script_location
self.retry_limit = retry_limit
self.num_of_dpus = num_of_dpus
self.region_name = region_name
self.s3_bucket = s3_bucket
self.role_name = iam_role_name
self.s3_glue_logs = "logs/glue-logs/"
self.create_job_kwargs = create_job_kwargs or {}
kwargs["client_type"] = "glue"
super().__init__(*args, **kwargs)
def list_jobs(self) -> List:
""":return: Lists of Jobs"""
conn = self.get_conn()
return conn.get_jobs()
def get_iam_execution_role(self) -> Dict:
""":return: iam role for job execution"""
iam_client = self.get_client_type("iam", self.region_name)
try:
glue_execution_role = iam_client.get_role(RoleName=self.role_name)
self.log.info("Iam Role Name: %s", self.role_name)
return glue_execution_role
except Exception as general_error:
self.log.error("Failed to create aws glue job, error: %s", general_error)
raise
def initialize_job(self, script_arguments: Optional[dict] = None) -> Dict[str, str]:
"""
Initializes connection with AWS Glue
to run job
:return:
"""
glue_client = self.get_conn()
script_arguments = script_arguments or {}
try:
job_name = self.get_or_create_glue_job()
job_run = glue_client.start_job_run(
JobName=job_name, Arguments=script_arguments
)
return job_run
except Exception as general_error:
self.log.error("Failed to run aws glue job, error: %s", general_error)
raise
def get_job_state(self, job_name: str, run_id: str) -> str:
"""
Get state of the Glue job. The job state can be
running, finished, failed, stopped or timeout.
:param job_name: unique job name per AWS account
:type job_name: str
:param run_id: The job-run ID of the predecessor job run
:type run_id: str
:return: State of the Glue job
"""
glue_client = self.get_conn()
job_run = glue_client.get_job_run(
JobName=job_name, RunId=run_id, PredecessorsIncluded=True
)
job_run_state = job_run["JobRun"]["JobRunState"]
return job_run_state
def job_completion(self, job_name: str, run_id: str) -> Dict[str, str]:
"""
Waits until Glue job with job_name completes or
fails and return final state if finished.
Raises AirflowException when the job failed
:param job_name: unique job name per AWS account
:type job_name: str
:param run_id: The job-run ID of the predecessor job run
:type run_id: str
:return: Dict of JobRunState and JobRunId
"""
failed_states = ["FAILED", "TIMEOUT"]
finished_states = ["SUCCEEDED", "STOPPED"]
while True:
job_run_state = self.get_job_state(job_name, run_id)
if job_run_state in finished_states:
self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state)
return {"JobRunState": job_run_state, "JobRunId": run_id}
if job_run_state in failed_states:
job_error_message = (
"Exiting Job " + run_id + " Run State: " + job_run_state
)
self.log.info(job_error_message)
raise AirflowException(job_error_message)
else:
self.log.info(
"Polling for AWS Glue Job %s current run state with status %s",
job_name,
job_run_state,
)
time.sleep(self.JOB_POLL_INTERVAL)
def get_or_create_glue_job(self) -> str:
"""
Creates(or just returns) and returns the Job name
:return:Name of the Job
"""
glue_client = self.get_conn()
try:
get_job_response = glue_client.get_job(JobName=self.job_name)
self.log.info("Job Already exist. Returning Name of the job")
return get_job_response["Job"]["Name"]
except glue_client.exceptions.EntityNotFoundException:
self.log.info("Job doesnt exist. Now creating and running AWS Glue Job")
if self.s3_bucket is None:
raise AirflowException(
"Could not initialize glue job, error: Specify Parameter `s3_bucket`"
)
s3_log_path = f"s3://{self.s3_bucket}/{self.s3_glue_logs}{self.job_name}"
execution_role = self.get_iam_execution_role()
try:
create_job_response = glue_client.create_job(
Name=self.job_name,
Description=self.desc,
LogUri=s3_log_path,
Role=execution_role["Role"]["RoleName"],
ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit},
Command={"Name": "glueetl", "ScriptLocation": self.script_location},
MaxRetries=self.retry_limit,
AllocatedCapacity=self.num_of_dpus,
**self.create_job_kwargs,
)
return create_job_response["Name"]
except Exception as general_error:
self.log.error(
"Failed to create aws glue job, error: %s", general_error
)
raise

@ -0,0 +1,142 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains AWS Glue Catalog Hook"""
from typing import Optional, Set
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class AwsGlueCatalogHook(AwsBaseHook):
"""
Interact with AWS Glue Catalog
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
def __init__(self, *args, **kwargs):
super().__init__(client_type="glue", *args, **kwargs)
def get_partitions(
self,
database_name: str,
table_name: str,
expression: str = "",
page_size: Optional[int] = None,
max_items: Optional[int] = None,
) -> Set[tuple]:
"""
Retrieves the partition values for a table.
:param database_name: The name of the catalog database where the partitions reside.
:type database_name: str
:param table_name: The name of the partitions' table.
:type table_name: str
:param expression: An expression filtering the partitions to be returned.
Please see official AWS documentation for further information.
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartitions
:type expression: str
:param page_size: pagination size
:type page_size: int
:param max_items: maximum items to return
:type max_items: int
:return: set of partition values where each value is a tuple since
a partition may be composed of multiple columns. For example:
``{('2018-01-01','1'), ('2018-01-01','2')}``
"""
config = {
"PageSize": page_size,
"MaxItems": max_items,
}
paginator = self.get_conn().get_paginator("get_partitions")
response = paginator.paginate(
DatabaseName=database_name,
TableName=table_name,
Expression=expression,
PaginationConfig=config,
)
partitions = set()
for page in response:
for partition in page["Partitions"]:
partitions.add(tuple(partition["Values"]))
return partitions
def check_for_partition(
self, database_name: str, table_name: str, expression: str
) -> bool:
"""
Checks whether a partition exists
:param database_name: Name of hive database (schema) @table belongs to
:type database_name: str
:param table_name: Name of hive table @partition belongs to
:type table_name: str
:expression: Expression that matches the partitions to check for
(eg `a = 'b' AND c = 'd'`)
:type expression: str
:rtype: bool
>>> hook = AwsGlueCatalogHook()
>>> t = 'static_babynames_partitioned'
>>> hook.check_for_partition('airflow', t, "ds='2015-01-01'")
True
"""
partitions = self.get_partitions(
database_name, table_name, expression, max_items=1
)
return bool(partitions)
def get_table(self, database_name: str, table_name: str) -> dict:
"""
Get the information of the table
:param database_name: Name of hive database (schema) @table belongs to
:type database_name: str
:param table_name: Name of hive table
:type table_name: str
:rtype: dict
>>> hook = AwsGlueCatalogHook()
>>> r = hook.get_table('db', 'table_foo')
>>> r['Name'] = 'table_foo'
"""
result = self.get_conn().get_table(DatabaseName=database_name, Name=table_name)
return result["Table"]
def get_table_location(self, database_name: str, table_name: str) -> str:
"""
Get the physical location of the table
:param database_name: Name of hive database (schema) @table belongs to
:type database_name: str
:param table_name: Name of hive table
:type table_name: str
:return: str
"""
table = self.get_table(database_name, table_name)
return table["StorageDescriptor"]["Location"]

@ -0,0 +1,184 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from time import sleep
try:
from functools import cached_property
except ImportError:
from cached_property import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class AwsGlueCrawlerHook(AwsBaseHook):
"""
Interacts with AWS Glue Crawler.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
def __init__(self, *args, **kwargs):
kwargs["client_type"] = "glue"
super().__init__(*args, **kwargs)
@cached_property
def glue_client(self):
""":return: AWS Glue client"""
return self.get_conn()
def has_crawler(self, crawler_name) -> bool:
"""
Checks if the crawler already exists
:param crawler_name: unique crawler name per AWS account
:type crawler_name: str
:return: Returns True if the crawler already exists and False if not.
"""
self.log.info("Checking if crawler already exists: %s", crawler_name)
try:
self.get_crawler(crawler_name)
return True
except self.glue_client.exceptions.EntityNotFoundException:
return False
def get_crawler(self, crawler_name: str) -> dict:
"""
Gets crawler configurations
:param crawler_name: unique crawler name per AWS account
:type crawler_name: str
:return: Nested dictionary of crawler configurations
"""
return self.glue_client.get_crawler(Name=crawler_name)["Crawler"]
def update_crawler(self, **crawler_kwargs) -> str:
"""
Updates crawler configurations
:param crawler_kwargs: Keyword args that define the configurations used for the crawler
:type crawler_kwargs: any
:return: True if crawler was updated and false otherwise
"""
crawler_name = crawler_kwargs["Name"]
current_crawler = self.get_crawler(crawler_name)
update_config = {
key: value
for key, value in crawler_kwargs.items()
if current_crawler[key] != crawler_kwargs[key]
}
if update_config != {}:
self.log.info("Updating crawler: %s", crawler_name)
self.glue_client.update_crawler(**crawler_kwargs)
self.log.info("Updated configurations: %s", update_config)
return True
else:
return False
def create_crawler(self, **crawler_kwargs) -> str:
"""
Creates an AWS Glue Crawler
:param crawler_kwargs: Keyword args that define the configurations used to create the crawler
:type crawler_kwargs: any
:return: Name of the crawler
"""
crawler_name = crawler_kwargs["Name"]
self.log.info("Creating crawler: %s", crawler_name)
return self.glue_client.create_crawler(**crawler_kwargs)["Crawler"]["Name"]
def start_crawler(self, crawler_name: str) -> dict:
"""
Triggers the AWS Glue crawler
:param crawler_name: unique crawler name per AWS account
:type crawler_name: str
:return: Empty dictionary
"""
self.log.info("Starting crawler %s", crawler_name)
crawler = self.glue_client.start_crawler(Name=crawler_name)
return crawler
def wait_for_crawler_completion(
self, crawler_name: str, poll_interval: int = 5
) -> str:
"""
Waits until Glue crawler completes and
returns the status of the latest crawl run.
Raises AirflowException if the crawler fails or is cancelled.
:param crawler_name: unique crawler name per AWS account
:type crawler_name: str
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check crawler status
:type poll_interval: int
:return: Crawler's status
"""
failed_status = ["FAILED", "CANCELLED"]
while True:
crawler = self.get_crawler(crawler_name)
crawler_state = crawler["State"]
if crawler_state == "READY":
self.log.info("State: %s", crawler_state)
self.log.info("crawler_config: %s", crawler)
crawler_status = crawler["LastCrawl"]["Status"]
if crawler_status in failed_status:
raise AirflowException(
f"Status: {crawler_status}"
) # pylint: disable=raising-format-tuple
else:
metrics = self.glue_client.get_crawler_metrics(
CrawlerNameList=[crawler_name]
)["CrawlerMetricsList"][0]
self.log.info("Status: %s", crawler_status)
self.log.info(
"Last Runtime Duration (seconds): %s",
metrics["LastRuntimeSeconds"],
)
self.log.info(
"Median Runtime Duration (seconds): %s",
metrics["MedianRuntimeSeconds"],
)
self.log.info("Tables Created: %s", metrics["TablesCreated"])
self.log.info("Tables Updated: %s", metrics["TablesUpdated"])
self.log.info("Tables Deleted: %s", metrics["TablesDeleted"])
return crawler_status
else:
self.log.info("Polling for AWS Glue crawler: %s ", crawler_name)
self.log.info("State: %s", crawler_state)
metrics = self.glue_client.get_crawler_metrics(
CrawlerNameList=[crawler_name]
)["CrawlerMetricsList"][0]
time_left = int(metrics["TimeLeftSeconds"])
if time_left > 0:
self.log.info("Estimated Time Left (seconds): %s", time_left)
else:
self.log.info("Crawler should finish soon")
sleep(poll_interval)

@ -0,0 +1,50 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains AWS Firehose hook"""
from typing import Iterable
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class AwsFirehoseHook(AwsBaseHook):
"""
Interact with AWS Kinesis Firehose.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
:param delivery_stream: Name of the delivery stream
:type delivery_stream: str
"""
def __init__(self, delivery_stream: str, *args, **kwargs) -> None:
self.delivery_stream = delivery_stream
kwargs["client_type"] = "firehose"
super().__init__(*args, **kwargs)
def put_records(self, records: Iterable):
"""Write batch records to Kinesis Firehose"""
response = self.get_conn().put_record_batch(
DeliveryStreamName=self.delivery_stream, Records=records
)
return response

@ -0,0 +1,69 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains AWS Lambda hook"""
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class AwsLambdaHook(AwsBaseHook):
"""
Interact with AWS Lambda
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
:param function_name: AWS Lambda Function Name
:type function_name: str
:param log_type: Tail Invocation Request
:type log_type: str
:param qualifier: AWS Lambda Function Version or Alias Name
:type qualifier: str
:param invocation_type: AWS Lambda Invocation Type (RequestResponse, Event etc)
:type invocation_type: str
"""
def __init__(
self,
function_name: str,
log_type: str = "None",
qualifier: str = "$LATEST",
invocation_type: str = "RequestResponse",
*args,
**kwargs,
) -> None:
self.function_name = function_name
self.log_type = log_type
self.invocation_type = invocation_type
self.qualifier = qualifier
kwargs["client_type"] = "lambda"
super().__init__(*args, **kwargs)
def invoke_lambda(self, payload: str) -> str:
"""Invoke Lambda Function"""
response = self.conn.invoke(
FunctionName=self.function_name,
InvocationType=self.invocation_type,
LogType=self.log_type,
Payload=payload,
Qualifier=self.qualifier,
)
return response

@ -0,0 +1,105 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains a hook (AwsLogsHook) with some very basic
functionality for interacting with AWS CloudWatch.
"""
from typing import Dict, Generator, Optional
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class AwsLogsHook(AwsBaseHook):
"""
Interact with AWS CloudWatch Logs
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "logs"
super().__init__(*args, **kwargs)
def get_log_events(
self,
log_group: str,
log_stream_name: str,
start_time: int = 0,
skip: int = 0,
start_from_head: bool = True,
) -> Generator:
"""
A generator for log items in a single stream. This will yield all the
items that are available at the current moment.
:param log_group: The name of the log group.
:type log_group: str
:param log_stream_name: The name of the specific stream.
:type log_stream_name: str
:param start_time: The time stamp value to start reading the logs from (default: 0).
:type start_time: int
:param skip: The number of log entries to skip at the start (default: 0).
This is for when there are multiple entries at the same timestamp.
:type skip: int
:param start_from_head: whether to start from the beginning (True) of the log or
at the end of the log (False).
:type start_from_head: bool
:rtype: dict
:return: | A CloudWatch log event with the following key-value pairs:
| 'timestamp' (int): The time in milliseconds of the event.
| 'message' (str): The log event data.
| 'ingestionTime' (int): The time in milliseconds the event was ingested.
"""
next_token = None
event_count = 1
while event_count > 0:
if next_token is not None:
token_arg: Optional[Dict[str, str]] = {"nextToken": next_token}
else:
token_arg = {}
response = self.get_conn().get_log_events(
logGroupName=log_group,
logStreamName=log_stream_name,
startTime=start_time,
startFromHead=start_from_head,
**token_arg,
)
events = response["events"]
event_count = len(events)
if event_count > skip:
events = events[skip:]
skip = 0
else:
skip -= event_count
events = []
yield from events
if "nextForwardToken" in response:
next_token = response["nextForwardToken"]
else:
return

@ -0,0 +1,131 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Interact with AWS Redshift, using the boto3 library."""
from typing import List, Optional
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class RedshiftHook(AwsBaseHook):
"""
Interact with AWS Redshift, using the boto3 library
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "redshift"
super().__init__(*args, **kwargs)
# TODO: Wrap create_cluster_snapshot
def cluster_status(self, cluster_identifier: str) -> str:
"""
Return status of a cluster
:param cluster_identifier: unique identifier of a cluster
:type cluster_identifier: str
"""
try:
response = self.get_conn().describe_clusters(
ClusterIdentifier=cluster_identifier
)["Clusters"]
return response[0]["ClusterStatus"] if response else None
except self.get_conn().exceptions.ClusterNotFoundFault:
return "cluster_not_found"
def delete_cluster( # pylint: disable=invalid-name
self,
cluster_identifier: str,
skip_final_cluster_snapshot: bool = True,
final_cluster_snapshot_identifier: Optional[str] = None,
):
"""
Delete a cluster and optionally create a snapshot
:param cluster_identifier: unique identifier of a cluster
:type cluster_identifier: str
:param skip_final_cluster_snapshot: determines cluster snapshot creation
:type skip_final_cluster_snapshot: bool
:param final_cluster_snapshot_identifier: name of final cluster snapshot
:type final_cluster_snapshot_identifier: str
"""
final_cluster_snapshot_identifier = final_cluster_snapshot_identifier or ""
response = self.get_conn().delete_cluster(
ClusterIdentifier=cluster_identifier,
SkipFinalClusterSnapshot=skip_final_cluster_snapshot,
FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier,
)
return response["Cluster"] if response["Cluster"] else None
def describe_cluster_snapshots(
self, cluster_identifier: str
) -> Optional[List[str]]:
"""
Gets a list of snapshots for a cluster
:param cluster_identifier: unique identifier of a cluster
:type cluster_identifier: str
"""
response = self.get_conn().describe_cluster_snapshots(
ClusterIdentifier=cluster_identifier
)
if "Snapshots" not in response:
return None
snapshots = response["Snapshots"]
snapshots = [snapshot for snapshot in snapshots if snapshot["Status"]]
snapshots.sort(key=lambda x: x["SnapshotCreateTime"], reverse=True)
return snapshots
def restore_from_cluster_snapshot(
self, cluster_identifier: str, snapshot_identifier: str
) -> str:
"""
Restores a cluster from its snapshot
:param cluster_identifier: unique identifier of a cluster
:type cluster_identifier: str
:param snapshot_identifier: unique identifier for a snapshot of a cluster
:type snapshot_identifier: str
"""
response = self.get_conn().restore_from_cluster_snapshot(
ClusterIdentifier=cluster_identifier, SnapshotIdentifier=snapshot_identifier
)
return response["Cluster"] if response["Cluster"] else None
def create_cluster_snapshot(
self, snapshot_identifier: str, cluster_identifier: str
) -> str:
"""
Creates a snapshot of a cluster
:param snapshot_identifier: unique identifier for a snapshot of a cluster
:type snapshot_identifier: str
:param cluster_identifier: unique identifier of a cluster
:type cluster_identifier: str
"""
response = self.get_conn().create_cluster_snapshot(
SnapshotIdentifier=snapshot_identifier,
ClusterIdentifier=cluster_identifier,
)
return response["Snapshot"] if response["Snapshot"] else None

@ -0,0 +1,973 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Interact with AWS S3, using the boto3 library."""
import fnmatch
import gzip as gz
import io
import re
import shutil
from functools import wraps
from inspect import signature
from io import BytesIO
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, cast
from urllib.parse import urlparse
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.helpers import chunks
from boto3.s3.transfer import S3Transfer, TransferConfig
from botocore.exceptions import ClientError
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
def provide_bucket_name(func: T) -> T:
"""
Function decorator that provides a bucket name taken from the connection
in case no bucket name has been passed to the function.
"""
function_signature = signature(func)
@wraps(func)
def wrapper(*args, **kwargs) -> T:
bound_args = function_signature.bind(*args, **kwargs)
if "bucket_name" not in bound_args.arguments:
self = args[0]
if self.aws_conn_id:
connection = self.get_connection(self.aws_conn_id)
if connection.schema:
bound_args.arguments["bucket_name"] = connection.schema
return func(*bound_args.args, **bound_args.kwargs)
return cast(T, wrapper)
def unify_bucket_name_and_key(func: T) -> T:
"""
Function decorator that unifies bucket name and key taken from the key
in case no bucket name and at least a key has been passed to the function.
"""
function_signature = signature(func)
@wraps(func)
def wrapper(*args, **kwargs) -> T:
bound_args = function_signature.bind(*args, **kwargs)
def get_key_name() -> Optional[str]:
if "wildcard_key" in bound_args.arguments:
return "wildcard_key"
if "key" in bound_args.arguments:
return "key"
raise ValueError("Missing key parameter!")
key_name = get_key_name()
if key_name and "bucket_name" not in bound_args.arguments:
(
bound_args.arguments["bucket_name"],
bound_args.arguments[key_name],
) = S3Hook.parse_s3_url(bound_args.arguments[key_name])
return func(*bound_args.args, **bound_args.kwargs)
return cast(T, wrapper)
class S3Hook(AwsBaseHook):
"""
Interact with AWS S3, using the boto3 library.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
conn_type = "s3"
hook_name = "S3"
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "s3"
self.extra_args = {}
if "extra_args" in kwargs:
self.extra_args = kwargs["extra_args"]
if not isinstance(self.extra_args, dict):
raise ValueError(
f"extra_args '{self.extra_args!r}' must be of type {dict}"
)
del kwargs["extra_args"]
self.transfer_config = TransferConfig()
if "transfer_config_args" in kwargs:
transport_config_args = kwargs["transfer_config_args"]
if not isinstance(transport_config_args, dict):
raise ValueError(
f"transfer_config_args '{transport_config_args!r} must be of type {dict}"
)
self.transfer_config = TransferConfig(**transport_config_args)
del kwargs["transfer_config_args"]
super().__init__(*args, **kwargs)
@staticmethod
def parse_s3_url(s3url: str) -> Tuple[str, str]:
"""
Parses the S3 Url into a bucket name and key.
:param s3url: The S3 Url to parse.
:rtype s3url: str
:return: the parsed bucket name and key
:rtype: tuple of str
"""
parsed_url = urlparse(s3url)
if not parsed_url.netloc:
raise AirflowException(f'Please provide a bucket_name instead of "{s3url}"')
bucket_name = parsed_url.netloc
key = parsed_url.path.strip("/")
return bucket_name, key
@provide_bucket_name
def check_for_bucket(self, bucket_name: Optional[str] = None) -> bool:
"""
Check if bucket_name exists.
:param bucket_name: the name of the bucket
:type bucket_name: str
:return: True if it exists and False if not.
:rtype: bool
"""
try:
self.get_conn().head_bucket(Bucket=bucket_name)
return True
except ClientError as e:
self.log.error(e.response["Error"]["Message"])
return False
@provide_bucket_name
def get_bucket(self, bucket_name: Optional[str] = None) -> str:
"""
Returns a boto3.S3.Bucket object
:param bucket_name: the name of the bucket
:type bucket_name: str
:return: the bucket object to the bucket name.
:rtype: boto3.S3.Bucket
"""
s3_resource = self.get_resource_type("s3")
return s3_resource.Bucket(bucket_name)
@provide_bucket_name
def create_bucket(
self, bucket_name: Optional[str] = None, region_name: Optional[str] = None
) -> None:
"""
Creates an Amazon S3 bucket.
:param bucket_name: The name of the bucket
:type bucket_name: str
:param region_name: The name of the aws region in which to create the bucket.
:type region_name: str
"""
if not region_name:
region_name = self.get_conn().meta.region_name
if region_name == "us-east-1":
self.get_conn().create_bucket(Bucket=bucket_name)
else:
self.get_conn().create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": region_name},
)
@provide_bucket_name
def check_for_prefix(
self, prefix: str, delimiter: str, bucket_name: Optional[str] = None
) -> bool:
"""
Checks that a prefix exists in a bucket
:param bucket_name: the name of the bucket
:type bucket_name: str
:param prefix: a key prefix
:type prefix: str
:param delimiter: the delimiter marks key hierarchy.
:type delimiter: str
:return: False if the prefix does not exist in the bucket and True if it does.
:rtype: bool
"""
prefix = prefix + delimiter if prefix[-1] != delimiter else prefix
prefix_split = re.split(fr"(\w+[{delimiter}])$", prefix, 1)
previous_level = prefix_split[0]
plist = self.list_prefixes(bucket_name, previous_level, delimiter)
return prefix in plist
@provide_bucket_name
def list_prefixes(
self,
bucket_name: Optional[str] = None,
prefix: Optional[str] = None,
delimiter: Optional[str] = None,
page_size: Optional[int] = None,
max_items: Optional[int] = None,
) -> list:
"""
Lists prefixes in a bucket under prefix
:param bucket_name: the name of the bucket
:type bucket_name: str
:param prefix: a key prefix
:type prefix: str
:param delimiter: the delimiter marks key hierarchy.
:type delimiter: str
:param page_size: pagination size
:type page_size: int
:param max_items: maximum items to return
:type max_items: int
:return: a list of matched prefixes
:rtype: list
"""
prefix = prefix or ""
delimiter = delimiter or ""
config = {
"PageSize": page_size,
"MaxItems": max_items,
}
paginator = self.get_conn().get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name,
Prefix=prefix,
Delimiter=delimiter,
PaginationConfig=config,
)
prefixes = []
for page in response:
if "CommonPrefixes" in page:
for common_prefix in page["CommonPrefixes"]:
prefixes.append(common_prefix["Prefix"])
return prefixes
@provide_bucket_name
def list_keys(
self,
bucket_name: Optional[str] = None,
prefix: Optional[str] = None,
delimiter: Optional[str] = None,
page_size: Optional[int] = None,
max_items: Optional[int] = None,
) -> list:
"""
Lists keys in a bucket under prefix and not containing delimiter
:param bucket_name: the name of the bucket
:type bucket_name: str
:param prefix: a key prefix
:type prefix: str
:param delimiter: the delimiter marks key hierarchy.
:type delimiter: str
:param page_size: pagination size
:type page_size: int
:param max_items: maximum items to return
:type max_items: int
:return: a list of matched keys
:rtype: list
"""
prefix = prefix or ""
delimiter = delimiter or ""
config = {
"PageSize": page_size,
"MaxItems": max_items,
}
paginator = self.get_conn().get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name,
Prefix=prefix,
Delimiter=delimiter,
PaginationConfig=config,
)
keys = []
for page in response:
if "Contents" in page:
for k in page["Contents"]:
keys.append(k["Key"])
return keys
@provide_bucket_name
@unify_bucket_name_and_key
def check_for_key(self, key: str, bucket_name: Optional[str] = None) -> bool:
"""
Checks if a key exists in a bucket
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which the file is stored
:type bucket_name: str
:return: True if the key exists and False if not.
:rtype: bool
"""
try:
self.get_conn().head_object(Bucket=bucket_name, Key=key)
return True
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
return False
else:
raise e
@provide_bucket_name
@unify_bucket_name_and_key
def get_key(self, key: str, bucket_name: Optional[str] = None) -> S3Transfer:
"""
Returns a boto3.s3.Object
:param key: the path to the key
:type key: str
:param bucket_name: the name of the bucket
:type bucket_name: str
:return: the key object from the bucket
:rtype: boto3.s3.Object
"""
obj = self.get_resource_type("s3").Object(bucket_name, key)
obj.load()
return obj
@provide_bucket_name
@unify_bucket_name_and_key
def read_key(self, key: str, bucket_name: Optional[str] = None) -> str:
"""
Reads a key from S3
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which the file is stored
:type bucket_name: str
:return: the content of the key
:rtype: str
"""
obj = self.get_key(key, bucket_name)
return obj.get()["Body"].read().decode("utf-8")
@provide_bucket_name
@unify_bucket_name_and_key
def select_key(
self,
key: str,
bucket_name: Optional[str] = None,
expression: Optional[str] = None,
expression_type: Optional[str] = None,
input_serialization: Optional[Dict[str, Any]] = None,
output_serialization: Optional[Dict[str, Any]] = None,
) -> str:
"""
Reads a key with S3 Select.
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which the file is stored
:type bucket_name: str
:param expression: S3 Select expression
:type expression: str
:param expression_type: S3 Select expression type
:type expression_type: str
:param input_serialization: S3 Select input data serialization format
:type input_serialization: dict
:param output_serialization: S3 Select output data serialization format
:type output_serialization: dict
:return: retrieved subset of original data by S3 Select
:rtype: str
.. seealso::
For more details about S3 Select parameters:
http://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.select_object_content
"""
expression = expression or "SELECT * FROM S3Object"
expression_type = expression_type or "SQL"
if input_serialization is None:
input_serialization = {"CSV": {}}
if output_serialization is None:
output_serialization = {"CSV": {}}
response = self.get_conn().select_object_content(
Bucket=bucket_name,
Key=key,
Expression=expression,
ExpressionType=expression_type,
InputSerialization=input_serialization,
OutputSerialization=output_serialization,
)
return "".join(
event["Records"]["Payload"].decode("utf-8")
for event in response["Payload"]
if "Records" in event
)
@provide_bucket_name
@unify_bucket_name_and_key
def check_for_wildcard_key(
self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = ""
) -> bool:
"""
Checks that a key matching a wildcard expression exists in a bucket
:param wildcard_key: the path to the key
:type wildcard_key: str
:param bucket_name: the name of the bucket
:type bucket_name: str
:param delimiter: the delimiter marks key hierarchy
:type delimiter: str
:return: True if a key exists and False if not.
:rtype: bool
"""
return (
self.get_wildcard_key(
wildcard_key=wildcard_key, bucket_name=bucket_name, delimiter=delimiter
)
is not None
)
@provide_bucket_name
@unify_bucket_name_and_key
def get_wildcard_key(
self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = ""
) -> S3Transfer:
"""
Returns a boto3.s3.Object object matching the wildcard expression
:param wildcard_key: the path to the key
:type wildcard_key: str
:param bucket_name: the name of the bucket
:type bucket_name: str
:param delimiter: the delimiter marks key hierarchy
:type delimiter: str
:return: the key object from the bucket or None if none has been found.
:rtype: boto3.s3.Object
"""
prefix = re.split(r"[*]", wildcard_key, 1)[0]
key_list = self.list_keys(bucket_name, prefix=prefix, delimiter=delimiter)
key_matches = [k for k in key_list if fnmatch.fnmatch(k, wildcard_key)]
if key_matches:
return self.get_key(key_matches[0], bucket_name)
return None
@provide_bucket_name
@unify_bucket_name_and_key
def load_file(
self,
filename: str,
key: str,
bucket_name: Optional[str] = None,
replace: bool = False,
encrypt: bool = False,
gzip: bool = False,
acl_policy: Optional[str] = None,
) -> None:
"""
Loads a local file to S3
:param filename: name of the file to load.
:type filename: str
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which to store the file
:type bucket_name: str
:param replace: A flag to decide whether or not to overwrite the key
if it already exists. If replace is False and the key exists, an
error will be raised.
:type replace: bool
:param encrypt: If True, the file will be encrypted on the server-side
by S3 and will be stored in an encrypted form while at rest in S3.
:type encrypt: bool
:param gzip: If True, the file will be compressed locally
:type gzip: bool
:param acl_policy: String specifying the canned ACL policy for the file being
uploaded to the S3 bucket.
:type acl_policy: str
"""
if not replace and self.check_for_key(key, bucket_name):
raise ValueError(f"The key {key} already exists.")
extra_args = self.extra_args
if encrypt:
extra_args["ServerSideEncryption"] = "AES256"
if gzip:
with open(filename, "rb") as f_in:
filename_gz = f_in.name + ".gz"
with gz.open(filename_gz, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
filename = filename_gz
if acl_policy:
extra_args["ACL"] = acl_policy
client = self.get_conn()
client.upload_file(
filename,
bucket_name,
key,
ExtraArgs=extra_args,
Config=self.transfer_config,
)
@provide_bucket_name
@unify_bucket_name_and_key
def load_string(
self,
string_data: str,
key: str,
bucket_name: Optional[str] = None,
replace: bool = False,
encrypt: bool = False,
encoding: Optional[str] = None,
acl_policy: Optional[str] = None,
compression: Optional[str] = None,
) -> None:
"""
Loads a string to S3
This is provided as a convenience to drop a string in S3. It uses the
boto infrastructure to ship a file to s3.
:param string_data: str to set as content for the key.
:type string_data: str
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which to store the file
:type bucket_name: str
:param replace: A flag to decide whether or not to overwrite the key
if it already exists
:type replace: bool
:param encrypt: If True, the file will be encrypted on the server-side
by S3 and will be stored in an encrypted form while at rest in S3.
:type encrypt: bool
:param encoding: The string to byte encoding
:type encoding: str
:param acl_policy: The string to specify the canned ACL policy for the
object to be uploaded
:type acl_policy: str
:param compression: Type of compression to use, currently only gzip is supported.
:type compression: str
"""
encoding = encoding or "utf-8"
bytes_data = string_data.encode(encoding)
# Compress string
available_compressions = ["gzip"]
if compression is not None and compression not in available_compressions:
raise NotImplementedError(
"Received {} compression type. String "
"can currently be compressed in {} "
"only.".format(compression, available_compressions)
)
if compression == "gzip":
bytes_data = gz.compress(bytes_data)
file_obj = io.BytesIO(bytes_data)
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy)
file_obj.close()
@provide_bucket_name
@unify_bucket_name_and_key
def load_bytes(
self,
bytes_data: bytes,
key: str,
bucket_name: Optional[str] = None,
replace: bool = False,
encrypt: bool = False,
acl_policy: Optional[str] = None,
) -> None:
"""
Loads bytes to S3
This is provided as a convenience to drop a string in S3. It uses the
boto infrastructure to ship a file to s3.
:param bytes_data: bytes to set as content for the key.
:type bytes_data: bytes
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which to store the file
:type bucket_name: str
:param replace: A flag to decide whether or not to overwrite the key
if it already exists
:type replace: bool
:param encrypt: If True, the file will be encrypted on the server-side
by S3 and will be stored in an encrypted form while at rest in S3.
:type encrypt: bool
:param acl_policy: The string to specify the canned ACL policy for the
object to be uploaded
:type acl_policy: str
"""
file_obj = io.BytesIO(bytes_data)
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy)
file_obj.close()
@provide_bucket_name
@unify_bucket_name_and_key
def load_file_obj(
self,
file_obj: BytesIO,
key: str,
bucket_name: Optional[str] = None,
replace: bool = False,
encrypt: bool = False,
acl_policy: Optional[str] = None,
) -> None:
"""
Loads a file object to S3
:param file_obj: The file-like object to set as the content for the S3 key.
:type file_obj: file-like object
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which to store the file
:type bucket_name: str
:param replace: A flag that indicates whether to overwrite the key
if it already exists.
:type replace: bool
:param encrypt: If True, S3 encrypts the file on the server,
and the file is stored in encrypted form at rest in S3.
:type encrypt: bool
:param acl_policy: The string to specify the canned ACL policy for the
object to be uploaded
:type acl_policy: str
"""
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy)
def _upload_file_obj(
self,
file_obj: BytesIO,
key: str,
bucket_name: Optional[str] = None,
replace: bool = False,
encrypt: bool = False,
acl_policy: Optional[str] = None,
) -> None:
if not replace and self.check_for_key(key, bucket_name):
raise ValueError(f"The key {key} already exists.")
extra_args = self.extra_args
if encrypt:
extra_args["ServerSideEncryption"] = "AES256"
if acl_policy:
extra_args["ACL"] = acl_policy
client = self.get_conn()
client.upload_fileobj(
file_obj,
bucket_name,
key,
ExtraArgs=extra_args,
Config=self.transfer_config,
)
def copy_object(
self,
source_bucket_key: str,
dest_bucket_key: str,
source_bucket_name: Optional[str] = None,
dest_bucket_name: Optional[str] = None,
source_version_id: Optional[str] = None,
acl_policy: Optional[str] = None,
) -> None:
"""
Creates a copy of an object that is already stored in S3.
Note: the S3 connection used here needs to have access to both
source and destination bucket/key.
:param source_bucket_key: The key of the source object.
It can be either full s3:// style url or relative path from root level.
When it's specified as a full s3:// url, please omit source_bucket_name.
:type source_bucket_key: str
:param dest_bucket_key: The key of the object to copy to.
The convention to specify `dest_bucket_key` is the same
as `source_bucket_key`.
:type dest_bucket_key: str
:param source_bucket_name: Name of the S3 bucket where the source object is in.
It should be omitted when `source_bucket_key` is provided as a full s3:// url.
:type source_bucket_name: str
:param dest_bucket_name: Name of the S3 bucket to where the object is copied.
It should be omitted when `dest_bucket_key` is provided as a full s3:// url.
:type dest_bucket_name: str
:param source_version_id: Version ID of the source object (OPTIONAL)
:type source_version_id: str
:param acl_policy: The string to specify the canned ACL policy for the
object to be copied which is private by default.
:type acl_policy: str
"""
acl_policy = acl_policy or "private"
if dest_bucket_name is None:
dest_bucket_name, dest_bucket_key = self.parse_s3_url(dest_bucket_key)
else:
parsed_url = urlparse(dest_bucket_key)
if parsed_url.scheme != "" or parsed_url.netloc != "":
raise AirflowException(
"If dest_bucket_name is provided, "
+ "dest_bucket_key should be relative path "
+ "from root level, rather than a full s3:// url"
)
if source_bucket_name is None:
source_bucket_name, source_bucket_key = self.parse_s3_url(source_bucket_key)
else:
parsed_url = urlparse(source_bucket_key)
if parsed_url.scheme != "" or parsed_url.netloc != "":
raise AirflowException(
"If source_bucket_name is provided, "
+ "source_bucket_key should be relative path "
+ "from root level, rather than a full s3:// url"
)
copy_source = {
"Bucket": source_bucket_name,
"Key": source_bucket_key,
"VersionId": source_version_id,
}
response = self.get_conn().copy_object(
Bucket=dest_bucket_name,
Key=dest_bucket_key,
CopySource=copy_source,
ACL=acl_policy,
)
return response
@provide_bucket_name
def delete_bucket(self, bucket_name: str, force_delete: bool = False) -> None:
"""
To delete s3 bucket, delete all s3 bucket objects and then delete the bucket.
:param bucket_name: Bucket name
:type bucket_name: str
:param force_delete: Enable this to delete bucket even if not empty
:type force_delete: bool
:return: None
:rtype: None
"""
if force_delete:
bucket_keys = self.list_keys(bucket_name=bucket_name)
if bucket_keys:
self.delete_objects(bucket=bucket_name, keys=bucket_keys)
self.conn.delete_bucket(Bucket=bucket_name)
def delete_objects(self, bucket: str, keys: Union[str, list]) -> None:
"""
Delete keys from the bucket.
:param bucket: Name of the bucket in which you are going to delete object(s)
:type bucket: str
:param keys: The key(s) to delete from S3 bucket.
When ``keys`` is a string, it's supposed to be the key name of
the single object to delete.
When ``keys`` is a list, it's supposed to be the list of the
keys to delete.
:type keys: str or list
"""
if isinstance(keys, str):
keys = [keys]
s3 = self.get_conn()
# We can only send a maximum of 1000 keys per request.
# For details see:
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.delete_objects
for chunk in chunks(keys, chunk_size=1000):
response = s3.delete_objects(
Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]}
)
deleted_keys = [x["Key"] for x in response.get("Deleted", [])]
self.log.info("Deleted: %s", deleted_keys)
if "Errors" in response:
errors_keys = [x["Key"] for x in response.get("Errors", [])]
raise AirflowException(f"Errors when deleting: {errors_keys}")
@provide_bucket_name
@unify_bucket_name_and_key
def download_file(
self,
key: str,
bucket_name: Optional[str] = None,
local_path: Optional[str] = None,
) -> str:
"""
Downloads a file from the S3 location to the local file system.
:param key: The key path in S3.
:type key: str
:param bucket_name: The specific bucket to use.
:type bucket_name: Optional[str]
:param local_path: The local path to the downloaded file. If no path is provided it will use the
system's temporary directory.
:type local_path: Optional[str]
:return: the file name.
:rtype: str
"""
self.log.info(
"Downloading source S3 file from Bucket %s with path %s", bucket_name, key
)
if not self.check_for_key(key, bucket_name):
raise AirflowException(
f"The source file in Bucket {bucket_name} with path {key} does not exist"
)
s3_obj = self.get_key(key, bucket_name)
with NamedTemporaryFile(
dir=local_path, prefix="airflow_tmp_", delete=False
) as local_tmp_file:
s3_obj.download_fileobj(local_tmp_file)
return local_tmp_file.name
def generate_presigned_url(
self,
client_method: str,
params: Optional[dict] = None,
expires_in: int = 3600,
http_method: Optional[str] = None,
) -> Optional[str]:
"""
Generate a presigned url given a client, its method, and arguments
:param client_method: The client method to presign for.
:type client_method: str
:param params: The parameters normally passed to ClientMethod.
:type params: dict
:param expires_in: The number of seconds the presigned url is valid for.
By default it expires in an hour (3600 seconds).
:type expires_in: int
:param http_method: The http method to use on the generated url.
By default, the http method is whatever is used in the method's model.
:type http_method: str
:return: The presigned url.
:rtype: str
"""
s3_client = self.get_conn()
try:
return s3_client.generate_presigned_url(
ClientMethod=client_method,
Params=params,
ExpiresIn=expires_in,
HttpMethod=http_method,
)
except ClientError as e:
self.log.error(e.response["Error"]["Message"])
return None
@provide_bucket_name
def get_bucket_tagging(
self, bucket_name: Optional[str] = None
) -> Optional[List[Dict[str, str]]]:
"""
Gets a List of tags from a bucket.
:param bucket_name: The name of the bucket.
:type bucket_name: str
:return: A List containing the key/value pairs for the tags
:rtype: Optional[List[Dict[str, str]]]
"""
try:
s3_client = self.get_conn()
result = s3_client.get_bucket_tagging(Bucket=bucket_name)["TagSet"]
self.log.info("S3 Bucket Tag Info: %s", result)
return result
except ClientError as e:
self.log.error(e)
raise e
@provide_bucket_name
def put_bucket_tagging(
self,
tag_set: Optional[List[Dict[str, str]]] = None,
key: Optional[str] = None,
value: Optional[str] = None,
bucket_name: Optional[str] = None,
) -> None:
"""
Overwrites the existing TagSet with provided tags. Must provide either a TagSet or a key/value pair.
:param tag_set: A List containing the key/value pairs for the tags.
:type tag_set: List[Dict[str, str]]
:param key: The Key for the new TagSet entry.
:type key: str
:param value: The Value for the new TagSet entry.
:type value: str
:param bucket_name: The name of the bucket.
:type bucket_name: str
:return: None
:rtype: None
"""
self.log.info(
"S3 Bucket Tag Info:\tKey: %s\tValue: %s\tSet: %s", key, value, tag_set
)
if not tag_set:
tag_set = []
if key and value:
tag_set.append({"Key": key, "Value": value})
elif not tag_set or (key or value):
message = "put_bucket_tagging() requires either a predefined TagSet or a key/value pair."
self.log.error(message)
raise ValueError(message)
try:
s3_client = self.get_conn()
s3_client.put_bucket_tagging(
Bucket=bucket_name, Tagging={"TagSet": tag_set}
)
except ClientError as e:
self.log.error(e)
raise e
@provide_bucket_name
def delete_bucket_tagging(self, bucket_name: Optional[str] = None) -> None:
"""
Deletes all tags from a bucket.
:param bucket_name: The name of the bucket.
:type bucket_name: str
:return: None
:rtype: None
"""
s3_client = self.get_conn()
s3_client.delete_bucket_tagging(Bucket=bucket_name)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,71 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import base64
import json
from typing import Union
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class SecretsManagerHook(AwsBaseHook):
"""
Interact with Amazon SecretsManager Service.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. see also::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
def __init__(self, *args, **kwargs):
super().__init__(client_type="secretsmanager", *args, **kwargs)
def get_secret(self, secret_name: str) -> Union[str, bytes]:
"""
Retrieve secret value from AWS Secrets Manager as a str or bytes
reflecting format it stored in the AWS Secrets Manager
:param secret_name: name of the secrets.
:type secret_name: str
:return: Union[str, bytes] with the information about the secrets
:rtype: Union[str, bytes]
"""
# Depending on whether the secret is a string or binary, one of
# these fields will be populated.
get_secret_value_response = self.get_conn().get_secret_value(
SecretId=secret_name
)
if "SecretString" in get_secret_value_response:
secret = get_secret_value_response["SecretString"]
else:
secret = base64.b64decode(get_secret_value_response["SecretBinary"])
return secret
def get_secret_as_dict(self, secret_name: str) -> dict:
"""
Retrieve secret value from AWS Secrets Manager in a dict representation
:param secret_name: name of the secrets.
:type secret_name: str
:return: dict with the information about the secrets
:rtype: dict
"""
return json.loads(self.get_secret(secret_name))

@ -0,0 +1,99 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains AWS SES Hook"""
from typing import Any, Dict, Iterable, List, Optional, Union
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.email import build_mime_message
class SESHook(AwsBaseHook):
"""
Interact with Amazon Simple Email Service.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "ses"
super().__init__(*args, **kwargs)
def send_email( # pylint: disable=too-many-arguments
self,
mail_from: str,
to: Union[str, Iterable[str]],
subject: str,
html_content: str,
files: Optional[List[str]] = None,
cc: Optional[Union[str, Iterable[str]]] = None,
bcc: Optional[Union[str, Iterable[str]]] = None,
mime_subtype: str = "mixed",
mime_charset: str = "utf-8",
reply_to: Optional[str] = None,
return_path: Optional[str] = None,
custom_headers: Optional[Dict[str, Any]] = None,
) -> dict:
"""
Send email using Amazon Simple Email Service
:param mail_from: Email address to set as email's from
:param to: List of email addresses to set as email's to
:param subject: Email's subject
:param html_content: Content of email in HTML format
:param files: List of paths of files to be attached
:param cc: List of email addresses to set as email's CC
:param bcc: List of email addresses to set as email's BCC
:param mime_subtype: Can be used to specify the sub-type of the message. Default = mixed
:param mime_charset: Email's charset. Default = UTF-8.
:param return_path: The email address to which replies will be sent. By default, replies
are sent to the original sender's email address.
:param reply_to: The email address to which message bounces and complaints should be sent.
"Return-Path" is sometimes called "envelope from," "envelope sender," or "MAIL FROM."
:param custom_headers: Additional headers to add to the MIME message.
No validations are run on these values and they should be able to be encoded.
:return: Response from Amazon SES service with unique message identifier.
"""
ses_client = self.get_conn()
custom_headers = custom_headers or {}
if reply_to:
custom_headers["Reply-To"] = reply_to
if return_path:
custom_headers["Return-Path"] = return_path
message, recipients = build_mime_message(
mail_from=mail_from,
to=to,
subject=subject,
html_content=html_content,
files=files,
cc=cc,
bcc=bcc,
mime_subtype=mime_subtype,
mime_charset=mime_charset,
custom_headers=custom_headers,
)
return ses_client.send_raw_email(
Source=mail_from,
Destinations=recipients,
RawMessage={"Data": message.as_string()},
)

@ -0,0 +1,96 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains AWS SNS hook"""
import json
from typing import Dict, Optional, Union
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
def _get_message_attribute(o):
if isinstance(o, bytes):
return {"DataType": "Binary", "BinaryValue": o}
if isinstance(o, str):
return {"DataType": "String", "StringValue": o}
if isinstance(o, (int, float)):
return {"DataType": "Number", "StringValue": str(o)}
if hasattr(o, "__iter__"):
return {"DataType": "String.Array", "StringValue": json.dumps(o)}
raise TypeError(
"Values in MessageAttributes must be one of bytes, str, int, float, or iterable; "
f"got {type(o)}"
)
class AwsSnsHook(AwsBaseHook):
"""
Interact with Amazon Simple Notification Service.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
def __init__(self, *args, **kwargs):
super().__init__(client_type="sns", *args, **kwargs)
def publish_to_target(
self,
target_arn: str,
message: str,
subject: Optional[str] = None,
message_attributes: Optional[dict] = None,
):
"""
Publish a message to a topic or an endpoint.
:param target_arn: either a TopicArn or an EndpointArn
:type target_arn: str
:param message: the default message you want to send
:param message: str
:param subject: subject of message
:type subject: str
:param message_attributes: additional attributes to publish for message filtering. This should be
a flat dict; the DataType to be sent depends on the type of the value:
- bytes = Binary
- str = String
- int, float = Number
- iterable = String.Array
:type message_attributes: dict
"""
publish_kwargs: Dict[str, Union[str, dict]] = {
"TargetArn": target_arn,
"MessageStructure": "json",
"Message": json.dumps({"default": message}),
}
# Construct args this way because boto3 distinguishes from missing args and those set to None
if subject:
publish_kwargs["Subject"] = subject
if message_attributes:
publish_kwargs["MessageAttributes"] = {
key: _get_message_attribute(val)
for key, val in message_attributes.items()
}
return self.get_conn().publish(**publish_kwargs)

@ -0,0 +1,87 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains AWS SQS hook"""
from typing import Dict, Optional
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class SQSHook(AwsBaseHook):
"""
Interact with Amazon Simple Queue Service.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "sqs"
super().__init__(*args, **kwargs)
def create_queue(self, queue_name: str, attributes: Optional[Dict] = None) -> Dict:
"""
Create queue using connection object
:param queue_name: name of the queue.
:type queue_name: str
:param attributes: additional attributes for the queue (default: None)
For details of the attributes parameter see :py:meth:`SQS.create_queue`
:type attributes: dict
:return: dict with the information about the queue
For details of the returned value see :py:meth:`SQS.create_queue`
:rtype: dict
"""
return self.get_conn().create_queue(
QueueName=queue_name, Attributes=attributes or {}
)
def send_message(
self,
queue_url: str,
message_body: str,
delay_seconds: int = 0,
message_attributes: Optional[Dict] = None,
) -> Dict:
"""
Send message to the queue
:param queue_url: queue url
:type queue_url: str
:param message_body: the contents of the message
:type message_body: str
:param delay_seconds: seconds to delay the message
:type delay_seconds: int
:param message_attributes: additional attributes for the message (default: None)
For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message`
:type message_attributes: dict
:return: dict with the information about the message sent
For details of the returned value see :py:meth:`botocore.client.SQS.send_message`
:rtype: dict
"""
return self.get_conn().send_message(
QueueUrl=queue_url,
MessageBody=message_body,
DelaySeconds=delay_seconds,
MessageAttributes=message_attributes or {},
)

@ -0,0 +1,82 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from typing import Optional, Union
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class StepFunctionHook(AwsBaseHook):
"""
Interact with an AWS Step Functions State Machine.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
def __init__(self, region_name: Optional[str] = None, *args, **kwargs) -> None:
kwargs["client_type"] = "stepfunctions"
super().__init__(*args, **kwargs)
def start_execution(
self,
state_machine_arn: str,
name: Optional[str] = None,
state_machine_input: Union[dict, str, None] = None,
) -> str:
"""
Start Execution of the State Machine.
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.start_execution
:param state_machine_arn: AWS Step Function State Machine ARN
:type state_machine_arn: str
:param name: The name of the execution.
:type name: Optional[str]
:param state_machine_input: JSON data input to pass to the State Machine
:type state_machine_input: Union[Dict[str, any], str, None]
:return: Execution ARN
:rtype: str
"""
execution_args = {"stateMachineArn": state_machine_arn}
if name is not None:
execution_args["name"] = name
if state_machine_input is not None:
if isinstance(state_machine_input, str):
execution_args["input"] = state_machine_input
elif isinstance(state_machine_input, dict):
execution_args["input"] = json.dumps(state_machine_input)
self.log.info("Executing Step Function State Machine: %s", state_machine_arn)
response = self.conn.start_execution(**execution_args)
return response.get("executionArn")
def describe_execution(self, execution_arn: str) -> dict:
"""
Describes a State Machine Execution
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.describe_execution
:param execution_arn: ARN of the State Machine Execution
:type execution_arn: str
:return: Dict with Execution details
:rtype: dict
"""
return self.get_conn().describe_execution(executionArn=execution_arn)

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save