commit
88398a1491
1031 changed files with 149779 additions and 0 deletions
@ -0,0 +1,2 @@ |
||||
.DS_Store |
||||
*/.DS_Store |
||||
@ -0,0 +1,46 @@ |
||||
FROM apache/airflow:2.0.2-python3.8 |
||||
USER root |
||||
RUN apt-get update \ |
||||
&& apt-get install -y --no-install-recommends \ |
||||
git \ |
||||
pylint \ |
||||
libpq-dev \ |
||||
python-dev \ |
||||
# gcc \ |
||||
&& apt-get remove python-cffi \ |
||||
&& apt-get autoremove -yqq --purge \ |
||||
&& apt-get clean \ |
||||
&& rm -rf /var/lib/apt/lists/* |
||||
|
||||
RUN /usr/local/bin/python -m pip install --upgrade pip |
||||
|
||||
# COPY --from=docker /usr/local/bin/docker /usr/local/bin/docker |
||||
RUN curl -sSL https://get.docker.com/ | sh |
||||
# COPY requirements.txt ./ |
||||
# RUN pip install --no-cache-dir -r requirements.txt |
||||
|
||||
# ENV PYTHONPATH "${PYTHONPATH}:/your/custom/path" |
||||
ENV AIRFLOW_UID=1000 |
||||
ENV AIRFLOW_GID=1000 |
||||
|
||||
RUN echo airflow ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/airflow |
||||
RUN chmod 0440 /etc/sudoers.d/airflow |
||||
RUN echo 'airflow:airflow' | chpasswd |
||||
RUN groupmod --gid ${AIRFLOW_GID} airflow |
||||
RUN groupmod --gid 998 docker |
||||
# adduser --quiet "airflow" --uid "${AIRFLOW_UID}" \ |
||||
# --gid "${AIRFLOW_GID}" \ |
||||
# --home "${AIRFLOW_USER_HOME_DIR}" |
||||
RUN usermod --gid "${AIRFLOW_GID}" --uid "${AIRFLOW_UID}" airflow |
||||
RUN usermod -aG docker airflow |
||||
RUN chown -R ${AIRFLOW_UID}:${AIRFLOW_GID} /home/airflow |
||||
RUN chown -R ${AIRFLOW_UID}:${AIRFLOW_GID} /opt/airflow |
||||
|
||||
WORKDIR /opt/airflow |
||||
|
||||
# RUN usermod -g ${AIRFLOW_GID} airflow -G airflow |
||||
|
||||
USER ${AIRFLOW_UID} |
||||
|
||||
COPY requirements.txt ./ |
||||
RUN pip install --no-cache-dir -r requirements.txt |
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,11 @@ |
||||
STATE_COLORS = { |
||||
"queued": "darkgray", |
||||
"running": "blue", |
||||
"success": "cobalt", |
||||
"failed": "firebrick", |
||||
"up_for_retry": "yellow", |
||||
"up_for_reschedule": "turquoise", |
||||
"upstream_failed": "orange", |
||||
"skipped": "darkorchid", |
||||
"scheduled": "tan", |
||||
} |
||||
@ -0,0 +1,5 @@ |
||||
from copy import deepcopy |
||||
|
||||
from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG |
||||
|
||||
LOGGING_CONFIG = deepcopy(DEFAULT_LOGGING_CONFIG) |
||||
@ -0,0 +1,23 @@ |
||||
from datetime import datetime |
||||
|
||||
from airflow import DAG |
||||
from airflow.hooks.base_hook import BaseHook |
||||
from airflow.operators.python_operator import PythonOperator |
||||
|
||||
|
||||
def get_secrets(**kwargs): |
||||
conn = BaseHook.get_connection(kwargs["my_conn_id"]) |
||||
print( |
||||
f"Password: {conn.password}, Login: {conn.login}, URI: {conn.get_uri()}, Host: {conn.host}" |
||||
) |
||||
|
||||
|
||||
with DAG( |
||||
"example_secrets_dags", start_date=datetime(2020, 1, 1), schedule_interval=None |
||||
) as dag: |
||||
|
||||
test_task = PythonOperator( |
||||
task_id="test-task", |
||||
python_callable=get_secrets, |
||||
op_kwargs={"my_conn_id": "smtp_default"}, |
||||
) |
||||
@ -0,0 +1,130 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
"""Default configuration for the Airflow webserver""" |
||||
import os |
||||
|
||||
from flask_appbuilder.security.manager import AUTH_DB |
||||
|
||||
# from flask_appbuilder.security.manager import AUTH_LDAP |
||||
# from flask_appbuilder.security.manager import AUTH_OAUTH |
||||
# from flask_appbuilder.security.manager import AUTH_OID |
||||
# from flask_appbuilder.security.manager import AUTH_REMOTE_USER |
||||
|
||||
|
||||
basedir = os.path.abspath(os.path.dirname(__file__)) |
||||
|
||||
# Flask-WTF flag for CSRF |
||||
WTF_CSRF_ENABLED = True |
||||
|
||||
# ---------------------------------------------------- |
||||
# AUTHENTICATION CONFIG |
||||
# ---------------------------------------------------- |
||||
# For details on how to set up each of the following authentication, see |
||||
# http://flask-appbuilder.readthedocs.io/en/latest/security.html# authentication-methods |
||||
# for details. |
||||
|
||||
# The authentication type |
||||
# AUTH_OID : Is for OpenID |
||||
# AUTH_DB : Is for database |
||||
# AUTH_LDAP : Is for LDAP |
||||
# AUTH_REMOTE_USER : Is for using REMOTE_USER from web server |
||||
# AUTH_OAUTH : Is for OAuth |
||||
AUTH_TYPE = AUTH_DB |
||||
|
||||
# Uncomment to setup Full admin role name |
||||
# AUTH_ROLE_ADMIN = 'Admin' |
||||
|
||||
# Uncomment to setup Public role name, no authentication needed |
||||
# AUTH_ROLE_PUBLIC = 'Public' |
||||
|
||||
# Will allow user self registration |
||||
# AUTH_USER_REGISTRATION = True |
||||
|
||||
# The recaptcha it's automatically enabled for user self registration is active and the keys are necessary |
||||
# RECAPTCHA_PRIVATE_KEY = PRIVATE_KEY |
||||
# RECAPTCHA_PUBLIC_KEY = PUBLIC_KEY |
||||
|
||||
# Config for Flask-Mail necessary for user self registration |
||||
# MAIL_SERVER = 'smtp.gmail.com' |
||||
# MAIL_USE_TLS = True |
||||
# MAIL_USERNAME = 'yourappemail@gmail.com' |
||||
# MAIL_PASSWORD = 'passwordformail' |
||||
# MAIL_DEFAULT_SENDER = 'sender@gmail.com' |
||||
|
||||
# The default user self registration role |
||||
# AUTH_USER_REGISTRATION_ROLE = "Public" |
||||
|
||||
# When using OAuth Auth, uncomment to setup provider(s) info |
||||
# Google OAuth example: |
||||
# OAUTH_PROVIDERS = [{ |
||||
# 'name':'google', |
||||
# 'token_key':'access_token', |
||||
# 'icon':'fa-google', |
||||
# 'remote_app': { |
||||
# 'api_base_url':'https://www.googleapis.com/oauth2/v2/', |
||||
# 'client_kwargs':{ |
||||
# 'scope': 'email profile' |
||||
# }, |
||||
# 'access_token_url':'https://accounts.google.com/o/oauth2/token', |
||||
# 'authorize_url':'https://accounts.google.com/o/oauth2/auth', |
||||
# 'request_token_url': None, |
||||
# 'client_id': GOOGLE_KEY, |
||||
# 'client_secret': GOOGLE_SECRET_KEY, |
||||
# } |
||||
# }] |
||||
|
||||
# When using LDAP Auth, setup the ldap server |
||||
# AUTH_LDAP_SERVER = "ldap://ldapserver.new" |
||||
|
||||
# When using OpenID Auth, uncomment to setup OpenID providers. |
||||
# example for OpenID authentication |
||||
# OPENID_PROVIDERS = [ |
||||
# { 'name': 'Yahoo', 'url': 'https://me.yahoo.com' }, |
||||
# { 'name': 'AOL', 'url': 'http://openid.aol.com/<username>' }, |
||||
# { 'name': 'Flickr', 'url': 'http://www.flickr.com/<username>' }, |
||||
# { 'name': 'MyOpenID', 'url': 'https://www.myopenid.com' }] |
||||
|
||||
# ---------------------------------------------------- |
||||
# Theme CONFIG |
||||
# ---------------------------------------------------- |
||||
# Flask App Builder comes up with a number of predefined themes |
||||
# that you can use for Apache Airflow. |
||||
# http://flask-appbuilder.readthedocs.io/en/latest/customizing.html#changing-themes |
||||
# Please make sure to remove "navbar_color" configuration from airflow.cfg |
||||
# in order to fully utilize the theme. (or use that property in conjunction with theme) |
||||
# APP_THEME = "bootstrap-theme.css" # default bootstrap |
||||
# APP_THEME = "amelia.css" |
||||
# APP_THEME = "cerulean.css" |
||||
APP_THEME = "cosmo.css" |
||||
# APP_THEME = "cyborg.css" |
||||
# APP_THEME = "darkly.css" |
||||
# APP_THEME = "flatly.css" |
||||
# APP_THEME = "journal.css" |
||||
# APP_THEME = "lumen.css" |
||||
# APP_THEME = "paper.css" |
||||
# APP_THEME = "readable.css" |
||||
# APP_THEME = "sandstone.css" |
||||
# APP_THEME = "simplex.css" |
||||
# APP_THEME = "slate.css" |
||||
# APP_THEME = "solar.css" |
||||
# APP_THEME = "spacelab.css" |
||||
# APP_THEME = "superhero.css" |
||||
# APP_THEME = "united.css" |
||||
# APP_THEME = "yeti.css" |
||||
|
||||
# FAB_STATIC_FOLDER = |
||||
@ -0,0 +1,264 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
""" |
||||
Example Airflow DAG that shows the complex DAG structure. |
||||
""" |
||||
|
||||
from airflow import models |
||||
from airflow.models.baseoperator import chain |
||||
from airflow.operators.bash import BashOperator |
||||
from airflow.operators.python import PythonOperator |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
with models.DAG( |
||||
dag_id="example_complex", |
||||
schedule_interval=None, |
||||
start_date=days_ago(1), |
||||
tags=["example", "example2", "example3"], |
||||
) as dag: |
||||
|
||||
# Create |
||||
create_entry_group = BashOperator( |
||||
task_id="create_entry_group", bash_command="echo create_entry_group" |
||||
) |
||||
|
||||
create_entry_group_result = BashOperator( |
||||
task_id="create_entry_group_result", |
||||
bash_command="echo create_entry_group_result", |
||||
) |
||||
|
||||
create_entry_group_result2 = BashOperator( |
||||
task_id="create_entry_group_result2", |
||||
bash_command="echo create_entry_group_result2", |
||||
) |
||||
|
||||
create_entry_gcs = BashOperator( |
||||
task_id="create_entry_gcs", bash_command="echo create_entry_gcs" |
||||
) |
||||
|
||||
create_entry_gcs_result = BashOperator( |
||||
task_id="create_entry_gcs_result", bash_command="echo create_entry_gcs_result" |
||||
) |
||||
|
||||
create_entry_gcs_result2 = BashOperator( |
||||
task_id="create_entry_gcs_result2", bash_command="echo create_entry_gcs_result2" |
||||
) |
||||
|
||||
create_tag = BashOperator(task_id="create_tag", bash_command="echo create_tag") |
||||
|
||||
create_tag_result = BashOperator( |
||||
task_id="create_tag_result", bash_command="echo create_tag_result" |
||||
) |
||||
|
||||
create_tag_result2 = BashOperator( |
||||
task_id="create_tag_result2", bash_command="echo create_tag_result2" |
||||
) |
||||
|
||||
create_tag_template = BashOperator( |
||||
task_id="create_tag_template", bash_command="echo create_tag_template" |
||||
) |
||||
|
||||
create_tag_template_result = BashOperator( |
||||
task_id="create_tag_template_result", |
||||
bash_command="echo create_tag_template_result", |
||||
) |
||||
|
||||
create_tag_template_result2 = BashOperator( |
||||
task_id="create_tag_template_result2", |
||||
bash_command="echo create_tag_template_result2", |
||||
) |
||||
|
||||
create_tag_template_field = BashOperator( |
||||
task_id="create_tag_template_field", |
||||
bash_command="echo create_tag_template_field", |
||||
) |
||||
|
||||
create_tag_template_field_result = BashOperator( |
||||
task_id="create_tag_template_field_result", |
||||
bash_command="echo create_tag_template_field_result", |
||||
) |
||||
|
||||
create_tag_template_field_result2 = BashOperator( |
||||
task_id="create_tag_template_field_result2", |
||||
bash_command="echo create_tag_template_field_result", |
||||
) |
||||
|
||||
# Delete |
||||
delete_entry = BashOperator( |
||||
task_id="delete_entry", bash_command="echo delete_entry" |
||||
) |
||||
create_entry_gcs >> delete_entry |
||||
|
||||
delete_entry_group = BashOperator( |
||||
task_id="delete_entry_group", bash_command="echo delete_entry_group" |
||||
) |
||||
create_entry_group >> delete_entry_group |
||||
|
||||
delete_tag = BashOperator(task_id="delete_tag", bash_command="echo delete_tag") |
||||
create_tag >> delete_tag |
||||
|
||||
delete_tag_template_field = BashOperator( |
||||
task_id="delete_tag_template_field", |
||||
bash_command="echo delete_tag_template_field", |
||||
) |
||||
|
||||
delete_tag_template = BashOperator( |
||||
task_id="delete_tag_template", bash_command="echo delete_tag_template" |
||||
) |
||||
|
||||
# Get |
||||
get_entry_group = BashOperator( |
||||
task_id="get_entry_group", bash_command="echo get_entry_group" |
||||
) |
||||
|
||||
get_entry_group_result = BashOperator( |
||||
task_id="get_entry_group_result", bash_command="echo get_entry_group_result" |
||||
) |
||||
|
||||
get_entry = BashOperator(task_id="get_entry", bash_command="echo get_entry") |
||||
|
||||
get_entry_result = BashOperator( |
||||
task_id="get_entry_result", bash_command="echo get_entry_result" |
||||
) |
||||
|
||||
get_tag_template = BashOperator( |
||||
task_id="get_tag_template", bash_command="echo get_tag_template" |
||||
) |
||||
|
||||
get_tag_template_result = BashOperator( |
||||
task_id="get_tag_template_result", bash_command="echo get_tag_template_result" |
||||
) |
||||
|
||||
# List |
||||
list_tags = BashOperator(task_id="list_tags", bash_command="echo list_tags") |
||||
|
||||
list_tags_result = BashOperator( |
||||
task_id="list_tags_result", bash_command="echo list_tags_result" |
||||
) |
||||
|
||||
# Lookup |
||||
lookup_entry = BashOperator( |
||||
task_id="lookup_entry", bash_command="echo lookup_entry" |
||||
) |
||||
|
||||
lookup_entry_result = BashOperator( |
||||
task_id="lookup_entry_result", bash_command="echo lookup_entry_result" |
||||
) |
||||
|
||||
# Rename |
||||
rename_tag_template_field = BashOperator( |
||||
task_id="rename_tag_template_field", |
||||
bash_command="echo rename_tag_template_field", |
||||
) |
||||
|
||||
# Search |
||||
search_catalog = PythonOperator( |
||||
task_id="search_catalog", python_callable=lambda: print("search_catalog") |
||||
) |
||||
|
||||
search_catalog_result = BashOperator( |
||||
task_id="search_catalog_result", bash_command="echo search_catalog_result" |
||||
) |
||||
|
||||
# Update |
||||
update_entry = BashOperator( |
||||
task_id="update_entry", bash_command="echo update_entry" |
||||
) |
||||
|
||||
update_tag = BashOperator(task_id="update_tag", bash_command="echo update_tag") |
||||
|
||||
update_tag_template = BashOperator( |
||||
task_id="update_tag_template", bash_command="echo update_tag_template" |
||||
) |
||||
|
||||
update_tag_template_field = BashOperator( |
||||
task_id="update_tag_template_field", |
||||
bash_command="echo update_tag_template_field", |
||||
) |
||||
|
||||
# Create |
||||
create_tasks = [ |
||||
create_entry_group, |
||||
create_entry_gcs, |
||||
create_tag_template, |
||||
create_tag_template_field, |
||||
create_tag, |
||||
] |
||||
chain(*create_tasks) |
||||
|
||||
create_entry_group >> delete_entry_group |
||||
create_entry_group >> create_entry_group_result |
||||
create_entry_group >> create_entry_group_result2 |
||||
|
||||
create_entry_gcs >> delete_entry |
||||
create_entry_gcs >> create_entry_gcs_result |
||||
create_entry_gcs >> create_entry_gcs_result2 |
||||
|
||||
create_tag_template >> delete_tag_template_field |
||||
create_tag_template >> create_tag_template_result |
||||
create_tag_template >> create_tag_template_result2 |
||||
|
||||
create_tag_template_field >> delete_tag_template_field |
||||
create_tag_template_field >> create_tag_template_field_result |
||||
create_tag_template_field >> create_tag_template_field_result2 |
||||
|
||||
create_tag >> delete_tag |
||||
create_tag >> create_tag_result |
||||
create_tag >> create_tag_result2 |
||||
|
||||
# Delete |
||||
delete_tasks = [ |
||||
delete_tag, |
||||
delete_tag_template_field, |
||||
delete_tag_template, |
||||
delete_entry_group, |
||||
delete_entry, |
||||
] |
||||
chain(*delete_tasks) |
||||
|
||||
# Get |
||||
create_tag_template >> get_tag_template >> delete_tag_template |
||||
get_tag_template >> get_tag_template_result |
||||
|
||||
create_entry_gcs >> get_entry >> delete_entry |
||||
get_entry >> get_entry_result |
||||
|
||||
create_entry_group >> get_entry_group >> delete_entry_group |
||||
get_entry_group >> get_entry_group_result |
||||
|
||||
# List |
||||
create_tag >> list_tags >> delete_tag |
||||
list_tags >> list_tags_result |
||||
|
||||
# Lookup |
||||
create_entry_gcs >> lookup_entry >> delete_entry |
||||
lookup_entry >> lookup_entry_result |
||||
|
||||
# Rename |
||||
create_tag_template_field >> rename_tag_template_field >> delete_tag_template_field |
||||
|
||||
# Search |
||||
chain(create_tasks, search_catalog, delete_tasks) |
||||
search_catalog >> search_catalog_result |
||||
|
||||
# Update |
||||
create_entry_gcs >> update_entry >> delete_entry |
||||
create_tag >> update_tag >> delete_tag |
||||
create_tag_template >> update_tag_template >> delete_tag_template |
||||
create_tag_template_field >> update_tag_template_field >> rename_tag_template_field |
||||
@ -0,0 +1,93 @@ |
||||
""" |
||||
Code that goes along with the Airflow located at: |
||||
http://airflow.readthedocs.org/en/latest/tutorial.html |
||||
""" |
||||
from datetime import datetime, timedelta |
||||
|
||||
from airflow import DAG |
||||
|
||||
# Operators; we need this to operate! |
||||
from airflow.operators.bash import BashOperator |
||||
from airflow.operators.dummy_operator import DummyOperator |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
default_args = { |
||||
"owner": "donaldrich", |
||||
"depends_on_past": False, |
||||
"start_date": datetime(2016, 7, 13), |
||||
"email": ["email@gmail.com"], |
||||
"email_on_failure": False, |
||||
"email_on_retry": False, |
||||
"retries": 1, |
||||
"retry_delay": timedelta(minutes=15), |
||||
} |
||||
|
||||
script_path = "/data/scripts" |
||||
data_path = "/data/data/archive" |
||||
|
||||
dag = DAG( |
||||
"zip_docker", |
||||
default_args=default_args, |
||||
description="A simple tutorial DAG", |
||||
schedule_interval=None, |
||||
start_date=days_ago(2), |
||||
tags=["zip", "docker"], |
||||
) |
||||
|
||||
with dag: |
||||
|
||||
start = DummyOperator(task_id="start", dag=dag) |
||||
|
||||
pull1 = BashOperator( |
||||
task_id="pull_chrome", |
||||
bash_command="sudo docker pull selenoid/chrome:latest", |
||||
# retries=3, |
||||
# dag=dag |
||||
) |
||||
|
||||
pull2 = BashOperator( |
||||
task_id="pull_recorder", |
||||
bash_command="sudo docker pull selenoid/video-recorder:latest-release", |
||||
# retries=3, |
||||
# dag=dag |
||||
) |
||||
|
||||
scrape = BashOperator( |
||||
task_id="scrape_listings", |
||||
bash_command="sh /data/scripts/scrape.sh -s target -k docker", |
||||
# retries=3, |
||||
# dag=dag |
||||
) |
||||
|
||||
cleanup = BashOperator( |
||||
task_id="cleanup", |
||||
bash_command="sh /data/scripts/post-scrape.sh -s target -k devops", |
||||
# retries=3, |
||||
# dag=dag |
||||
) |
||||
|
||||
end = DummyOperator(task_id="end", dag=dag) |
||||
|
||||
start >> [pull1, pull2] >> scrape >> cleanup >> end |
||||
|
||||
# scrape = BashOperator( |
||||
# task_id="scrape_listings", |
||||
# bash_command="python3 " + script_path + '/gather/zip.py -k "devops"', |
||||
# # retries=3, |
||||
# # dag=dag |
||||
# ) |
||||
|
||||
# cleanup = BashOperator( |
||||
# task_id="cleanup", |
||||
# bash_command=script_path + "/post-scrape.sh -s zip -k devops", |
||||
# # retries=3, |
||||
# # dag=dag |
||||
# ) |
||||
|
||||
# init >> [pre1,pre2] |
||||
|
||||
# init >> pre2 |
||||
|
||||
# pre2 >> scrape |
||||
|
||||
# scrape >> cleanup |
||||
@ -0,0 +1,125 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
""" |
||||
### Tutorial Documentation |
||||
Documentation that goes along with the Airflow tutorial located |
||||
[here](https://airflow.apache.org/tutorial.html) |
||||
""" |
||||
# [START tutorial] |
||||
# [START import_module] |
||||
from datetime import timedelta |
||||
from textwrap import dedent |
||||
|
||||
# The DAG object; we'll need this to instantiate a DAG |
||||
from airflow import DAG |
||||
|
||||
# Operators; we need this to operate! |
||||
from airflow.operators.bash import BashOperator |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
# [END import_module] |
||||
|
||||
# [START default_args] |
||||
# These args will get passed on to each operator |
||||
# You can override them on a per-task basis during operator initialization |
||||
default_args = { |
||||
"owner": "airflow", |
||||
"depends_on_past": False, |
||||
"email": ["airflow@example.com"], |
||||
"email_on_failure": False, |
||||
"email_on_retry": False, |
||||
"retries": 1, |
||||
"retry_delay": timedelta(minutes=5), |
||||
# 'queue': 'bash_queue', |
||||
# 'pool': 'backfill', |
||||
# 'priority_weight': 10, |
||||
# 'end_date': datetime(2016, 1, 1), |
||||
# 'wait_for_downstream': False, |
||||
# 'dag': dag, |
||||
# 'sla': timedelta(hours=2), |
||||
# 'execution_timeout': timedelta(seconds=300), |
||||
# 'on_failure_callback': some_function, |
||||
# 'on_success_callback': some_other_function, |
||||
# 'on_retry_callback': another_function, |
||||
# 'sla_miss_callback': yet_another_function, |
||||
# 'trigger_rule': 'all_success' |
||||
} |
||||
# [END default_args] |
||||
|
||||
# [START instantiate_dag] |
||||
with DAG( |
||||
"debug", |
||||
default_args=default_args, |
||||
description="A simple tutorial DAG", |
||||
schedule_interval=None, |
||||
start_date=days_ago(2), |
||||
tags=["example"], |
||||
) as dag: |
||||
# [END instantiate_dag] |
||||
|
||||
cleanup = BashOperator( |
||||
task_id="cleanup", |
||||
bash_command="sudo sh /data/scripts/post-scrape.sh -s target -k devops", |
||||
) |
||||
|
||||
# t1, t2 and t3 are examples of tasks created by instantiating operators |
||||
# [START basic_task] |
||||
# t1 = BashOperator( |
||||
# task_id="print_date", |
||||
# bash_command="date", |
||||
# ) |
||||
|
||||
# t2 = BashOperator( |
||||
# task_id="sleep", |
||||
# depends_on_past=False, |
||||
# bash_command="sleep 5", |
||||
# retries=3, |
||||
# ) |
||||
# # [END basic_task] |
||||
|
||||
# # [START documentation] |
||||
# dag.doc_md = __doc__ |
||||
|
||||
cleanup.doc_md = dedent(""" |
||||
#### Task Documentation |
||||
You can document your task using the attributes `doc_md` (markdown), |
||||
`doc` (plain text), `doc_rst`, `doc_json`, `doc_yaml` which gets |
||||
rendered in the UI's Task Instance Details page. |
||||
 |
||||
""") |
||||
# [END documentation] |
||||
|
||||
# [START jinja_template] |
||||
templated_command = dedent(""" |
||||
{% for i in range(5) %} |
||||
echo "{{ ds }}" |
||||
echo "{{ macros.ds_add(ds, 7)}}" |
||||
echo "{{ params.my_param }}" |
||||
{% endfor %} |
||||
""") |
||||
|
||||
t3 = BashOperator( |
||||
task_id="templated", |
||||
depends_on_past=False, |
||||
bash_command=templated_command, |
||||
params={"my_param": "Parameter I passed in"}, |
||||
) |
||||
# [END jinja_template] |
||||
|
||||
cleanup |
||||
# [END tutorial] |
||||
@ -0,0 +1,101 @@ |
||||
import random |
||||
from time import sleep |
||||
|
||||
from selenium import webdriver |
||||
from selenium.webdriver import ActionChains |
||||
from selenium.webdriver.common.by import By |
||||
from selenium.webdriver.support import expected_conditions as EC |
||||
from selenium.webdriver.support.wait import WebDriverWait |
||||
|
||||
# INSTANCE DRIVER |
||||
|
||||
|
||||
def configure_driver(settings, user_agent): |
||||
|
||||
options = webdriver.ChromeOptions() |
||||
|
||||
# Options to try to fool the site we are a normal browser. |
||||
options.add_experimental_option("excludeSwitches", ["enable-automation"]) |
||||
options.add_experimental_option("useAutomationExtension", False) |
||||
options.add_experimental_option( |
||||
"prefs", |
||||
{ |
||||
"download.default_directory": settings.files_path, |
||||
"download.prompt_for_download": False, |
||||
"download.directory_upgrade": True, |
||||
"safebrowsing_for_trusted_sources_enabled": False, |
||||
"safebrowsing.enabled": False, |
||||
}, |
||||
) |
||||
options.add_argument("--incognito") |
||||
options.add_argument("start-maximized") |
||||
options.add_argument("user-agent={user_agent}") |
||||
options.add_argument("disable-blink-features") |
||||
options.add_argument("disable-blink-features=AutomationControlled") |
||||
|
||||
driver = webdriver.Chrome( |
||||
options=options, executable_path=settings.chromedriver_path |
||||
) |
||||
|
||||
# #Overwrite the webdriver property |
||||
# driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", |
||||
# { |
||||
# "source": """ |
||||
# Object.defineProperty(navigator, 'webdriver', { |
||||
# get: () => undefined |
||||
# }) |
||||
# """ |
||||
# }) |
||||
|
||||
driver.execute_cdp_cmd("Network.enable", {}) |
||||
|
||||
# Overwrite the User-Agent header |
||||
driver.execute_cdp_cmd( |
||||
"Network.setExtraHTTPHeaders", {"headers": {"User-Agent": user_agent}} |
||||
) |
||||
|
||||
driver.command_executor._commands["send_command"] = ( |
||||
"POST", |
||||
"/session/$sessionId/chromium/send_command", |
||||
) |
||||
|
||||
return driver |
||||
|
||||
|
||||
# SOLVENTAR CAPTCHA |
||||
def solve_wait_recaptcha(driver): |
||||
|
||||
###### Move to reCAPTCHA Iframe |
||||
WebDriverWait(driver, 5).until( |
||||
EC.frame_to_be_available_and_switch_to_it( |
||||
( |
||||
By.CSS_SELECTOR, |
||||
"iframe[src^='https://www.google.com/recaptcha/api2/anchor?']", |
||||
) |
||||
) |
||||
) |
||||
|
||||
check_selector = "span.recaptcha-checkbox.goog-inline-block.recaptcha-checkbox-unchecked.rc-anchor-checkbox" |
||||
|
||||
captcha_check = driver.find_element_by_css_selector(check_selector) |
||||
|
||||
###### Random delay before hover & click the checkbox |
||||
sleep(random.uniform(3, 6)) |
||||
ActionChains(driver).move_to_element(captcha_check).perform() |
||||
|
||||
###### Hover it |
||||
sleep(random.uniform(0.5, 1)) |
||||
hov = ActionChains(driver).move_to_element(captcha_check).perform() |
||||
|
||||
###### Random delay before click the checkbox |
||||
sleep(random.uniform(0.5, 1)) |
||||
driver.execute_script("arguments[0].click()", captcha_check) |
||||
|
||||
###### Wait for recaptcha to be in solved state |
||||
elem = None |
||||
while elem is None: |
||||
try: |
||||
elem = driver.find_element_by_class_name("recaptcha-checkbox-checked") |
||||
except: |
||||
pass |
||||
sleep(5) |
||||
@ -0,0 +1,27 @@ |
||||
import mechanicalsoup |
||||
|
||||
browser = mechanicalsoup.StatefulBrowser() |
||||
browser.open("http://httpbin.org/") |
||||
|
||||
print(browser.url) |
||||
browser.follow_link("forms") |
||||
print(browser.url) |
||||
print(browser.page) |
||||
|
||||
browser.select_form('form[action="/post"]') |
||||
browser["custname"] = "Me" |
||||
browser["custtel"] = "00 00 0001" |
||||
browser["custemail"] = "nobody@example.com" |
||||
browser["size"] = "medium" |
||||
browser["topping"] = "onion" |
||||
browser["topping"] = ("bacon", "cheese") |
||||
browser["comments"] = "This pizza looks really good :-)" |
||||
|
||||
# Uncomment to launch a real web browser on the current page. |
||||
# browser.launch_browser() |
||||
|
||||
# Uncomment to display a summary of the filled-in form |
||||
browser.form.print_summary() |
||||
|
||||
response = browser.submit_selected() |
||||
print(response.text) |
||||
@ -0,0 +1,270 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
"""Airflow logging settings""" |
||||
|
||||
import os |
||||
from pathlib import Path |
||||
from typing import Any, Dict, Union |
||||
from urllib.parse import urlparse |
||||
|
||||
from airflow.configuration import conf |
||||
from airflow.exceptions import AirflowException |
||||
|
||||
# TODO: Logging format and level should be configured |
||||
# in this file instead of from airflow.cfg. Currently |
||||
# there are other log format and level configurations in |
||||
# settings.py and cli.py. Please see AIRFLOW-1455. |
||||
LOG_LEVEL: str = conf.get('logging', 'LOGGING_LEVEL').upper() |
||||
|
||||
|
||||
# Flask appbuilder's info level log is very verbose, |
||||
# so it's set to 'WARN' by default. |
||||
FAB_LOG_LEVEL: str = conf.get('logging', 'FAB_LOGGING_LEVEL').upper() |
||||
|
||||
LOG_FORMAT: str = conf.get('logging', 'LOG_FORMAT') |
||||
|
||||
COLORED_LOG_FORMAT: str = conf.get('logging', 'COLORED_LOG_FORMAT') |
||||
|
||||
COLORED_LOG: bool = conf.getboolean('logging', 'COLORED_CONSOLE_LOG') |
||||
|
||||
COLORED_FORMATTER_CLASS: str = conf.get('logging', 'COLORED_FORMATTER_CLASS') |
||||
|
||||
BASE_LOG_FOLDER: str = conf.get('logging', 'BASE_LOG_FOLDER') |
||||
|
||||
PROCESSOR_LOG_FOLDER: str = conf.get('scheduler', 'CHILD_PROCESS_LOG_DIRECTORY') |
||||
|
||||
DAG_PROCESSOR_MANAGER_LOG_LOCATION: str = conf.get('logging', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION') |
||||
|
||||
FILENAME_TEMPLATE: str = conf.get('logging', 'LOG_FILENAME_TEMPLATE') |
||||
|
||||
PROCESSOR_FILENAME_TEMPLATE: str = conf.get('logging', 'LOG_PROCESSOR_FILENAME_TEMPLATE') |
||||
|
||||
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { |
||||
'version': 1, |
||||
'disable_existing_loggers': False, |
||||
'formatters': { |
||||
'airflow': {'format': LOG_FORMAT}, |
||||
'airflow_coloured': { |
||||
'format': COLORED_LOG_FORMAT if COLORED_LOG else LOG_FORMAT, |
||||
'class': COLORED_FORMATTER_CLASS if COLORED_LOG else 'logging.Formatter', |
||||
}, |
||||
}, |
||||
'handlers': { |
||||
'console': { |
||||
'class': 'airflow.utils.log.logging_mixin.RedirectStdHandler', |
||||
'formatter': 'airflow_coloured', |
||||
'stream': 'sys.stdout', |
||||
}, |
||||
'task': { |
||||
'class': 'airflow.utils.log.file_task_handler.FileTaskHandler', |
||||
'formatter': 'airflow', |
||||
'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER), |
||||
'filename_template': FILENAME_TEMPLATE, |
||||
}, |
||||
'processor': { |
||||
'class': 'airflow.utils.log.file_processor_handler.FileProcessorHandler', |
||||
'formatter': 'airflow', |
||||
'base_log_folder': os.path.expanduser(PROCESSOR_LOG_FOLDER), |
||||
'filename_template': PROCESSOR_FILENAME_TEMPLATE, |
||||
}, |
||||
}, |
||||
'loggers': { |
||||
'airflow.processor': { |
||||
'handlers': ['processor'], |
||||
'level': LOG_LEVEL, |
||||
'propagate': False, |
||||
}, |
||||
'airflow.task': { |
||||
'handlers': ['task'], |
||||
'level': LOG_LEVEL, |
||||
'propagate': False, |
||||
}, |
||||
'flask_appbuilder': { |
||||
'handler': ['console'], |
||||
'level': FAB_LOG_LEVEL, |
||||
'propagate': True, |
||||
}, |
||||
}, |
||||
'root': { |
||||
'handlers': ['console'], |
||||
'level': LOG_LEVEL, |
||||
}, |
||||
} |
||||
|
||||
EXTRA_LOGGER_NAMES: str = conf.get('logging', 'EXTRA_LOGGER_NAMES', fallback=None) |
||||
if EXTRA_LOGGER_NAMES: |
||||
new_loggers = { |
||||
logger_name.strip(): { |
||||
'handler': ['console'], |
||||
'level': LOG_LEVEL, |
||||
'propagate': True, |
||||
} |
||||
for logger_name in EXTRA_LOGGER_NAMES.split(",") |
||||
} |
||||
DEFAULT_LOGGING_CONFIG['loggers'].update(new_loggers) |
||||
|
||||
DEFAULT_DAG_PARSING_LOGGING_CONFIG: Dict[str, Dict[str, Dict[str, Any]]] = { |
||||
'handlers': { |
||||
'processor_manager': { |
||||
'class': 'logging.handlers.RotatingFileHandler', |
||||
'formatter': 'airflow', |
||||
'filename': DAG_PROCESSOR_MANAGER_LOG_LOCATION, |
||||
'mode': 'a', |
||||
'maxBytes': 104857600, # 100MB |
||||
'backupCount': 5, |
||||
} |
||||
}, |
||||
'loggers': { |
||||
'airflow.processor_manager': { |
||||
'handlers': ['processor_manager'], |
||||
'level': LOG_LEVEL, |
||||
'propagate': False, |
||||
} |
||||
}, |
||||
} |
||||
|
||||
# Only update the handlers and loggers when CONFIG_PROCESSOR_MANAGER_LOGGER is set. |
||||
# This is to avoid exceptions when initializing RotatingFileHandler multiple times |
||||
# in multiple processes. |
||||
if os.environ.get('CONFIG_PROCESSOR_MANAGER_LOGGER') == 'True': |
||||
DEFAULT_LOGGING_CONFIG['handlers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers']) |
||||
DEFAULT_LOGGING_CONFIG['loggers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['loggers']) |
||||
|
||||
# Manually create log directory for processor_manager handler as RotatingFileHandler |
||||
# will only create file but not the directory. |
||||
processor_manager_handler_config: Dict[str, Any] = DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers'][ |
||||
'processor_manager' |
||||
] |
||||
directory: str = os.path.dirname(processor_manager_handler_config['filename']) |
||||
Path(directory).mkdir(parents=True, exist_ok=True, mode=0o755) |
||||
|
||||
################## |
||||
# Remote logging # |
||||
################## |
||||
|
||||
REMOTE_LOGGING: bool = conf.getboolean('logging', 'remote_logging') |
||||
|
||||
if REMOTE_LOGGING: |
||||
|
||||
ELASTICSEARCH_HOST: str = conf.get('elasticsearch', 'HOST') |
||||
|
||||
# Storage bucket URL for remote logging |
||||
# S3 buckets should start with "s3://" |
||||
# Cloudwatch log groups should start with "cloudwatch://" |
||||
# GCS buckets should start with "gs://" |
||||
# WASB buckets should start with "wasb" |
||||
# just to help Airflow select correct handler |
||||
REMOTE_BASE_LOG_FOLDER: str = conf.get('logging', 'REMOTE_BASE_LOG_FOLDER') |
||||
|
||||
if REMOTE_BASE_LOG_FOLDER.startswith('s3://'): |
||||
S3_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { |
||||
'task': { |
||||
'class': 'airflow.providers.amazon.aws.log.s3_task_handler.S3TaskHandler', |
||||
'formatter': 'airflow', |
||||
'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), |
||||
's3_log_folder': REMOTE_BASE_LOG_FOLDER, |
||||
'filename_template': FILENAME_TEMPLATE, |
||||
}, |
||||
} |
||||
|
||||
DEFAULT_LOGGING_CONFIG['handlers'].update(S3_REMOTE_HANDLERS) |
||||
elif REMOTE_BASE_LOG_FOLDER.startswith('cloudwatch://'): |
||||
CLOUDWATCH_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { |
||||
'task': { |
||||
'class': 'airflow.providers.amazon.aws.log.cloudwatch_task_handler.CloudwatchTaskHandler', |
||||
'formatter': 'airflow', |
||||
'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), |
||||
'log_group_arn': urlparse(REMOTE_BASE_LOG_FOLDER).netloc, |
||||
'filename_template': FILENAME_TEMPLATE, |
||||
}, |
||||
} |
||||
|
||||
DEFAULT_LOGGING_CONFIG['handlers'].update(CLOUDWATCH_REMOTE_HANDLERS) |
||||
elif REMOTE_BASE_LOG_FOLDER.startswith('gs://'): |
||||
key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None) |
||||
GCS_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { |
||||
'task': { |
||||
'class': 'airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler', |
||||
'formatter': 'airflow', |
||||
'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), |
||||
'gcs_log_folder': REMOTE_BASE_LOG_FOLDER, |
||||
'filename_template': FILENAME_TEMPLATE, |
||||
'gcp_key_path': key_path, |
||||
}, |
||||
} |
||||
|
||||
DEFAULT_LOGGING_CONFIG['handlers'].update(GCS_REMOTE_HANDLERS) |
||||
elif REMOTE_BASE_LOG_FOLDER.startswith('wasb'): |
||||
WASB_REMOTE_HANDLERS: Dict[str, Dict[str, Union[str, bool]]] = { |
||||
'task': { |
||||
'class': 'airflow.providers.microsoft.azure.log.wasb_task_handler.WasbTaskHandler', |
||||
'formatter': 'airflow', |
||||
'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), |
||||
'wasb_log_folder': REMOTE_BASE_LOG_FOLDER, |
||||
'wasb_container': 'airflow-logs', |
||||
'filename_template': FILENAME_TEMPLATE, |
||||
'delete_local_copy': False, |
||||
}, |
||||
} |
||||
|
||||
DEFAULT_LOGGING_CONFIG['handlers'].update(WASB_REMOTE_HANDLERS) |
||||
elif REMOTE_BASE_LOG_FOLDER.startswith('stackdriver://'): |
||||
key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None) |
||||
# stackdriver:///airflow-tasks => airflow-tasks |
||||
log_name = urlparse(REMOTE_BASE_LOG_FOLDER).path[1:] |
||||
STACKDRIVER_REMOTE_HANDLERS = { |
||||
'task': { |
||||
'class': 'airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverTaskHandler', |
||||
'formatter': 'airflow', |
||||
'name': log_name, |
||||
'gcp_key_path': key_path, |
||||
} |
||||
} |
||||
|
||||
DEFAULT_LOGGING_CONFIG['handlers'].update(STACKDRIVER_REMOTE_HANDLERS) |
||||
elif ELASTICSEARCH_HOST: |
||||
ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get('elasticsearch', 'LOG_ID_TEMPLATE') |
||||
ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get('elasticsearch', 'END_OF_LOG_MARK') |
||||
ELASTICSEARCH_FRONTEND: str = conf.get('elasticsearch', 'frontend') |
||||
ELASTICSEARCH_WRITE_STDOUT: bool = conf.getboolean('elasticsearch', 'WRITE_STDOUT') |
||||
ELASTICSEARCH_JSON_FORMAT: bool = conf.getboolean('elasticsearch', 'JSON_FORMAT') |
||||
ELASTICSEARCH_JSON_FIELDS: str = conf.get('elasticsearch', 'JSON_FIELDS') |
||||
|
||||
ELASTIC_REMOTE_HANDLERS: Dict[str, Dict[str, Union[str, bool]]] = { |
||||
'task': { |
||||
'class': 'airflow.providers.elasticsearch.log.es_task_handler.ElasticsearchTaskHandler', |
||||
'formatter': 'airflow', |
||||
'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), |
||||
'log_id_template': ELASTICSEARCH_LOG_ID_TEMPLATE, |
||||
'filename_template': FILENAME_TEMPLATE, |
||||
'end_of_log_mark': ELASTICSEARCH_END_OF_LOG_MARK, |
||||
'host': ELASTICSEARCH_HOST, |
||||
'frontend': ELASTICSEARCH_FRONTEND, |
||||
'write_stdout': ELASTICSEARCH_WRITE_STDOUT, |
||||
'json_format': ELASTICSEARCH_JSON_FORMAT, |
||||
'json_fields': ELASTICSEARCH_JSON_FIELDS, |
||||
}, |
||||
} |
||||
|
||||
DEFAULT_LOGGING_CONFIG['handlers'].update(ELASTIC_REMOTE_HANDLERS) |
||||
else: |
||||
raise AirflowException( |
||||
"Incorrect remote log configuration. Please check the configuration of option 'host' in " |
||||
"section 'elasticsearch' if you are using Elasticsearch. In the other case, " |
||||
"'remote_base_log_folder' option in the 'logging' section." |
||||
) |
||||
@ -0,0 +1,28 @@ |
||||
from airflow.plugins_manager import AirflowPlugin |
||||
from flask_admin.base import MenuLink |
||||
|
||||
github = MenuLink( |
||||
category="Astronomer", |
||||
name="Airflow Instance Github Repo", |
||||
url="https://github.com/astronomerio/astronomer-dags", |
||||
) |
||||
|
||||
astronomer_home = MenuLink( |
||||
category="Astronomer", name="Astronomer Home", url="https://www.astronomer.io/" |
||||
) |
||||
|
||||
aiflow_plugins = MenuLink( |
||||
category="Astronomer", |
||||
name="Airflow Plugins", |
||||
url="https://github.com/airflow-plugins", |
||||
) |
||||
|
||||
# Defining the plugin class |
||||
class AirflowTestPlugin(AirflowPlugin): |
||||
name = "AstronomerMenuLinks" |
||||
operators = [] |
||||
flask_blueprints = [] |
||||
hooks = [] |
||||
executors = [] |
||||
admin_views = [] |
||||
menu_links = [github, astronomer_home, aiflow_plugins] |
||||
@ -0,0 +1,32 @@ |
||||
""" |
||||
Singer |
||||
This example shows how to use Singer within Airflow using a custom operator: |
||||
- SingerOperator |
||||
https://github.com/airflow-plugins/singer_plugin/blob/master/operators/singer_operator.py#L5 |
||||
A complete list of Taps and Targets can be found in the Singer.io Github org: |
||||
https://github.com/singer-io |
||||
""" |
||||
|
||||
|
||||
from datetime import datetime |
||||
|
||||
from airflow import DAG |
||||
from airflow.operators import SingerOperator |
||||
|
||||
default_args = { |
||||
"start_date": datetime(2018, 2, 22), |
||||
"retries": 0, |
||||
"email": [], |
||||
"email_on_failure": True, |
||||
"email_on_retry": False, |
||||
} |
||||
|
||||
dag = DAG( |
||||
"__singer__fixerio_to_csv", schedule_interval="@hourly", default_args=default_args |
||||
) |
||||
|
||||
with dag: |
||||
|
||||
singer = SingerOperator(task_id="singer", tap="fixerio", target="csv") |
||||
|
||||
singer |
||||
@ -0,0 +1,181 @@ |
||||
""" |
||||
Headless Site Navigation and File Download (Using Selenium) to S3 |
||||
This example demonstrates using Selenium (via Firefox/GeckoDriver) to: |
||||
1) Log into a website w/ credentials stored in connection labeled 'selenium_conn_id' |
||||
2) Download a file (initiated on login) |
||||
3) Transform the CSV into JSON formatting |
||||
4) Append the current data to each record |
||||
5) Load the corresponding file into S3 |
||||
To use this DAG, you will need to have the following installed: |
||||
[XVFB](https://www.x.org/archive/X11R7.6/doc/man/man1/Xvfb.1.xhtml) |
||||
[GeckoDriver](https://github.com/mozilla/geckodriver/releases/download) |
||||
selenium==3.11.0 |
||||
xvfbwrapper==0.2.9 |
||||
""" |
||||
import csv |
||||
import datetime |
||||
import json |
||||
import logging |
||||
import os |
||||
import time |
||||
from datetime import datetime, timedelta |
||||
|
||||
import boa |
||||
import requests |
||||
from airflow import DAG |
||||
from airflow.models import Connection |
||||
from airflow.operators.dummy_operator import DummyOperator |
||||
from airflow.operators.python_operator import PythonOperator, PythonVirtualenvOperator |
||||
from airflow.providers.docker.operators.docker import DockerOperator |
||||
from airflow.utils.dates import days_ago |
||||
from airflow.utils.db import provide_session |
||||
from bs4 import BeautifulSoup |
||||
from selenium import webdriver |
||||
from selenium.webdriver.chrome.options import Options |
||||
from selenium.webdriver.chrome.service import Service |
||||
from selenium.webdriver.common.by import By |
||||
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities |
||||
from selenium.webdriver.common.keys import Keys |
||||
from tools.network import ( |
||||
browser_config, |
||||
browser_feature, |
||||
driver_object, |
||||
ip_status, |
||||
proxy_ip, |
||||
selenoid_status, |
||||
uc_test, |
||||
user_agent, |
||||
vpn_settings, |
||||
) |
||||
|
||||
from config.webdriver import browser_capabilities, browser_options |
||||
|
||||
default_args = { |
||||
"start_date": days_ago(2), |
||||
"email": [], |
||||
"email_on_failure": True, |
||||
"email_on_retry": False, |
||||
# 'retries': 2, |
||||
"retry_delay": timedelta(minutes=5), |
||||
"catchup": False, |
||||
} |
||||
|
||||
|
||||
# def hello_world_py(): |
||||
# # selenium_conn_id = kwargs.get('templates_dict', None).get('selenium_conn_id', None) |
||||
# # filename = kwargs.get('templates_dict', None).get('filename', None) |
||||
# # s3_conn_id = kwargs.get('templates_dict', None).get('s3_conn_id', None) |
||||
# # s3_bucket = kwargs.get('templates_dict', None).get('s3_bucket', None) |
||||
# # s3_key = kwargs.get('templates_dict', None).get('s3_key', None) |
||||
# # date = kwargs.get('templates_dict', None).get('date', None) |
||||
# # module_name = kwargs.get('templates_dict', None).get('module', None) |
||||
|
||||
|
||||
# module = "anon_browser_test" |
||||
|
||||
# chrome_options = browser_options() |
||||
# capabilities = browser_capabilities(module) |
||||
# logging.info('Assembling driver') |
||||
# driver = webdriver.Remote( |
||||
# command_executor="http://192.168.1.101:4444/wd/hub", |
||||
# options=chrome_options, |
||||
# desired_capabilities=capabilities, |
||||
# ) |
||||
# logging.info('proxy IP') |
||||
# proxy_ip() |
||||
# logging.info('driver') |
||||
# vpn_settings() |
||||
# logging.info('driver') |
||||
# selenoid_status() |
||||
# logging.info('driver') |
||||
# ip_status(driver) |
||||
# logging.info('driver') |
||||
# browser_config(driver) |
||||
# logging.info('driver') |
||||
# user_agent(driver) |
||||
# logging.info('driver') |
||||
# driver_object(driver) |
||||
# logging.info('driver') |
||||
# browser_feature(driver) |
||||
# logging.info('driver') |
||||
|
||||
# driver.quit() |
||||
# logging.info('driver') |
||||
# uc_test() |
||||
# # print("Finished") |
||||
# return 'Whatever you return gets printed in the logs' |
||||
|
||||
# dag = DAG( |
||||
# 'anon_browser_test', |
||||
# schedule_interval='@daily', |
||||
# default_args=default_args, |
||||
# catchup=False |
||||
# ) |
||||
|
||||
# dummy_operator = DummyOperator(task_id="dummy_task", retries=3, dag=dag) |
||||
|
||||
# selenium = PythonOperator( |
||||
# task_id='anon_browser_test', |
||||
# python_callable=hello_world_py, |
||||
# templates_dict={"module": "anon_browser_test"}, |
||||
# dag=dag |
||||
# # "s3_bucket": S3_BUCKET, |
||||
# # "s3_key": S3_KEY, |
||||
# # "date": date} |
||||
# # provide_context=True |
||||
# ) |
||||
|
||||
# t1 = DockerOperator( |
||||
# # api_version='1.19', |
||||
# # docker_url='tcp://localhost:2375', # Set your docker URL |
||||
# command='/bin/sleep 30', |
||||
# image='selenoid/chrome:latest', |
||||
# # network_mode='bridge', |
||||
# task_id='chrome', |
||||
# dag=dag, |
||||
# ) |
||||
|
||||
# t2 = DockerOperator( |
||||
# # api_version='1.19', |
||||
# # docker_url='tcp://localhost:2375', # Set your docker URL |
||||
# command='/bin/sleep 30', |
||||
# image='selenoid/video-recorder:latest-release', |
||||
# # network_mode='bridge', |
||||
# task_id='video_recorder', |
||||
# dag=dag, |
||||
# ) |
||||
|
||||
# [START howto_operator_python_venv] |
||||
def callable_virtualenv(): |
||||
""" |
||||
Example function that will be performed in a virtual environment. |
||||
Importing at the module level ensures that it will not attempt to import the |
||||
library before it is installed. |
||||
""" |
||||
from time import sleep |
||||
|
||||
from colorama import Back, Fore, Style |
||||
|
||||
print(Fore.RED + "some red text") |
||||
print(Back.GREEN + "and with a green background") |
||||
print(Style.DIM + "and in dim text") |
||||
print(Style.RESET_ALL) |
||||
for _ in range(10): |
||||
print(Style.DIM + "Please wait...", flush=True) |
||||
sleep(10) |
||||
print("Finished") |
||||
|
||||
|
||||
virtualenv_task = PythonVirtualenvOperator( |
||||
task_id="virtualenv_python", |
||||
python_callable=callable_virtualenv, |
||||
requirements=["colorama==0.4.0"], |
||||
system_site_packages=False, |
||||
dag=dag, |
||||
) |
||||
# [END howto_operator_python_venv] |
||||
|
||||
# selenium >> dummy_operator |
||||
# dummy_operator >> virtualenv_task |
||||
# t1 >> selenium |
||||
# t2 >> selenium |
||||
@ -0,0 +1,62 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
from datetime import timedelta |
||||
|
||||
from airflow import DAG |
||||
from airflow.operators.bash import BashOperator |
||||
from airflow.providers.docker.operators.docker import DockerOperator |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
default_args = { |
||||
"owner": "airflow", |
||||
"depends_on_past": False, |
||||
"email": ["airflow@example.com"], |
||||
"email_on_failure": False, |
||||
"email_on_retry": False, |
||||
"retries": 1, |
||||
"retry_delay": timedelta(minutes=5), |
||||
} |
||||
|
||||
dag = DAG( |
||||
"docker_sample", |
||||
default_args=default_args, |
||||
schedule_interval=timedelta(minutes=10), |
||||
start_date=days_ago(2), |
||||
) |
||||
|
||||
t1 = BashOperator(task_id="print_date", bash_command="date", dag=dag) |
||||
|
||||
t2 = BashOperator(task_id="sleep", bash_command="sleep 5", retries=3, dag=dag) |
||||
|
||||
t3 = DockerOperator( |
||||
api_version="1.19", |
||||
docker_url="tcp://localhost:2375", # Set your docker URL |
||||
command="/bin/sleep 30", |
||||
image="centos:latest", |
||||
network_mode="bridge", |
||||
task_id="docker_op_tester", |
||||
dag=dag, |
||||
) |
||||
|
||||
|
||||
t4 = BashOperator(task_id="print_hello", bash_command='echo "hello world!!!"', dag=dag) |
||||
|
||||
|
||||
t1 >> t2 |
||||
t1 >> t3 |
||||
t3 >> t4 |
||||
@ -0,0 +1,114 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
# pylint: disable=missing-function-docstring |
||||
|
||||
# [START tutorial] |
||||
# [START import_module] |
||||
import json |
||||
|
||||
from airflow.decorators import dag, task |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
# [END import_module] |
||||
|
||||
# [START default_args] |
||||
# These args will get passed on to each operator |
||||
# You can override them on a per-task basis during operator initialization |
||||
default_args = { |
||||
"owner": "airflow", |
||||
} |
||||
# [END default_args] |
||||
|
||||
|
||||
# [START instantiate_dag] |
||||
@dag( |
||||
default_args=default_args, |
||||
schedule_interval=None, |
||||
start_date=days_ago(2), |
||||
tags=["example"], |
||||
) |
||||
def tutorial_taskflow_api_etl(): |
||||
""" |
||||
### TaskFlow API Tutorial Documentation |
||||
This is a simple ETL data pipeline example which demonstrates the use of |
||||
the TaskFlow API using three simple tasks for Extract, Transform, and Load. |
||||
Documentation that goes along with the Airflow TaskFlow API tutorial is |
||||
located |
||||
[here](https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html) |
||||
""" |
||||
# [END instantiate_dag] |
||||
|
||||
# [START extract] |
||||
@task() |
||||
def extract(): |
||||
""" |
||||
#### Extract task |
||||
A simple Extract task to get data ready for the rest of the data |
||||
pipeline. In this case, getting data is simulated by reading from a |
||||
hardcoded JSON string. |
||||
""" |
||||
data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}' |
||||
|
||||
order_data_dict = json.loads(data_string) |
||||
return order_data_dict |
||||
|
||||
# [END extract] |
||||
|
||||
# [START transform] |
||||
@task(multiple_outputs=True) |
||||
def transform(order_data_dict: dict): |
||||
""" |
||||
#### Transform task |
||||
A simple Transform task which takes in the collection of order data and |
||||
computes the total order value. |
||||
""" |
||||
total_order_value = 0 |
||||
|
||||
for value in order_data_dict.values(): |
||||
total_order_value += value |
||||
|
||||
return {"total_order_value": total_order_value} |
||||
|
||||
# [END transform] |
||||
|
||||
# [START load] |
||||
@task() |
||||
def load(total_order_value: float): |
||||
""" |
||||
#### Load task |
||||
A simple Load task which takes in the result of the Transform task and |
||||
instead of saving it to end user review, just prints it out. |
||||
""" |
||||
|
||||
print(f"Total order value is: {total_order_value:.2f}") |
||||
|
||||
# [END load] |
||||
|
||||
# [START main_flow] |
||||
order_data = extract() |
||||
order_summary = transform(order_data) |
||||
load(order_summary["total_order_value"]) |
||||
# [END main_flow] |
||||
|
||||
|
||||
# [START dag_invocation] |
||||
tutorial_etl_dag = tutorial_taskflow_api_etl() |
||||
# [END dag_invocation] |
||||
|
||||
# [END tutorial] |
||||
@ -0,0 +1,61 @@ |
||||
from os import getenv |
||||
|
||||
from sqlalchemy import VARCHAR, create_engine |
||||
from sqlalchemy.engine.url import URL |
||||
|
||||
|
||||
def _psql_insert_copy(table, conn, keys, data_iter): |
||||
""" |
||||
Execute SQL statement inserting data |
||||
|
||||
Parameters |
||||
---------- |
||||
table : pandas.io.sql.SQLTable |
||||
conn : sqlalchemy.engine.Engine or sqlalchemy.engine.Connection |
||||
keys : list of str |
||||
Column names |
||||
data_iter : Iterable that iterates the values to be inserted |
||||
""" |
||||
# Alternative to_sql() *method* for DBs that support COPY FROM |
||||
import csv |
||||
from io import StringIO |
||||
|
||||
# gets a DBAPI connection that can provide a cursor |
||||
dbapi_conn = conn.connection |
||||
with dbapi_conn.cursor() as cur: |
||||
s_buf = StringIO() |
||||
writer = csv.writer(s_buf) |
||||
writer.writerows(data_iter) |
||||
s_buf.seek(0) |
||||
|
||||
columns = ', '.join('"{}"'.format(k) for k in keys) |
||||
if table.schema: |
||||
table_name = '{}.{}'.format(table.schema, table.name) |
||||
else: |
||||
table_name = table.name |
||||
|
||||
sql = 'COPY {} ({}) FROM STDIN WITH CSV'.format( |
||||
table_name, columns) |
||||
cur.copy_expert(sql=sql, file=s_buf) |
||||
|
||||
|
||||
def load_df_into_db(data_frame, schema: str, table: str) -> None: |
||||
engine = create_engine(URL( |
||||
username=getenv('DBT_POSTGRES_USER'), |
||||
password=getenv('DBT_POSTGRES_PASSWORD'), |
||||
host=getenv('DBT_POSTGRES_HOST'), |
||||
port=getenv('DBT_POSTGRES_PORT'), |
||||
database=getenv('DBT_POSTGRES_DB'), |
||||
drivername='postgres' |
||||
)) |
||||
with engine.connect() as cursor: |
||||
cursor.execute(f'CREATE SCHEMA IF NOT EXISTS {schema}') |
||||
data_frame.to_sql( |
||||
schema=schema, |
||||
name=table, |
||||
dtype=VARCHAR, |
||||
if_exists='append', |
||||
con=engine, |
||||
method=_psql_insert_copy, |
||||
index=False |
||||
) |
||||
@ -0,0 +1,126 @@ |
||||
""" |
||||
Code that goes along with the Airflow located at: |
||||
http://airflow.readthedocs.org/en/latest/tutorial.html |
||||
""" |
||||
from datetime import datetime, timedelta |
||||
from textwrap import dedent |
||||
|
||||
from airflow import DAG |
||||
|
||||
# Operators; we need this to operate! |
||||
from airflow.operators.bash import BashOperator |
||||
from airflow.operators.dummy_operator import DummyOperator |
||||
from airflow.operators.python_operator import PythonOperator, PythonVirtualenvOperator |
||||
from airflow.utils.dates import days_ago |
||||
from airflow.utils.task_group import TaskGroup |
||||
from utils.network import selenoid_status, vpn_settings |
||||
from airflow.providers.postgres.operators.postgres import PostgresOperator |
||||
from utils.notify import PushoverOperator |
||||
from utils.postgres import sqlLoad |
||||
|
||||
default_args = { |
||||
"owner": "donaldrich", |
||||
"depends_on_past": False, |
||||
"start_date": datetime(2016, 7, 13), |
||||
"email": ["email@gmail.com"], |
||||
"email_on_failure": False, |
||||
"email_on_retry": False, |
||||
"retries": 1, |
||||
"retry_delay": timedelta(minutes=15), |
||||
# "on_failure_callback": some_function, |
||||
# "on_success_callback": some_other_function, |
||||
# "on_retry_callback": another_function, |
||||
} |
||||
|
||||
# [START default_args] |
||||
# default_args = { |
||||
# 'owner': 'airflow', |
||||
# 'depends_on_past': False, |
||||
# 'email': ['airflow@example.com'], |
||||
# 'email_on_failure': False, |
||||
# 'email_on_retry': False, |
||||
# 'retries': 1, |
||||
# 'retry_delay': timedelta(minutes=5), |
||||
# 'queue': 'bash_queue', |
||||
# 'pool': 'backfill', |
||||
# 'priority_weight': 10, |
||||
# 'end_date': datetime(2016, 1, 1), |
||||
# 'wait_for_downstream': False, |
||||
# 'dag': dag, |
||||
# 'sla': timedelta(hours=2), |
||||
# 'execution_timeout': timedelta(seconds=300), |
||||
# 'on_failure_callback': some_function, |
||||
# 'on_success_callback': some_other_function, |
||||
# 'on_retry_callback': another_function, |
||||
# 'sla_miss_callback': yet_another_function, |
||||
# 'trigger_rule': 'all_success' |
||||
# } |
||||
# [END default_args] |
||||
|
||||
dag = DAG( |
||||
"recon", |
||||
default_args=default_args, |
||||
description="A simple tutorial DAG", |
||||
schedule_interval=None, |
||||
start_date=days_ago(2), |
||||
tags=["target"], |
||||
) |
||||
|
||||
with dag: |
||||
|
||||
# start = DummyOperator(task_id="start", dag=dag) |
||||
|
||||
with TaskGroup("pull_images") as start2: |
||||
|
||||
pull1 = BashOperator( |
||||
task_id="chrome_latest", |
||||
bash_command="sudo docker pull selenoid/chrome:latest", |
||||
) |
||||
|
||||
pull2 = BashOperator( |
||||
task_id="video_recorder", |
||||
bash_command= |
||||
"sudo docker pull selenoid/video-recorder:latest-release", |
||||
) |
||||
|
||||
t1 = PythonOperator( |
||||
task_id="selenoid_status", |
||||
python_callable=selenoid_status, |
||||
) |
||||
[pull1, pull2] >> t1 |
||||
start2.doc_md = dedent(""" |
||||
#### Task Documentation |
||||
You can document your task using the attributes `doc_md` (markdown), |
||||
`doc` (plain text), `doc_rst`, `doc_json`, `doc_yaml` which gets |
||||
rendered in the UI's Task Instance Details page. |
||||
 |
||||
""") |
||||
|
||||
scrape = BashOperator( |
||||
task_id="scrape_listings", |
||||
bash_command="python3 /data/scripts/scrapers/zip-scrape.py -k devops", |
||||
) |
||||
|
||||
sqlLoad = PythonOperator(task_id="sql_load", python_callable=sqlLoad) |
||||
|
||||
# get_birth_date = PostgresOperator( |
||||
# task_id="get_birth_date", |
||||
# postgres_conn_id="postgres_default", |
||||
# sql="sql/birth_date.sql", |
||||
# params={"begin_date": "2020-01-01", "end_date": "2020-12-31"}, |
||||
# ) |
||||
|
||||
cleanup = BashOperator( |
||||
task_id="cleanup", |
||||
bash_command="sudo sh /data/scripts/post-scrape.sh -s zip -k devops", |
||||
) |
||||
|
||||
success_notify = PushoverOperator( |
||||
task_id="finished", |
||||
title="Airflow Complete", |
||||
message="We did it!", |
||||
) |
||||
|
||||
end = DummyOperator(task_id="end", dag=dag) |
||||
|
||||
start2 >> scrape >> sqlLoad >> cleanup >> success_notify >> end |
||||
@ -0,0 +1,41 @@ |
||||
from datetime import datetime, timedelta |
||||
|
||||
from airflow import DAG |
||||
from airflow.operators.python_operator import PythonOperator |
||||
|
||||
# from airflow.hooks.S3_hook import S3Hook |
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook |
||||
|
||||
DEFAULT_ARGS = { |
||||
"owner": "Airflow", |
||||
"depends_on_past": False, |
||||
"start_date": datetime(2020, 1, 13), |
||||
"email": ["airflow@example.com"], |
||||
"email_on_failure": False, |
||||
"email_on_retry": False, |
||||
"retries": 1, |
||||
"retry_delay": timedelta(minutes=5), |
||||
} |
||||
|
||||
dag = DAG("create_date_dimension", default_args=DEFAULT_ARGS, schedule_interval="@once") |
||||
|
||||
# Create a task to call your processing function |
||||
def write_text_file(ds, **kwargs): |
||||
with open("/data/data/temp.txt", "w") as fp: |
||||
# Add file generation/processing step here, E.g.: |
||||
# fp.write(ds) |
||||
|
||||
# Upload generated file to Minio |
||||
# s3 = S3Hook('local_minio') |
||||
s3 = S3Hook(aws_conn_id="minio") |
||||
s3.load_file("/data/data/temp.txt", key=f"my-test-file.txt", bucket_name="airflow") |
||||
|
||||
|
||||
# Create a task to call your processing function |
||||
|
||||
t1 = PythonOperator( |
||||
task_id="generate_and_upload_to_s3", |
||||
provide_context=True, |
||||
python_callable=write_text_file, |
||||
dag=dag, |
||||
) |
||||
@ -0,0 +1,154 @@ |
||||
""" |
||||
Headless Site Navigation and File Download (Using Selenium) to S3 |
||||
|
||||
This example demonstrates using Selenium (via Firefox/GeckoDriver) to: |
||||
1) Log into a website w/ credentials stored in connection labeled 'selenium_conn_id' |
||||
2) Download a file (initiated on login) |
||||
3) Transform the CSV into JSON formatting |
||||
4) Append the current data to each record |
||||
5) Load the corresponding file into S3 |
||||
|
||||
To use this DAG, you will need to have the following installed: |
||||
[XVFB](https://www.x.org/archive/X11R7.6/doc/man/man1/Xvfb.1.xhtml) |
||||
[GeckoDriver](https://github.com/mozilla/geckodriver/releases/download) |
||||
|
||||
selenium==3.11.0 |
||||
xvfbwrapper==0.2.9 |
||||
""" |
||||
# import boa |
||||
import csv |
||||
import json |
||||
import logging |
||||
import os |
||||
import time |
||||
from datetime import datetime, timedelta |
||||
|
||||
from airflow import DAG |
||||
from airflow.hooks import S3Hook |
||||
from airflow.models import Connection |
||||
from airflow.operators.dummy_operator import DummyOperator |
||||
from airflow.operators.python_operator import PythonOperator |
||||
from airflow.utils.db import provide_session |
||||
from selenium import webdriver |
||||
from selenium.webdriver.common.keys import Keys |
||||
from selenium.webdriver.firefox.options import Options |
||||
from xvfbwrapper import Xvfb |
||||
|
||||
S3_CONN_ID = "" |
||||
S3_BUCKET = "" |
||||
S3_KEY = "" |
||||
|
||||
date = "{{ ds }}" |
||||
|
||||
default_args = { |
||||
"start_date": datetime(2018, 2, 10, 0, 0), |
||||
"email": [], |
||||
"email_on_failure": True, |
||||
"email_on_retry": False, |
||||
"retries": 2, |
||||
"retry_delay": timedelta(minutes=5), |
||||
"catchup": False, |
||||
} |
||||
|
||||
dag = DAG( |
||||
"selenium_extraction_to_s3", |
||||
schedule_interval="@daily", |
||||
default_args=default_args, |
||||
catchup=False, |
||||
) |
||||
|
||||
|
||||
def imap_py(**kwargs): |
||||
selenium_conn_id = kwargs.get("templates_dict", None).get("selenium_conn_id", None) |
||||
filename = kwargs.get("templates_dict", None).get("filename", None) |
||||
s3_conn_id = kwargs.get("templates_dict", None).get("s3_conn_id", None) |
||||
s3_bucket = kwargs.get("templates_dict", None).get("s3_bucket", None) |
||||
s3_key = kwargs.get("templates_dict", None).get("s3_key", None) |
||||
date = kwargs.get("templates_dict", None).get("date", None) |
||||
|
||||
@provide_session |
||||
def get_conn(conn_id, session=None): |
||||
conn = session.query(Connection).filter(Connection.conn_id == conn_id).first() |
||||
return conn |
||||
|
||||
url = get_conn(selenium_conn_id).host |
||||
email = get_conn(selenium_conn_id).user |
||||
pwd = get_conn(selenium_conn_id).password |
||||
|
||||
vdisplay = Xvfb() |
||||
vdisplay.start() |
||||
caps = webdriver.DesiredCapabilities.FIREFOX |
||||
caps["marionette"] = True |
||||
|
||||
profile = webdriver.FirefoxProfile() |
||||
profile.set_preference("browser.download.manager.showWhenStarting", False) |
||||
profile.set_preference("browser.helperApps.neverAsk.saveToDisk", "text/csv") |
||||
|
||||
logging.info("Profile set...") |
||||
options = Options() |
||||
options.set_headless(headless=True) |
||||
logging.info("Options set...") |
||||
logging.info("Initializing Driver...") |
||||
driver = webdriver.Firefox( |
||||
firefox_profile=profile, firefox_options=options, capabilities=caps |
||||
) |
||||
logging.info("Driver Intialized...") |
||||
driver.get(url) |
||||
logging.info("Authenticating...") |
||||
elem = driver.find_element_by_id("email") |
||||
elem.send_keys(email) |
||||
elem = driver.find_element_by_id("password") |
||||
elem.send_keys(pwd) |
||||
elem.send_keys(Keys.RETURN) |
||||
|
||||
logging.info("Successfully authenticated.") |
||||
|
||||
sleep_time = 15 |
||||
|
||||
logging.info("Downloading File....Sleeping for {} Seconds.".format(str(sleep_time))) |
||||
time.sleep(sleep_time) |
||||
|
||||
driver.close() |
||||
vdisplay.stop() |
||||
|
||||
dest_s3 = S3Hook(s3_conn_id=s3_conn_id) |
||||
|
||||
os.chdir("/root/Downloads") |
||||
|
||||
csvfile = open(filename, "r") |
||||
|
||||
output_json = "file.json" |
||||
|
||||
with open(output_json, "w") as jsonfile: |
||||
reader = csv.DictReader(csvfile) |
||||
|
||||
for row in reader: |
||||
# row = dict((boa.constrict(k), v) for k, v in row.items()) |
||||
row["run_date"] = date |
||||
json.dump(row, jsonfile) |
||||
jsonfile.write("\n") |
||||
|
||||
dest_s3.load_file( |
||||
filename=output_json, key=s3_key, bucket_name=s3_bucket, replace=True |
||||
) |
||||
|
||||
dest_s3.connection.close() |
||||
|
||||
|
||||
with dag: |
||||
|
||||
kick_off_dag = DummyOperator(task_id="kick_off_dag") |
||||
|
||||
selenium = PythonOperator( |
||||
task_id="selenium_retrieval_to_s3", |
||||
python_callable=imap_py, |
||||
templates_dict={ |
||||
"s3_conn_id": S3_CONN_ID, |
||||
"s3_bucket": S3_BUCKET, |
||||
"s3_key": S3_KEY, |
||||
"date": date, |
||||
}, |
||||
provide_context=True, |
||||
) |
||||
|
||||
kick_off_dag >> selenium |
||||
@ -0,0 +1,65 @@ |
||||
from datetime import timedelta |
||||
|
||||
from airflow import DAG |
||||
from airflow.operators.bash import BashOperator |
||||
from airflow.operators.dummy_operator import DummyOperator |
||||
from airflow.providers.docker.operators.docker import DockerOperator |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
default_args = { |
||||
"owner": "airflow", |
||||
"depends_on_past": False, |
||||
"email": ["airflow@example.com"], |
||||
"email_on_failure": False, |
||||
"email_on_retry": False, |
||||
"retries": 1, |
||||
"retry_delay": timedelta(minutes=5), |
||||
} |
||||
|
||||
dag = DAG( |
||||
"selenoid_setup", |
||||
default_args=default_args, |
||||
schedule_interval="@daily", |
||||
start_date=days_ago(2), |
||||
) |
||||
|
||||
|
||||
with dag: |
||||
|
||||
kick_off_dag = DummyOperator(task_id="kick_off_dag") |
||||
|
||||
t1 = DockerOperator( |
||||
# api_version='1.19', |
||||
# docker_url='tcp://localhost:2375', # Set your docker URL |
||||
command="/bin/sleep 30", |
||||
image="selenoid/chrome:latest", |
||||
network_mode="bridge", |
||||
task_id="selenoid1", |
||||
# dag=dag, |
||||
) |
||||
|
||||
t2 = DockerOperator( |
||||
# api_version='1.19', |
||||
# docker_url='tcp://localhost:2375', # Set your docker URL |
||||
command="/bin/sleep 30", |
||||
image="selenoid/video-recorder:latest-release", |
||||
network_mode="bridge", |
||||
task_id="selenoid2", |
||||
# dag=dag, |
||||
) |
||||
|
||||
scrape = BashOperator( |
||||
task_id="pull_selenoid_video-recorder", |
||||
bash_command="docker pull selenoid/video-recorder:latest-release", |
||||
# retries=3, |
||||
# dag=dag |
||||
) |
||||
|
||||
scrape2 = BashOperator( |
||||
task_id="pull_selenoid_chrome", |
||||
bash_command="docker pull selenoid/chrome:latest", |
||||
# retries=3, |
||||
# dag=dag |
||||
) |
||||
|
||||
scrape >> t1 >> t2 |
||||
@ -0,0 +1 @@ |
||||
SELECT * FROM pet WHERE birth_date BETWEEN SYMMETRIC {{ params.begin_date }} AND {{ params.end_date }}; |
||||
@ -0,0 +1,122 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
""" |
||||
This is an example DAG for the use of the SqliteOperator. |
||||
In this example, we create two tasks that execute in sequence. |
||||
The first task calls an sql command, defined in the SQLite operator, |
||||
which when triggered, is performed on the connected sqlite database. |
||||
The second task is similar but instead calls the SQL command from an external file. |
||||
""" |
||||
import apprise |
||||
from airflow import DAG |
||||
from airflow.operators.bash import BashOperator |
||||
from airflow.operators.dummy_operator import DummyOperator |
||||
from airflow.operators.python_operator import PythonOperator |
||||
from airflow.providers.docker.operators.docker import DockerOperator |
||||
from airflow.providers.sqlite.operators.sqlite import SqliteOperator |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
default_args = {"owner": "airflow"} |
||||
|
||||
dag = DAG( |
||||
dag_id="example_sqlite", |
||||
default_args=default_args, |
||||
schedule_interval="@hourly", |
||||
start_date=days_ago(2), |
||||
tags=["example"], |
||||
) |
||||
|
||||
with dag: |
||||
|
||||
# [START howto_operator_sqlite] |
||||
|
||||
# Example of creating a task that calls a common CREATE TABLE sql command. |
||||
# def notify(): |
||||
|
||||
# # Create an Apprise instance |
||||
# apobj = apprise.Apprise() |
||||
|
||||
# # Add all of the notification services by their server url. |
||||
# # A sample email notification: |
||||
# # apobj.add('mailto://myuserid:mypass@gmail.com') |
||||
|
||||
# # A sample pushbullet notification |
||||
# apobj.add('pover://aejghiy6af1bshe8mbdksmkzeon3ip@umjiu36dxwwaj8pnfx3n6y2xbm3ssx') |
||||
|
||||
# # Then notify these services any time you desire. The below would |
||||
# # notify all of the services loaded into our Apprise object. |
||||
# apobj.notify( |
||||
# body='what a great notification service!', |
||||
# title='my notification title', |
||||
# ) |
||||
# return apobj |
||||
|
||||
# apprise = PythonOperator( |
||||
# task_id="apprise", |
||||
# python_callable=notify, |
||||
# dag=dag |
||||
# # "s3_bucket": S3_BUCKET, |
||||
# # "s3_key": S3_KEY, |
||||
# # "date": date} |
||||
# # provide_context=True |
||||
# ) |
||||
|
||||
# t2 = DockerOperator( |
||||
# task_id='docker_command', |
||||
# image='selenoid/chrome:latest', |
||||
# api_version='auto', |
||||
# auto_remove=False, |
||||
# command="/bin/sleep 30", |
||||
# docker_url="unix://var/run/docker.sock", |
||||
# network_mode="bridge" |
||||
# ) |
||||
|
||||
start = DummyOperator(task_id="start") |
||||
|
||||
docker = BashOperator( |
||||
task_id="pull_selenoid_2", |
||||
bash_command="sudo docker pull selenoid/chrome:latest", |
||||
# retries=3, |
||||
# dag=dag |
||||
) |
||||
|
||||
pre2 = BashOperator( |
||||
task_id="apprise", |
||||
bash_command="apprise -vv -t 'my title' -b 'my notification body' pover://umjiu36dxwwaj8pnfx3n6y2xbm3ssx@aejghiy6af1bshe8mbdksmkzeon3ip", |
||||
# retries=3, |
||||
# dag=dag |
||||
) |
||||
|
||||
# t3 = DockerOperator( |
||||
# api_version="1.19", |
||||
# docker_url="tcp://localhost:2375", # Set your docker URL |
||||
# command="/bin/sleep 30", |
||||
# image="centos:latest", |
||||
# network_mode="bridge", |
||||
# task_id="docker_op_tester", |
||||
# # dag=dag, |
||||
# ) |
||||
|
||||
end = DummyOperator(task_id="end") |
||||
|
||||
|
||||
start >> docker >> pre2 >> end |
||||
# [END howto_operator_sqlite_external_file] |
||||
|
||||
# create_table_sqlite_task >> external_create_table_sqlite_task |
||||
# |
||||
@ -0,0 +1,137 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
# pylint: disable=missing-function-docstring |
||||
""" |
||||
### ETL DAG Tutorial Documentation |
||||
This ETL DAG is compatible with Airflow 1.10.x (specifically tested with 1.10.12) and is referenced |
||||
as part of the documentation that goes along with the Airflow Functional DAG tutorial located |
||||
[here](https://airflow.apache.org/tutorial_decorated_flows.html) |
||||
""" |
||||
# [START tutorial] |
||||
# [START import_module] |
||||
import json |
||||
from textwrap import dedent |
||||
|
||||
# The DAG object; we'll need this to instantiate a DAG |
||||
from airflow import DAG |
||||
|
||||
# Operators; we need this to operate! |
||||
from airflow.operators.python import PythonOperator |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
# [END import_module] |
||||
|
||||
# [START default_args] |
||||
# These args will get passed on to each operator |
||||
# You can override them on a per-task basis during operator initialization |
||||
default_args = { |
||||
"owner": "donaldrich", |
||||
"email": ["email@gmail.com"], |
||||
} |
||||
# [END default_args] |
||||
|
||||
# [START instantiate_dag] |
||||
with DAG( |
||||
"tutorial_etl_dag", |
||||
default_args=default_args, |
||||
description="ETL DAG tutorial", |
||||
schedule_interval=None, |
||||
start_date=days_ago(2), |
||||
tags=["example"], |
||||
) as dag: |
||||
# [END instantiate_dag] |
||||
# [START documentation] |
||||
dag.doc_md = __doc__ |
||||
|
||||
# [END documentation] |
||||
|
||||
# [START extract_function] |
||||
def extract(**kwargs): |
||||
ti = kwargs["ti"] |
||||
data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}' |
||||
ti.xcom_push("order_data", data_string) |
||||
|
||||
# [END extract_function] |
||||
|
||||
# [START transform_function] |
||||
def transform(**kwargs): |
||||
ti = kwargs["ti"] |
||||
extract_data_string = ti.xcom_pull(task_ids="extract", |
||||
key="order_data") |
||||
order_data = json.loads(extract_data_string) |
||||
|
||||
total_order_value = 0 |
||||
for value in order_data.values(): |
||||
total_order_value += value |
||||
|
||||
total_value = {"total_order_value": total_order_value} |
||||
total_value_json_string = json.dumps(total_value) |
||||
ti.xcom_push("total_order_value", total_value_json_string) |
||||
|
||||
# [END transform_function] |
||||
|
||||
# [START load_function] |
||||
def load(**kwargs): |
||||
ti = kwargs["ti"] |
||||
total_value_string = ti.xcom_pull(task_ids="transform", |
||||
key="total_order_value") |
||||
total_order_value = json.loads(total_value_string) |
||||
|
||||
print(total_order_value) |
||||
|
||||
# [END load_function] |
||||
|
||||
# [START main_flow] |
||||
extract_task = PythonOperator( |
||||
task_id="extract", |
||||
python_callable=extract, |
||||
) |
||||
extract_task.doc_md = dedent("""\ |
||||
#### Extract task |
||||
A simple Extract task to get data ready for the rest of the data pipeline. |
||||
In this case, getting data is simulated by reading from a hardcoded JSON string. |
||||
This data is then put into xcom, so that it can be processed by the next task. |
||||
""") |
||||
|
||||
transform_task = PythonOperator( |
||||
task_id="transform", |
||||
python_callable=transform, |
||||
) |
||||
transform_task.doc_md = dedent("""\ |
||||
#### Transform task |
||||
A simple Transform task which takes in the collection of order data from xcom |
||||
and computes the total order value. |
||||
This computed value is then put into xcom, so that it can be processed by the next task. |
||||
""") |
||||
|
||||
load_task = PythonOperator( |
||||
task_id="load", |
||||
python_callable=load, |
||||
) |
||||
load_task.doc_md = dedent("""\ |
||||
#### Load task |
||||
A simple Load task which takes in the result of the Transform task, by reading it |
||||
from xcom and instead of saving it to end user review, just prints it out. |
||||
""") |
||||
|
||||
extract_task >> transform_task >> load_task |
||||
|
||||
# [END main_flow] |
||||
|
||||
# [END tutorial] |
||||
@ -0,0 +1,32 @@ |
||||
import os |
||||
from datetime import datetime |
||||
from os import environ |
||||
|
||||
from airflow import DAG |
||||
from airflow.hooks.base_hook import BaseHook |
||||
from airflow.operators.python_operator import PythonOperator |
||||
|
||||
os.environ[ |
||||
"AIRFLOW__SECRETS__BACKEND" |
||||
] = "airflow.providers.hashicorp.secrets.vault.VaultBackend" |
||||
os.environ[ |
||||
"AIRFLOW__SECRETS__BACKEND_KWARGS"] = '{"connections_path": "myapp", "mount_point": "secret", "auth_type": "token", "token": "token", "url": "http://vault:8200"}' |
||||
|
||||
|
||||
def get_secrets(**kwargs): |
||||
conn = BaseHook.get_connection(kwargs["my_conn_id"]) |
||||
print("Password:", {conn.password}) |
||||
print(" Login:", {conn.login}) |
||||
print(" URI:", {conn.get_uri()}) |
||||
print("Host:", {conn.host}) |
||||
|
||||
|
||||
with DAG( |
||||
"vault_example", start_date=datetime(2020, 1, 1), schedule_interval=None |
||||
) as dag: |
||||
|
||||
test_task = PythonOperator( |
||||
task_id="test-task", |
||||
python_callable=get_secrets, |
||||
op_kwargs={"my_conn_id": "smtp_default"}, |
||||
) |
||||
@ -0,0 +1,186 @@ |
||||
version: "3.7" |
||||
|
||||
x-airflow-common: &airflow-common |
||||
image: donaldrich/airflow:latest |
||||
# build: . |
||||
environment: &airflow-common-env |
||||
AIRFLOW__CORE__STORE_SERIALIZED_DAGS: "True" |
||||
AIRFLOW__CORE__STORE_DAG_CODE: "True" |
||||
AIRFLOW__CORE__EXECUTOR: "CeleryExecutor" |
||||
AIRFLOW__CORE__SQL_ALCHEMY_CONN: postgresql+psycopg2://airflow:airflow@postgres-dev:5432/airflow |
||||
AIRFLOW__CELERY__RESULT_BACKEND: db+postgresql://airflow:airflow@postgres-dev:5432/airflow |
||||
AIRFLOW__CELERY__BROKER_URL: redis://:@redis:6379/0 |
||||
AIRFLOW__CORE__FERNET_KEY: "" |
||||
AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION: "false" |
||||
AIRFLOW__CORE__LOAD_EXAMPLES: |
||||
"false" |
||||
# AIRFLOW__CORE__PARALLELISM: 4 |
||||
# AIRFLOW__CORE__DAG_CONCURRENCY: 4 |
||||
# AIRFLOW__CORE__MAX_ACTIVE_RUNS_PER_DAG: 4 |
||||
AIRFLOW_UID: "1000" |
||||
AIRFLOW_GID: "0" |
||||
_AIRFLOW_DB_UPGRADE: "true" |
||||
_AIRFLOW_WWW_USER_CREATE: "true" |
||||
_AIRFLOW_WWW_USER_USERNAME: "{{ user }}" |
||||
_AIRFLOW_WWW_USER_PASSWORD: "{{ password }}" |
||||
PYTHONPATH: "/data:$$PYTHONPATH" |
||||
user: "1000:0" |
||||
volumes: |
||||
- "./airflow:/opt/airflow" |
||||
- "./data:/data" |
||||
- "/var/run/docker.sock:/var/run/docker.sock" |
||||
# - "/usr/bin/docker:/bin/docker:ro" |
||||
networks: |
||||
- proxy |
||||
- backend |
||||
|
||||
services: |
||||
airflow-init: |
||||
<<: *airflow-common |
||||
container_name: airflow-init |
||||
environment: |
||||
<<: *airflow-common-env |
||||
# depends_on: |
||||
# - airflow-db |
||||
command: bash -c "airflow db init && airflow db upgrade && airflow users create --role Admin --username {{ user }} --email {{ email }} --firstname Don --lastname Aldrich --password {{ password }}" |
||||
|
||||
airflow-webserver: |
||||
<<: *airflow-common |
||||
container_name: airflow-webserver |
||||
hostname: airflow-webserver |
||||
command: |
||||
webserver |
||||
# command: > |
||||
# bash -c 'if [[ -z "$$AIRFLOW__API__AUTH_BACKEND" ]] && [[ $$(pip show -f apache-airflow | grep basic_auth.py) ]]; |
||||
# then export AIRFLOW__API__AUTH_BACKEND=airflow.api.auth.backend.basic_auth ; |
||||
# else export AIRFLOW__API__AUTH_BACKEND=airflow.api.auth.backend.default ; fi && |
||||
# { airflow create_user "$$@" || airflow users create "$$@" ; } && |
||||
# { airflow sync_perm || airflow sync-perm ;} && |
||||
# airflow webserver' -- -r Admin -u admin -e admin@example.com -f admin -l user -p admin |
||||
healthcheck: |
||||
test: ["CMD", "curl", "--fail", "http://localhost:8080/health"] |
||||
interval: 10s |
||||
timeout: 10s |
||||
retries: 5 |
||||
restart: always |
||||
privileged: true |
||||
depends_on: |
||||
- airflow-scheduler |
||||
environment: |
||||
<<: *airflow-common-env |
||||
labels: |
||||
traefik.http.services.airflow.loadbalancer.server.port: "8080" |
||||
traefik.enable: "true" |
||||
traefik.http.routers.airflow.entrypoints: "https" |
||||
traefik.http.routers.airflow.tls.certResolver: "cloudflare" |
||||
traefik.http.routers.airflow.rule: "Host(`airflow.{{ domain }}.com`)" |
||||
traefik.http.routers.airflow.middlewares: "ip-whitelist@file" |
||||
traefik.http.routers.airflow.service: "airflow" |
||||
|
||||
airflow-scheduler: |
||||
<<: *airflow-common |
||||
command: scheduler |
||||
container_name: airflow-scheduler |
||||
hostname: airflow-scheduler |
||||
restart: always |
||||
depends_on: |
||||
- airflow-init |
||||
environment: |
||||
<<: *airflow-common-env |
||||
# # entrypoint: sh -c '/app/scripts/wait-for postgres:5432 -- airflow db init && airflow scheduler' |
||||
|
||||
airflow-worker: |
||||
<<: *airflow-common |
||||
command: celery worker |
||||
restart: always |
||||
container_name: airflow-worker |
||||
hostname: airflow-worker |
||||
|
||||
airflow-queue: |
||||
<<: *airflow-common |
||||
command: celery flower |
||||
container_name: airflow-queue |
||||
hostname: airflow-queue |
||||
ports: |
||||
- 5555:5555 |
||||
healthcheck: |
||||
test: ["CMD", "curl", "--fail", "http://localhost:5555/"] |
||||
interval: 10s |
||||
timeout: 10s |
||||
retries: 5 |
||||
restart: always |
||||
|
||||
dbt: |
||||
image: fishtownanalytics/dbt:0.19.1 |
||||
container_name: dbt |
||||
volumes: |
||||
- "/home/{{ user }}/projects/jobfunnel:/data" |
||||
- "/home/{{ user }}/projects/jobfunnel/dbt:/usr/app" |
||||
# - "dbt-db:/var/lib/postgresql/data" |
||||
ports: |
||||
- "8081" |
||||
networks: |
||||
- proxy |
||||
- backend |
||||
command: docs serve --project-dir /data/transform |
||||
working_dir: "/data/transform" |
||||
environment: |
||||
DBT_SCHEMA: dbt |
||||
DBT_RAW_DATA_SCHEMA: dbt_raw_data |
||||
DBT_PROFILES_DIR: "/data/transform/profile" |
||||
# DBT_PROJECT_DIR: "/data/transform" |
||||
# DBT_PROFILES_DIR: "/data" |
||||
# DBT_POSTGRES_PASSWORD: dbt |
||||
# DBT_POSTGRES_USER : dbt |
||||
# DBT_POSTGRES_DB : dbt |
||||
# DBT_DBT_SCHEMA: dbt |
||||
# DBT_DBT_RAW_DATA_SCHEMA: dbt_raw_data |
||||
# DBT_POSTGRES_HOST: dbt-db |
||||
|
||||
meltano: |
||||
image: meltano/meltano:latest |
||||
container_name: meltano |
||||
volumes: |
||||
- "/home/{{ user }}/projects/jobfunnel:/project" |
||||
ports: |
||||
- "5000:5000" |
||||
networks: |
||||
- proxy |
||||
- backend |
||||
# environment: |
||||
# MELTANO_UI_SERVER_NAME: "etl.{{ domain }}.com" |
||||
# MELTANO_DATABASE_URI: "postgresql://meltano:meltano@meltano-db:5432/meltano" |
||||
# MELTANO_WEBAPP_POSTGRES_URL: localhost |
||||
# MELTANO_WEBAPP_POSTGRES_DB=meltano |
||||
# MELTANO_WEBAPP_POSTGRES_USER=meltano |
||||
# MELTANO_WEBAPP_POSTGRES_PASSWORD=meltano |
||||
# MELTANO_WEBAPP_POSTGRES_PORT=5501 |
||||
# MELTANO_WEBAPP_LOG_PATH="/tmp/meltano.log" |
||||
# MELTANO_WEBAPP_API_URL="http://localhost:5000" |
||||
# LOG_PATH="/tmp/meltano.log" |
||||
# API_URL="http://localhost:5000" |
||||
# MELTANO_MODEL_DIR="./model" |
||||
# MELTANO_TRANSFORM_DIR="./transform" |
||||
# MELTANO_UI_SESSION_COOKIE_DOMAIN: "etl.{{ domain }}.com" |
||||
# MELTANO_UI_SESSION_COOKIE_SECURE: "true" |
||||
labels: |
||||
traefik.http.services.meltano.loadbalancer.server.port: "5000" |
||||
traefik.enable: "true" |
||||
traefik.http.routers.meltano.entrypoints: "https" |
||||
traefik.http.routers.meltano.tls.certResolver: "cloudflare" |
||||
traefik.http.routers.meltano.rule: "Host(`etl.{{ domain }}.com`)" |
||||
traefik.http.routers.meltano.middlewares: "secured@file,ip-whitelist@file" |
||||
traefik.http.routers.meltano.service: "meltano" |
||||
|
||||
volumes: |
||||
tmp_airflow: |
||||
driver: local |
||||
airflow-db: |
||||
driver: local |
||||
dbt-db: |
||||
driver: local |
||||
networks: |
||||
proxy: |
||||
external: true |
||||
backend: |
||||
external: true |
||||
@ -0,0 +1,49 @@ |
||||
import csv |
||||
import logging |
||||
import os |
||||
import platform |
||||
import random |
||||
import re |
||||
import time |
||||
|
||||
# from webdriver_manager.chrome import ChromeDriverManager |
||||
from datetime import datetime |
||||
from urllib.request import urlopen |
||||
|
||||
# import pyautogui |
||||
import pandas as pd |
||||
import yaml |
||||
from bs4 import BeautifulSoup |
||||
|
||||
# import mouseinfo |
||||
from selenium import webdriver |
||||
from selenium.common.exceptions import NoSuchElementException, TimeoutException |
||||
from selenium.webdriver.chrome.options import Options |
||||
from selenium.webdriver.common.by import By |
||||
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities |
||||
from selenium.webdriver.common.keys import Keys |
||||
from selenium.webdriver.support import expected_conditions as EC |
||||
from selenium.webdriver.support.ui import WebDriverWait |
||||
|
||||
|
||||
def setupLogger(): |
||||
dt = datetime.strftime(datetime.now(), "%m_%d_%y %H_%M_%S ") |
||||
|
||||
if not os.path.isdir("./logs"): |
||||
os.mkdir("./logs") |
||||
|
||||
# TODO need to check if there is a log dir available or not |
||||
logging.basicConfig( |
||||
filename=("./logs/" + str(dt) + "applyJobs.log"), |
||||
filemode="w", |
||||
format="%(asctime)s::%(name)s::%(levelname)s::%(message)s", |
||||
datefmt="./logs/%d-%b-%y %H:%M:%S", |
||||
) |
||||
log.setLevel(logging.DEBUG) |
||||
c_handler = logging.StreamHandler() |
||||
c_handler.setLevel(logging.DEBUG) |
||||
c_format = logging.Formatter( |
||||
"%(asctime)s - %(levelname)s - %(message)s", "%H:%M:%S" |
||||
) |
||||
c_handler.setFormatter(c_format) |
||||
log.addHandler(c_handler) |
||||
@ -0,0 +1,145 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
# |
||||
import json |
||||
import re |
||||
from typing import Any, Dict, Optional |
||||
|
||||
from airflow.exceptions import AirflowException |
||||
from airflow.providers.http.hooks.http import HttpHook |
||||
|
||||
|
||||
class DiscordWebhookHook(HttpHook): |
||||
""" |
||||
This hook allows you to post messages to Discord using incoming webhooks. |
||||
Takes a Discord connection ID with a default relative webhook endpoint. The |
||||
default endpoint can be overridden using the webhook_endpoint parameter |
||||
(https://discordapp.com/developers/docs/resources/webhook). |
||||
Each Discord webhook can be pre-configured to use a specific username and |
||||
avatar_url. You can override these defaults in this hook. |
||||
:param http_conn_id: Http connection ID with host as "https://discord.com/api/" and |
||||
default webhook endpoint in the extra field in the form of |
||||
{"webhook_endpoint": "webhooks/{webhook.id}/{webhook.token}"} |
||||
:type http_conn_id: str |
||||
:param webhook_endpoint: Discord webhook endpoint in the form of |
||||
"webhooks/{webhook.id}/{webhook.token}" |
||||
:type webhook_endpoint: str |
||||
:param message: The message you want to send to your Discord channel |
||||
(max 2000 characters) |
||||
:type message: str |
||||
:param username: Override the default username of the webhook |
||||
:type username: str |
||||
:param avatar_url: Override the default avatar of the webhook |
||||
:type avatar_url: str |
||||
:param tts: Is a text-to-speech message |
||||
:type tts: bool |
||||
:param proxy: Proxy to use to make the Discord webhook call |
||||
:type proxy: str |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
http_conn_id: Optional[str] = None, |
||||
webhook_endpoint: Optional[str] = None, |
||||
message: str = "", |
||||
username: Optional[str] = None, |
||||
avatar_url: Optional[str] = None, |
||||
tts: bool = False, |
||||
proxy: Optional[str] = None, |
||||
*args: Any, |
||||
**kwargs: Any, |
||||
) -> None: |
||||
super().__init__(*args, **kwargs) |
||||
self.http_conn_id: Any = http_conn_id |
||||
self.webhook_endpoint = self._get_webhook_endpoint( |
||||
http_conn_id, webhook_endpoint |
||||
) |
||||
self.message = message |
||||
self.username = username |
||||
self.avatar_url = avatar_url |
||||
self.tts = tts |
||||
self.proxy = proxy |
||||
|
||||
def _get_webhook_endpoint( |
||||
self, http_conn_id: Optional[str], webhook_endpoint: Optional[str] |
||||
) -> str: |
||||
""" |
||||
Given a Discord http_conn_id, return the default webhook endpoint or override if a |
||||
webhook_endpoint is manually supplied. |
||||
:param http_conn_id: The provided connection ID |
||||
:param webhook_endpoint: The manually provided webhook endpoint |
||||
:return: Webhook endpoint (str) to use |
||||
""" |
||||
if webhook_endpoint: |
||||
endpoint = webhook_endpoint |
||||
elif http_conn_id: |
||||
conn = self.get_connection(http_conn_id) |
||||
extra = conn.extra_dejson |
||||
endpoint = extra.get("webhook_endpoint", "") |
||||
else: |
||||
raise AirflowException( |
||||
"Cannot get webhook endpoint: No valid Discord webhook endpoint or http_conn_id supplied." |
||||
) |
||||
|
||||
# make sure endpoint matches the expected Discord webhook format |
||||
if not re.match("^webhooks/[0-9]+/[a-zA-Z0-9_-]+$", endpoint): |
||||
raise AirflowException( |
||||
'Expected Discord webhook endpoint in the form of "webhooks/{webhook.id}/{webhook.token}".' |
||||
) |
||||
|
||||
return endpoint |
||||
|
||||
def _build_discord_payload(self) -> str: |
||||
""" |
||||
Construct the Discord JSON payload. All relevant parameters are combined here |
||||
to a valid Discord JSON payload. |
||||
:return: Discord payload (str) to send |
||||
""" |
||||
payload: Dict[str, Any] = {} |
||||
|
||||
if self.username: |
||||
payload["username"] = self.username |
||||
if self.avatar_url: |
||||
payload["avatar_url"] = self.avatar_url |
||||
|
||||
payload["tts"] = self.tts |
||||
|
||||
if len(self.message) <= 2000: |
||||
payload["content"] = self.message |
||||
else: |
||||
raise AirflowException( |
||||
"Discord message length must be 2000 or fewer characters." |
||||
) |
||||
|
||||
return json.dumps(payload) |
||||
|
||||
def execute(self) -> None: |
||||
"""Execute the Discord webhook call""" |
||||
proxies = {} |
||||
if self.proxy: |
||||
# we only need https proxy for Discord |
||||
proxies = {"https": self.proxy} |
||||
|
||||
discord_payload = self._build_discord_payload() |
||||
|
||||
self.run( |
||||
endpoint=self.webhook_endpoint, |
||||
data=discord_payload, |
||||
headers={"Content-type": "application/json"}, |
||||
extra_options={"proxies": proxies}, |
||||
) |
||||
@ -0,0 +1,108 @@ |
||||
""" |
||||
There are two ways to authenticate the Google Analytics Hook. |
||||
|
||||
If you have already obtained an OAUTH token, place it in the password field |
||||
of the relevant connection. |
||||
|
||||
If you don't have an OAUTH token, you may authenticate by passing a |
||||
'client_secrets' object to the extras section of the relevant connection. This |
||||
object will expect the following fields and use them to generate an OAUTH token |
||||
on execution. |
||||
|
||||
"type": "service_account", |
||||
"project_id": "example-project-id", |
||||
"private_key_id": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", |
||||
"private_key": "-----BEGIN PRIVATE KEY-----\nXXXXX\n-----END PRIVATE KEY-----\n", |
||||
"client_email": "google-analytics@{PROJECT_ID}.iam.gserviceaccount.com", |
||||
"client_id": "XXXXXXXXXXXXXXXXXXXXXX", |
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth", |
||||
"token_uri": "https://accounts.google.com/o/oauth2/token", |
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", |
||||
"client_x509_cert_url": "{CERT_URL}" |
||||
|
||||
More details can be found here: |
||||
https://developers.google.com/api-client-library/python/guide/aaa_client_secrets |
||||
""" |
||||
|
||||
import time |
||||
|
||||
from airflow.hooks.base_hook import BaseHook |
||||
from apiclient.discovery import build |
||||
from oauth2client.client import AccessTokenCredentials |
||||
from oauth2client.service_account import ServiceAccountCredentials |
||||
|
||||
|
||||
class GoogleHook(BaseHook): |
||||
def __init__(self, google_conn_id="google_default"): |
||||
self.google_analytics_conn_id = google_conn_id |
||||
self.connection = self.get_connection(google_conn_id) |
||||
|
||||
if self.connection.extra_dejson.get("client_secrets", None): |
||||
self.client_secrets = self.connection.extra_dejson["client_secrets"] |
||||
|
||||
def get_service_object(self, api_name, api_version, scopes=None): |
||||
if self.connection.password: |
||||
credentials = AccessTokenCredentials( |
||||
self.connection.password, "Airflow/1.0" |
||||
) |
||||
elif self.client_secrets: |
||||
credentials = ServiceAccountCredentials.from_json_keyfile_dict( |
||||
self.client_secrets, scopes |
||||
) |
||||
|
||||
return build(api_name, api_version, credentials=credentials) |
||||
|
||||
def get_analytics_report( |
||||
self, |
||||
view_id, |
||||
since, |
||||
until, |
||||
sampling_level, |
||||
dimensions, |
||||
metrics, |
||||
page_size, |
||||
include_empty_rows, |
||||
): |
||||
analytics = self.get_service_object( |
||||
"analyticsreporting", |
||||
"v4", |
||||
["https://www.googleapis.com/auth/analytics.readonly"], |
||||
) |
||||
|
||||
reportRequest = { |
||||
"viewId": view_id, |
||||
"dateRanges": [{"startDate": since, "endDate": until}], |
||||
"samplingLevel": sampling_level or "LARGE", |
||||
"dimensions": dimensions, |
||||
"metrics": metrics, |
||||
"pageSize": page_size or 1000, |
||||
"includeEmptyRows": include_empty_rows or False, |
||||
} |
||||
|
||||
response = ( |
||||
analytics.reports() |
||||
.batchGet(body={"reportRequests": [reportRequest]}) |
||||
.execute() |
||||
) |
||||
|
||||
if response.get("reports"): |
||||
report = response["reports"][0] |
||||
rows = report.get("data", {}).get("rows", []) |
||||
|
||||
while report.get("nextPageToken"): |
||||
time.sleep(1) |
||||
reportRequest.update({"pageToken": report["nextPageToken"]}) |
||||
response = ( |
||||
analytics.reports() |
||||
.batchGet(body={"reportRequests": [reportRequest]}) |
||||
.execute() |
||||
) |
||||
report = response["reports"][0] |
||||
rows.extend(report.get("data", {}).get("rows", [])) |
||||
|
||||
if report["data"]: |
||||
report["data"]["rows"] = rows |
||||
|
||||
return report |
||||
else: |
||||
return {} |
||||
@ -0,0 +1,95 @@ |
||||
import logging |
||||
import logging as log |
||||
import os |
||||
import time |
||||
|
||||
import docker |
||||
from airflow.hooks.base_hook import BaseHook |
||||
from selenium import webdriver |
||||
from selenium.webdriver.chrome.options import Options |
||||
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities |
||||
|
||||
|
||||
class SeleniumHook(BaseHook): |
||||
""" |
||||
Creates a Selenium Docker container on the host and controls the |
||||
browser by sending commands to the remote server. |
||||
""" |
||||
|
||||
def __init__(self): |
||||
logging.info("initialised hook") |
||||
pass |
||||
|
||||
def create_container(self): |
||||
""" |
||||
Creates the selenium docker container |
||||
""" |
||||
logging.info("creating_container") |
||||
cwd = os.getcwd() |
||||
self.local_downloads = os.path.join(cwd, "downloads") |
||||
self.sel_downloads = "/home/seluser/downloads" |
||||
volumes = [ |
||||
"{}:{}".format(self.local_downloads, self.sel_downloads), |
||||
"/dev/shm:/dev/shm", |
||||
] |
||||
client = docker.from_env() |
||||
container = client.containers.run( |
||||
"selenium/standalone-chrome", |
||||
volumes=volumes, |
||||
network="container_bridge", |
||||
detach=True, |
||||
) |
||||
self.container = container |
||||
cli = docker.APIClient() |
||||
self.container_ip = cli.inspect_container(container.id)["NetworkSettings"][ |
||||
"Networks" |
||||
]["container_bridge"]["IPAddress"] |
||||
|
||||
def create_driver(self): |
||||
""" |
||||
creates and configure the remote Selenium webdriver. |
||||
""" |
||||
logging.info("creating driver") |
||||
options = Options() |
||||
options.add_argument("--headless") |
||||
options.add_argument("--window-size=1920x1080") |
||||
chrome_driver = "{}:4444/wd/hub".format(self.container_ip) |
||||
# chrome_driver = '{}:4444/wd/hub'.format('http://127.0.0.1') # local |
||||
# wait for remote, unless timeout. |
||||
while True: |
||||
try: |
||||
driver = webdriver.Remote( |
||||
command_executor=chrome_driver, |
||||
desired_capabilities=DesiredCapabilities.CHROME, |
||||
options=options, |
||||
) |
||||
print("remote ready") |
||||
break |
||||
except: |
||||
print("remote not ready, sleeping for ten seconds.") |
||||
time.sleep(10) |
||||
# Enable downloads in headless chrome. |
||||
driver.command_executor._commands["send_command"] = ( |
||||
"POST", |
||||
"/session/$sessionId/chromium/send_command", |
||||
) |
||||
params = { |
||||
"cmd": "Page.setDownloadBehavior", |
||||
"params": {"behavior": "allow", "downloadPath": self.sel_downloads}, |
||||
} |
||||
driver.execute("send_command", params) |
||||
self.driver = driver |
||||
|
||||
def remove_container(self): |
||||
""" |
||||
This removes the Selenium container. |
||||
""" |
||||
self.container.remove(force=True) |
||||
print("Removed container: {}".format(self.container.id)) |
||||
|
||||
def run_script(self, script, args): |
||||
""" |
||||
This is a wrapper around the python script which sends commands to |
||||
the docker container. The first variable of the script must be the web driver. |
||||
""" |
||||
script(self.driver, *args) |
||||
@ -0,0 +1,2 @@ |
||||
from operators.singer import SingerOperator |
||||
from operators.sudo_bash import SudoBashOperator |
||||
@ -0,0 +1,95 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
# |
||||
from typing import Dict, Optional |
||||
|
||||
from airflow.exceptions import AirflowException |
||||
from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookHook |
||||
from airflow.providers.http.operators.http import SimpleHttpOperator |
||||
from airflow.utils.decorators import apply_defaults |
||||
|
||||
|
||||
class DiscordWebhookOperator(SimpleHttpOperator): |
||||
""" |
||||
This operator allows you to post messages to Discord using incoming webhooks. |
||||
Takes a Discord connection ID with a default relative webhook endpoint. The |
||||
default endpoint can be overridden using the webhook_endpoint parameter |
||||
(https://discordapp.com/developers/docs/resources/webhook). |
||||
Each Discord webhook can be pre-configured to use a specific username and |
||||
avatar_url. You can override these defaults in this operator. |
||||
:param http_conn_id: Http connection ID with host as "https://discord.com/api/" and |
||||
default webhook endpoint in the extra field in the form of |
||||
{"webhook_endpoint": "webhooks/{webhook.id}/{webhook.token}"} |
||||
:type http_conn_id: str |
||||
:param webhook_endpoint: Discord webhook endpoint in the form of |
||||
"webhooks/{webhook.id}/{webhook.token}" |
||||
:type webhook_endpoint: str |
||||
:param message: The message you want to send to your Discord channel |
||||
(max 2000 characters). (templated) |
||||
:type message: str |
||||
:param username: Override the default username of the webhook. (templated) |
||||
:type username: str |
||||
:param avatar_url: Override the default avatar of the webhook |
||||
:type avatar_url: str |
||||
:param tts: Is a text-to-speech message |
||||
:type tts: bool |
||||
:param proxy: Proxy to use to make the Discord webhook call |
||||
:type proxy: str |
||||
""" |
||||
|
||||
template_fields = ["username", "message"] |
||||
|
||||
@apply_defaults |
||||
def __init__( |
||||
self, |
||||
*, |
||||
http_conn_id: Optional[str] = None, |
||||
webhook_endpoint: Optional[str] = None, |
||||
message: str = "", |
||||
username: Optional[str] = None, |
||||
avatar_url: Optional[str] = None, |
||||
tts: bool = False, |
||||
proxy: Optional[str] = None, |
||||
**kwargs, |
||||
) -> None: |
||||
super().__init__(endpoint=webhook_endpoint, **kwargs) |
||||
|
||||
if not http_conn_id: |
||||
raise AirflowException("No valid Discord http_conn_id supplied.") |
||||
|
||||
self.http_conn_id = http_conn_id |
||||
self.webhook_endpoint = webhook_endpoint |
||||
self.message = message |
||||
self.username = username |
||||
self.avatar_url = avatar_url |
||||
self.tts = tts |
||||
self.proxy = proxy |
||||
self.hook: Optional[DiscordWebhookHook] = None |
||||
|
||||
def execute(self, context: Dict) -> None: |
||||
"""Call the DiscordWebhookHook to post message""" |
||||
self.hook = DiscordWebhookHook( |
||||
self.http_conn_id, |
||||
self.webhook_endpoint, |
||||
self.message, |
||||
self.username, |
||||
self.avatar_url, |
||||
self.tts, |
||||
self.proxy, |
||||
) |
||||
self.hook.execute() |
||||
@ -0,0 +1,254 @@ |
||||
import gzip |
||||
import json |
||||
import logging |
||||
import os |
||||
import time |
||||
|
||||
import boa |
||||
from airflow.hooks.base_hook import BaseHook |
||||
from airflow.hooks.S3_hook import S3Hook |
||||
from airflow.models import BaseOperator, Variable |
||||
from hooks.google import GoogleHook |
||||
from six import BytesIO |
||||
|
||||
|
||||
class GoogleSheetsToS3Operator(BaseOperator): |
||||
""" |
||||
Google Sheets To S3 Operator |
||||
|
||||
:param google_conn_id: The Google connection id. |
||||
:type google_conn_id: string |
||||
:param sheet_id: The id for associated report. |
||||
:type sheet_id: string |
||||
:param sheet_names: The name for the relevent sheets in the report. |
||||
:type sheet_names: string/array |
||||
:param range: The range of of cells containing the relevant data. |
||||
This must be the same for all sheets if multiple |
||||
are being pulled together. |
||||
Example: Sheet1!A2:E80 |
||||
:type range: string |
||||
:param include_schema: If set to true, infer the schema of the data and |
||||
output to S3 as a separate file |
||||
:type include_schema: boolean |
||||
:param s3_conn_id: The s3 connection id. |
||||
:type s3_conn_id: string |
||||
:param s3_key: The S3 key to be used to store the |
||||
retrieved data. |
||||
:type s3_key: string |
||||
""" |
||||
|
||||
template_fields = ("s3_key",) |
||||
|
||||
def __init__( |
||||
self, |
||||
google_conn_id, |
||||
sheet_id, |
||||
s3_conn_id, |
||||
s3_key, |
||||
compression_bound, |
||||
include_schema=False, |
||||
sheet_names=[], |
||||
range=None, |
||||
output_format="json", |
||||
*args, |
||||
**kwargs, |
||||
): |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
self.google_conn_id = google_conn_id |
||||
self.sheet_id = sheet_id |
||||
self.sheet_names = sheet_names |
||||
self.s3_conn_id = s3_conn_id |
||||
self.s3_key = s3_key |
||||
self.include_schema = include_schema |
||||
self.range = range |
||||
self.output_format = output_format.lower() |
||||
self.compression_bound = compression_bound |
||||
if self.output_format not in ("json"): |
||||
raise Exception("Acceptable output formats are: json.") |
||||
|
||||
if self.sheet_names and not isinstance(self.sheet_names, (str, list, tuple)): |
||||
raise Exception("Please specify the sheet names as a string or list.") |
||||
|
||||
def execute(self, context): |
||||
g_conn = GoogleHook(self.google_conn_id) |
||||
|
||||
if isinstance(self.sheet_names, str) and "," in self.sheet_names: |
||||
sheet_names = self.sheet_names.split(",") |
||||
else: |
||||
sheet_names = self.sheet_names |
||||
|
||||
sheets_object = g_conn.get_service_object("sheets", "v4") |
||||
logging.info("Retrieved Sheets Object") |
||||
|
||||
response = ( |
||||
sheets_object.spreadsheets() |
||||
.get(spreadsheetId=self.sheet_id, includeGridData=True) |
||||
.execute() |
||||
) |
||||
|
||||
title = response.get("properties").get("title") |
||||
sheets = response.get("sheets") |
||||
|
||||
final_output = dict() |
||||
|
||||
total_sheets = [] |
||||
for sheet in sheets: |
||||
name = sheet.get("properties").get("title") |
||||
name = boa.constrict(name) |
||||
total_sheets.append(name) |
||||
|
||||
if self.sheet_names: |
||||
if name not in sheet_names: |
||||
logging.info( |
||||
"{} is not found in available sheet names.".format(name) |
||||
) |
||||
continue |
||||
|
||||
table_name = name |
||||
data = sheet.get("data")[0].get("rowData") |
||||
output = [] |
||||
|
||||
for row in data: |
||||
row_data = [] |
||||
values = row.get("values") |
||||
for value in values: |
||||
ev = value.get("effectiveValue") |
||||
if ev is None: |
||||
row_data.append(None) |
||||
else: |
||||
for v in ev.values(): |
||||
row_data.append(v) |
||||
|
||||
output.append(row_data) |
||||
|
||||
if self.output_format == "json": |
||||
headers = output.pop(0) |
||||
output = [dict(zip(headers, row)) for row in output] |
||||
|
||||
final_output[table_name] = output |
||||
|
||||
s3 = S3Hook(self.s3_conn_id) |
||||
|
||||
for sheet in final_output: |
||||
output_data = final_output.get(sheet) |
||||
|
||||
file_name, file_extension = os.path.splitext(self.s3_key) |
||||
|
||||
output_name = "".join([file_name, "_", sheet, file_extension]) |
||||
|
||||
if self.include_schema is True: |
||||
schema_name = "".join( |
||||
[file_name, "_", sheet, "_schema", file_extension] |
||||
) |
||||
|
||||
self.output_manager( |
||||
s3, output_name, output_data, context, sheet, schema_name |
||||
) |
||||
|
||||
dag_id = context["ti"].dag_id |
||||
|
||||
var_key = "_".join([dag_id, self.sheet_id]) |
||||
Variable.set(key=var_key, value=json.dumps(total_sheets)) |
||||
time.sleep(10) |
||||
|
||||
return boa.constrict(title) |
||||
|
||||
def output_manager( |
||||
self, s3, output_name, output_data, context, sheet_name, schema_name=None |
||||
): |
||||
self.s3_bucket = BaseHook.get_connection(self.s3_conn_id).host |
||||
if self.output_format == "json": |
||||
output = "\n".join( |
||||
[ |
||||
json.dumps({boa.constrict(str(k)): v for k, v in record.items()}) |
||||
for record in output_data |
||||
] |
||||
) |
||||
|
||||
enc_output = str.encode(output, "utf-8") |
||||
|
||||
# if file is more than bound then apply gzip compression |
||||
if len(enc_output) / 1024 / 1024 >= self.compression_bound: |
||||
logging.info( |
||||
"File is more than {}MB, gzip compression will be applied".format( |
||||
self.compression_bound |
||||
) |
||||
) |
||||
output = gzip.compress(enc_output, compresslevel=5) |
||||
self.xcom_push( |
||||
context, |
||||
key="is_compressed_{}".format(sheet_name), |
||||
value="compressed", |
||||
) |
||||
self.load_bytes( |
||||
s3, |
||||
bytes_data=output, |
||||
key=output_name, |
||||
bucket_name=self.s3_bucket, |
||||
replace=True, |
||||
) |
||||
else: |
||||
logging.info( |
||||
"File is less than {}MB, compression will not be applied".format( |
||||
self.compression_bound |
||||
) |
||||
) |
||||
self.xcom_push( |
||||
context, |
||||
key="is_compressed_{}".format(sheet_name), |
||||
value="non-compressed", |
||||
) |
||||
s3.load_string( |
||||
string_data=output, |
||||
key=output_name, |
||||
bucket_name=self.s3_bucket, |
||||
replace=True, |
||||
) |
||||
|
||||
if self.include_schema is True: |
||||
output_keys = output_data[0].keys() |
||||
schema = [ |
||||
{"name": boa.constrict(a), "type": "varchar(512)"} |
||||
for a in output_keys |
||||
if a is not None |
||||
] |
||||
schema = {"columns": schema} |
||||
|
||||
s3.load_string( |
||||
string_data=json.dumps(schema), |
||||
key=schema_name, |
||||
bucket_name=self.s3_bucket, |
||||
replace=True, |
||||
) |
||||
|
||||
logging.info('Successfully output of "{}" to S3.'.format(output_name)) |
||||
|
||||
# TODO -- Add support for csv output |
||||
|
||||
# elif self.output_format == 'csv': |
||||
# with NamedTemporaryFile("w") as f: |
||||
# writer = csv.writer(f) |
||||
# writer.writerows(output_data) |
||||
# s3.load_file( |
||||
# filename=f.name, |
||||
# key=output_name, |
||||
# bucket_name=self.s3_bucket, |
||||
# replace=True |
||||
# ) |
||||
# |
||||
# if self.include_schema is True: |
||||
# pass |
||||
|
||||
# TODO: remove when airflow version is upgraded to 1.10 |
||||
def load_bytes(self, s3, bytes_data, key, bucket_name=None, replace=False): |
||||
if not bucket_name: |
||||
(bucket_name, key) = s3.parse_s3_url(key) |
||||
|
||||
if not replace and s3.check_for_key(key, bucket_name): |
||||
raise ValueError("The key {key} already exists.".format(key=key)) |
||||
|
||||
filelike_buffer = BytesIO(bytes_data) |
||||
|
||||
client = s3.get_conn() |
||||
client.upload_fileobj(filelike_buffer, bucket_name, key, ExtraArgs={}) |
||||
@ -0,0 +1,24 @@ |
||||
from airflow.models import BaseOperator |
||||
from airflow.utils.decorators import apply_defaults |
||||
from hooks.selenium_hook import SeleniumHook |
||||
|
||||
|
||||
class SeleniumOperator(BaseOperator): |
||||
""" |
||||
Selenium Operator |
||||
""" |
||||
|
||||
template_fields = ["script_args"] |
||||
|
||||
@apply_defaults |
||||
def __init__(self, script, script_args, *args, **kwargs): |
||||
super().__init__(*args, **kwargs) |
||||
self.script = script |
||||
self.script_args = script_args |
||||
|
||||
def execute(self, context): |
||||
hook = SeleniumHook() |
||||
hook.create_container() |
||||
hook.create_driver() |
||||
hook.run_script(self.script, self.script_args) |
||||
hook.remove_container() |
||||
@ -0,0 +1,45 @@ |
||||
from airflow.operators.bash_operator import BashOperator |
||||
from airflow.utils.decorators import apply_defaults |
||||
|
||||
|
||||
class SingerOperator(BashOperator): |
||||
""" |
||||
Singer Plugin |
||||
|
||||
:param tap: The relevant Singer Tap. |
||||
:type tap: string |
||||
:param target: The relevant Singer Target |
||||
:type target: string |
||||
:param tap_config: The config path for the Singer Tap. |
||||
:type tap_config: string |
||||
:param target_config: The config path for the Singer Target. |
||||
:type target_config: string |
||||
""" |
||||
|
||||
@apply_defaults |
||||
def __init__( |
||||
self, tap, target, tap_config=None, target_config=None, *args, **kwargs |
||||
): |
||||
|
||||
self.tap = "tap-{}".format(tap) |
||||
self.target = "target-{}".format(target) |
||||
self.tap_config = tap_config |
||||
self.target_config = target_config |
||||
|
||||
if self.tap_config: |
||||
if self.target_config: |
||||
self.bash_command = "{} -c {} | {} -c {}".format( |
||||
self.tap, self.tap_config, self.target, self.target_config |
||||
) |
||||
else: |
||||
self.bash_command = "{} -c {} | {}".format( |
||||
self.tap, self.tap_config, self.target |
||||
) |
||||
elif self.target_config: |
||||
self.bash_command = "{} | {} -c {}".format( |
||||
self.tap, self.target, self.target_config |
||||
) |
||||
else: |
||||
self.bash_command = "{} | {}".format(self.tap, self.target) |
||||
|
||||
super().__init__(bash_command=self.bash_command, *args, **kwargs) |
||||
@ -0,0 +1,110 @@ |
||||
import getpass |
||||
import logging |
||||
import os |
||||
from builtins import bytes |
||||
from subprocess import PIPE, STDOUT, Popen |
||||
from tempfile import NamedTemporaryFile, gettempdir |
||||
|
||||
from airflow.exceptions import AirflowException |
||||
from airflow.models import BaseOperator |
||||
from airflow.utils.decorators import apply_defaults |
||||
from airflow.utils.file import TemporaryDirectory |
||||
|
||||
|
||||
class SudoBashOperator(BaseOperator): |
||||
""" |
||||
Execute a Bash script, command or set of commands but sudo's as another user to execute them. |
||||
|
||||
:param bash_command: The command, set of commands or reference to a |
||||
bash script (must be '.sh') to be executed. |
||||
:type bash_command: string |
||||
:param user: The user to run the command as. The Airflow worker user |
||||
must have permission to sudo as that user |
||||
:type user: string |
||||
:param env: If env is not None, it must be a mapping that defines the |
||||
environment variables for the new process; these are used instead |
||||
of inheriting the current process environment, which is the default |
||||
behavior. |
||||
:type env: dict |
||||
:type output_encoding: output encoding of bash command |
||||
""" |
||||
|
||||
template_fields = ("bash_command", "user", "env", "output_encoding") |
||||
template_ext = ( |
||||
".sh", |
||||
".bash", |
||||
) |
||||
ui_color = "#f0ede4" |
||||
|
||||
@apply_defaults |
||||
def __init__( |
||||
self, |
||||
bash_command, |
||||
user, |
||||
xcom_push=False, |
||||
env=None, |
||||
output_encoding="utf-8", |
||||
*args, |
||||
**kwargs, |
||||
): |
||||
|
||||
super(SudoBashOperator, self).__init__(*args, **kwargs) |
||||
self.bash_command = bash_command |
||||
self.user = user |
||||
self.env = env |
||||
self.xcom_push_flag = xcom_push |
||||
self.output_encoding = output_encoding |
||||
|
||||
def execute(self, context): |
||||
""" |
||||
Execute the bash command in a temporary directory which will be cleaned afterwards. |
||||
""" |
||||
logging.info("tmp dir root location: \n" + gettempdir()) |
||||
with TemporaryDirectory(prefix="airflowtmp") as tmp_dir: |
||||
os.chmod(tmp_dir, 777) |
||||
# Ensure the sudo user has perms to their current working directory for making tempfiles |
||||
# This is not really a security flaw because the only thing in that dir is the |
||||
# temp script, owned by the airflow user and any temp files made by the sudo user |
||||
# and all of those will be created with the owning user's umask |
||||
# If a process needs finer control over the tempfiles it creates, that process can chmod |
||||
# them as they are created. |
||||
with NamedTemporaryFile(dir=tmp_dir, prefix=self.task_id) as f: |
||||
|
||||
if self.user == getpass.getuser(): # don't try to sudo as yourself |
||||
f.write(bytes(self.bash_command, "utf_8")) |
||||
else: |
||||
sudo_cmd = "sudo -u {} sh -c '{}'".format( |
||||
self.user, self.bash_command |
||||
) |
||||
f.write(bytes(sudo_cmd, "utf_8")) |
||||
f.flush() |
||||
|
||||
logging.info("Temporary script location: {0}".format(f.name)) |
||||
logging.info("Running command: {}".format(self.bash_command)) |
||||
self.sp = Popen( |
||||
["bash", f.name], |
||||
stdout=PIPE, |
||||
stderr=STDOUT, |
||||
cwd=tmp_dir, |
||||
env=self.env, |
||||
) |
||||
|
||||
logging.info("Output:") |
||||
line = "" |
||||
for line in iter(self.sp.stdout.readline, b""): |
||||
line = line.decode(self.output_encoding).strip() |
||||
logging.info(line) |
||||
self.sp.wait() |
||||
logging.info( |
||||
"Command exited with return code {0}".format(self.sp.returncode) |
||||
) |
||||
|
||||
if self.sp.returncode: |
||||
raise AirflowException("Bash command failed") |
||||
|
||||
if self.xcom_push_flag: |
||||
return line |
||||
|
||||
def on_kill(self): |
||||
logging.warn("Sending SIGTERM signal to bash subprocess") |
||||
self.sp.terminate() |
||||
@ -0,0 +1,15 @@ |
||||
from airflow.plugins_manager import AirflowPlugin |
||||
from hooks.google import GoogleHook |
||||
from operators.google import GoogleSheetsToS3Operator |
||||
|
||||
|
||||
class google_sheets_plugin(AirflowPlugin): |
||||
name = "GoogleSheetsPlugin" |
||||
operators = [GoogleSheetsToS3Operator] |
||||
# Leave in for explicitness |
||||
hooks = [GoogleHook] |
||||
executors = [] |
||||
macros = [] |
||||
admin_views = [] |
||||
flask_blueprints = [] |
||||
menu_links = [] |
||||
@ -0,0 +1,13 @@ |
||||
from airflow.plugins_manager import AirflowPlugin |
||||
from operators.singer import SingerOperator |
||||
|
||||
|
||||
class SingerPlugin(AirflowPlugin): |
||||
name = "singer_plugin" |
||||
hooks = [] |
||||
operators = [SingerOperator] |
||||
executors = [] |
||||
macros = [] |
||||
admin_views = [] |
||||
flask_blueprints = [] |
||||
menu_links = [] |
||||
@ -0,0 +1,25 @@ |
||||
.. Licensed to the Apache Software Foundation (ASF) under one |
||||
or more contributor license agreements. See the NOTICE file |
||||
distributed with this work for additional information |
||||
regarding copyright ownership. The ASF licenses this file |
||||
to you under the Apache License, Version 2.0 (the |
||||
"License"); you may not use this file except in compliance |
||||
with the License. You may obtain a copy of the License at |
||||
|
||||
.. http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
||||
.. Unless required by applicable law or agreed to in writing, |
||||
software distributed under the License is distributed on an |
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
KIND, either express or implied. See the License for the |
||||
specific language governing permissions and limitations |
||||
under the License. |
||||
|
||||
|
||||
Changelog |
||||
--------- |
||||
|
||||
1.0.0 |
||||
..... |
||||
|
||||
Initial version of the provider. |
||||
@ -0,0 +1,17 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
@ -0,0 +1,16 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
@ -0,0 +1,64 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
"""Example DAG demonstrating the usage of the AirbyteTriggerSyncOperator.""" |
||||
|
||||
from datetime import timedelta |
||||
|
||||
from airflow import DAG |
||||
from airflow.providers.airbyte.operators.airbyte import AirbyteTriggerSyncOperator |
||||
from airflow.providers.airbyte.sensors.airbyte import AirbyteJobSensor |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
args = { |
||||
"owner": "airflow", |
||||
} |
||||
|
||||
with DAG( |
||||
dag_id="example_airbyte_operator", |
||||
default_args=args, |
||||
schedule_interval=None, |
||||
start_date=days_ago(1), |
||||
dagrun_timeout=timedelta(minutes=60), |
||||
tags=["example"], |
||||
) as dag: |
||||
|
||||
# [START howto_operator_airbyte_synchronous] |
||||
sync_source_destination = AirbyteTriggerSyncOperator( |
||||
task_id="airbyte_sync_source_dest_example", |
||||
airbyte_conn_id="airbyte_default", |
||||
connection_id="15bc3800-82e4-48c3-a32d-620661273f28", |
||||
) |
||||
# [END howto_operator_airbyte_synchronous] |
||||
|
||||
# [START howto_operator_airbyte_asynchronous] |
||||
async_source_destination = AirbyteTriggerSyncOperator( |
||||
task_id="airbyte_async_source_dest_example", |
||||
airbyte_conn_id="airbyte_default", |
||||
connection_id="15bc3800-82e4-48c3-a32d-620661273f28", |
||||
asynchronous=True, |
||||
) |
||||
|
||||
airbyte_sensor = AirbyteJobSensor( |
||||
task_id="airbyte_sensor_source_dest_example", |
||||
airbyte_job_id="{{task_instance.xcom_pull(task_ids='airbyte_async_source_dest_example')}}", |
||||
airbyte_conn_id="airbyte_default", |
||||
) |
||||
# [END howto_operator_airbyte_asynchronous] |
||||
|
||||
async_source_destination >> airbyte_sensor |
||||
@ -0,0 +1,17 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
@ -0,0 +1,123 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
import time |
||||
from typing import Any, Optional |
||||
|
||||
from airflow.exceptions import AirflowException |
||||
from airflow.providers.http.hooks.http import HttpHook |
||||
|
||||
|
||||
class AirbyteHook(HttpHook): |
||||
""" |
||||
Hook for Airbyte API |
||||
|
||||
:param airbyte_conn_id: Required. The name of the Airflow connection to get |
||||
connection information for Airbyte. |
||||
:type airbyte_conn_id: str |
||||
:param api_version: Optional. Airbyte API version. |
||||
:type api_version: str |
||||
""" |
||||
|
||||
RUNNING = "running" |
||||
SUCCEEDED = "succeeded" |
||||
CANCELLED = "cancelled" |
||||
PENDING = "pending" |
||||
FAILED = "failed" |
||||
ERROR = "error" |
||||
|
||||
def __init__( |
||||
self, |
||||
airbyte_conn_id: str = "airbyte_default", |
||||
api_version: Optional[str] = "v1", |
||||
) -> None: |
||||
super().__init__(http_conn_id=airbyte_conn_id) |
||||
self.api_version: str = api_version |
||||
|
||||
def wait_for_job( |
||||
self, |
||||
job_id: str, |
||||
wait_seconds: Optional[float] = 3, |
||||
timeout: Optional[float] = 3600, |
||||
) -> None: |
||||
""" |
||||
Helper method which polls a job to check if it finishes. |
||||
|
||||
:param job_id: Required. Id of the Airbyte job |
||||
:type job_id: str |
||||
:param wait_seconds: Optional. Number of seconds between checks. |
||||
:type wait_seconds: float |
||||
:param timeout: Optional. How many seconds wait for job to be ready. |
||||
Used only if ``asynchronous`` is False. |
||||
:type timeout: float |
||||
""" |
||||
state = None |
||||
start = time.monotonic() |
||||
while True: |
||||
if timeout and start + timeout < time.monotonic(): |
||||
raise AirflowException( |
||||
f"Timeout: Airbyte job {job_id} is not ready after {timeout}s" |
||||
) |
||||
time.sleep(wait_seconds) |
||||
try: |
||||
job = self.get_job(job_id=job_id) |
||||
state = job.json()["job"]["status"] |
||||
except AirflowException as err: |
||||
self.log.info( |
||||
"Retrying. Airbyte API returned server error when waiting for job: %s", |
||||
err, |
||||
) |
||||
continue |
||||
|
||||
if state in (self.RUNNING, self.PENDING): |
||||
continue |
||||
if state == self.SUCCEEDED: |
||||
break |
||||
if state == self.ERROR: |
||||
raise AirflowException(f"Job failed:\n{job}") |
||||
elif state == self.CANCELLED: |
||||
raise AirflowException(f"Job was cancelled:\n{job}") |
||||
else: |
||||
raise Exception( |
||||
f"Encountered unexpected state `{state}` for job_id `{job_id}`" |
||||
) |
||||
|
||||
def submit_sync_connection(self, connection_id: str) -> Any: |
||||
""" |
||||
Submits a job to a Airbyte server. |
||||
|
||||
:param connection_id: Required. The ConnectionId of the Airbyte Connection. |
||||
:type connectiond_id: str |
||||
""" |
||||
return self.run( |
||||
endpoint=f"api/{self.api_version}/connections/sync", |
||||
json={"connectionId": connection_id}, |
||||
headers={"accept": "application/json"}, |
||||
) |
||||
|
||||
def get_job(self, job_id: int) -> Any: |
||||
""" |
||||
Gets the resource representation for a job in Airbyte. |
||||
|
||||
:param job_id: Required. Id of the Airbyte job |
||||
:type job_id: int |
||||
""" |
||||
return self.run( |
||||
endpoint=f"api/{self.api_version}/jobs/get", |
||||
json={"id": job_id}, |
||||
headers={"accept": "application/json"}, |
||||
) |
||||
@ -0,0 +1,17 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
@ -0,0 +1,89 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
from typing import Optional |
||||
|
||||
from airflow.models import BaseOperator |
||||
from airflow.providers.airbyte.hooks.airbyte import AirbyteHook |
||||
from airflow.utils.decorators import apply_defaults |
||||
|
||||
|
||||
class AirbyteTriggerSyncOperator(BaseOperator): |
||||
""" |
||||
This operator allows you to submit a job to an Airbyte server to run a integration |
||||
process between your source and destination. |
||||
|
||||
.. seealso:: |
||||
For more information on how to use this operator, take a look at the guide: |
||||
:ref:`howto/operator:AirbyteTriggerSyncOperator` |
||||
|
||||
:param airbyte_conn_id: Required. The name of the Airflow connection to get connection |
||||
information for Airbyte. |
||||
:type airbyte_conn_id: str |
||||
:param connection_id: Required. The Airbyte ConnectionId UUID between a source and destination. |
||||
:type connection_id: str |
||||
:param asynchronous: Optional. Flag to get job_id after submitting the job to the Airbyte API. |
||||
This is useful for submitting long running jobs and |
||||
waiting on them asynchronously using the AirbyteJobSensor. |
||||
:type asynchronous: bool |
||||
:param api_version: Optional. Airbyte API version. |
||||
:type api_version: str |
||||
:param wait_seconds: Optional. Number of seconds between checks. Only used when ``asynchronous`` is False. |
||||
:type wait_seconds: float |
||||
:param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. |
||||
Only used when ``asynchronous`` is False. |
||||
:type timeout: float |
||||
""" |
||||
|
||||
template_fields = ("connection_id",) |
||||
|
||||
@apply_defaults |
||||
def __init__( |
||||
self, |
||||
connection_id: str, |
||||
airbyte_conn_id: str = "airbyte_default", |
||||
asynchronous: Optional[bool] = False, |
||||
api_version: Optional[str] = "v1", |
||||
wait_seconds: Optional[float] = 3, |
||||
timeout: Optional[float] = 3600, |
||||
**kwargs, |
||||
) -> None: |
||||
super().__init__(**kwargs) |
||||
self.airbyte_conn_id = airbyte_conn_id |
||||
self.connection_id = connection_id |
||||
self.timeout = timeout |
||||
self.api_version = api_version |
||||
self.wait_seconds = wait_seconds |
||||
self.asynchronous = asynchronous |
||||
|
||||
def execute(self, context) -> None: |
||||
"""Create Airbyte Job and wait to finish""" |
||||
hook = AirbyteHook( |
||||
airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version |
||||
) |
||||
job_object = hook.submit_sync_connection(connection_id=self.connection_id) |
||||
job_id = job_object.json()["job"]["id"] |
||||
|
||||
self.log.info("Job %s was submitted to Airbyte Server", job_id) |
||||
if not self.asynchronous: |
||||
self.log.info("Waiting for job %s to complete", job_id) |
||||
hook.wait_for_job( |
||||
job_id=job_id, wait_seconds=self.wait_seconds, timeout=self.timeout |
||||
) |
||||
self.log.info("Job %s completed successfully", job_id) |
||||
|
||||
return job_id |
||||
@ -0,0 +1,51 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
--- |
||||
package-name: apache-airflow-providers-airbyte |
||||
name: Airbyte |
||||
description: | |
||||
`Airbyte <https://airbyte.io/>`__ |
||||
|
||||
versions: |
||||
- 1.0.0 |
||||
|
||||
integrations: |
||||
- integration-name: Airbyte |
||||
external-doc-url: https://www.airbyte.io/ |
||||
logo: /integration-logos/airbyte/Airbyte.png |
||||
how-to-guide: |
||||
- /docs/apache-airflow-providers-airbyte/operators/airbyte.rst |
||||
tags: [service] |
||||
|
||||
operators: |
||||
- integration-name: Airbyte |
||||
python-modules: |
||||
- airflow.providers.airbyte.operators.airbyte |
||||
|
||||
hooks: |
||||
- integration-name: Airbyte |
||||
python-modules: |
||||
- airflow.providers.airbyte.hooks.airbyte |
||||
|
||||
sensors: |
||||
- integration-name: Airbyte |
||||
python-modules: |
||||
- airflow.providers.airbyte.sensors.airbyte |
||||
|
||||
hook-class-names: |
||||
- airflow.providers.airbyte.hooks.airbyte.AirbyteHook |
||||
@ -0,0 +1,16 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
@ -0,0 +1,75 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
"""This module contains a Airbyte Job sensor.""" |
||||
from typing import Optional |
||||
|
||||
from airflow.exceptions import AirflowException |
||||
from airflow.providers.airbyte.hooks.airbyte import AirbyteHook |
||||
from airflow.sensors.base import BaseSensorOperator |
||||
from airflow.utils.decorators import apply_defaults |
||||
|
||||
|
||||
class AirbyteJobSensor(BaseSensorOperator): |
||||
""" |
||||
Check for the state of a previously submitted Airbyte job. |
||||
|
||||
:param airbyte_job_id: Required. Id of the Airbyte job |
||||
:type airbyte_job_id: str |
||||
:param airbyte_conn_id: Required. The name of the Airflow connection to get |
||||
connection information for Airbyte. |
||||
:type airbyte_conn_id: str |
||||
:param api_version: Optional. Airbyte API version. |
||||
:type api_version: str |
||||
""" |
||||
|
||||
template_fields = ("airbyte_job_id",) |
||||
ui_color = "#6C51FD" |
||||
|
||||
@apply_defaults |
||||
def __init__( |
||||
self, |
||||
*, |
||||
airbyte_job_id: str, |
||||
airbyte_conn_id: str = "airbyte_default", |
||||
api_version: Optional[str] = "v1", |
||||
**kwargs, |
||||
) -> None: |
||||
super().__init__(**kwargs) |
||||
self.airbyte_conn_id = airbyte_conn_id |
||||
self.airbyte_job_id = airbyte_job_id |
||||
self.api_version = api_version |
||||
|
||||
def poke(self, context: dict) -> bool: |
||||
hook = AirbyteHook( |
||||
airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version |
||||
) |
||||
job = hook.get_job(job_id=self.airbyte_job_id) |
||||
status = job.json()["job"]["status"] |
||||
|
||||
if status == hook.FAILED: |
||||
raise AirflowException(f"Job failed: \n{job}") |
||||
elif status == hook.CANCELLED: |
||||
raise AirflowException(f"Job was cancelled: \n{job}") |
||||
elif status == hook.SUCCEEDED: |
||||
self.log.info("Job %s completed successfully.", self.airbyte_job_id) |
||||
return True |
||||
elif status == hook.ERROR: |
||||
self.log.info("Job %s attempt has failed.", self.airbyte_job_id) |
||||
|
||||
self.log.info("Waiting for job %s to complete.", self.airbyte_job_id) |
||||
return False |
||||
@ -0,0 +1,63 @@ |
||||
.. Licensed to the Apache Software Foundation (ASF) under one |
||||
or more contributor license agreements. See the NOTICE file |
||||
distributed with this work for additional information |
||||
regarding copyright ownership. The ASF licenses this file |
||||
to you under the Apache License, Version 2.0 (the |
||||
"License"); you may not use this file except in compliance |
||||
with the License. You may obtain a copy of the License at |
||||
|
||||
.. http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
||||
.. Unless required by applicable law or agreed to in writing, |
||||
software distributed under the License is distributed on an |
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
KIND, either express or implied. See the License for the |
||||
specific language governing permissions and limitations |
||||
under the License. |
||||
|
||||
|
||||
Changelog |
||||
--------- |
||||
|
||||
1.2.0 |
||||
..... |
||||
|
||||
Features |
||||
~~~~~~~~ |
||||
|
||||
* ``Avoid using threads in S3 remote logging upload (#14414)`` |
||||
* ``Allow AWS Operator RedshiftToS3Transfer To Run a Custom Query (#14177)`` |
||||
* ``includes the STS token if STS credentials are used (#11227)`` |
||||
|
||||
1.1.0 |
||||
..... |
||||
|
||||
Features |
||||
~~~~~~~~ |
||||
|
||||
* ``Adding support to put extra arguments for Glue Job. (#14027)`` |
||||
* ``Add aws ses email backend for use with EmailOperator. (#13986)`` |
||||
* ``Add bucket_name to template fileds in S3 operators (#13973)`` |
||||
* ``Add ExasolToS3Operator (#13847)`` |
||||
* ``AWS Glue Crawler Integration (#13072)`` |
||||
* ``Add acl_policy to S3CopyObjectOperator (#13773)`` |
||||
* ``AllowDiskUse parameter and docs in MongotoS3Operator (#12033)`` |
||||
* ``Add S3ToFTPOperator (#11747)`` |
||||
* ``add xcom push for ECSOperator (#12096)`` |
||||
* ``[AIRFLOW-3723] Add Gzip capability to mongo_to_S3 operator (#13187)`` |
||||
* ``Add S3KeySizeSensor (#13049)`` |
||||
* ``Add 'mongo_collection' to template_fields in MongoToS3Operator (#13361)`` |
||||
* ``Allow Tags on AWS Batch Job Submission (#13396)`` |
||||
|
||||
Bug fixes |
||||
~~~~~~~~~ |
||||
|
||||
* ``Fix bug in GCSToS3Operator (#13718)`` |
||||
* ``Fix S3KeysUnchangedSensor so that template_fields work (#13490)`` |
||||
|
||||
|
||||
1.0.0 |
||||
..... |
||||
|
||||
|
||||
Initial version of the provider. |
||||
@ -0,0 +1,16 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
@ -0,0 +1,16 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
@ -0,0 +1,16 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
@ -0,0 +1,71 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
""" |
||||
This is an example dag for using `AWSDataSyncOperator` in a straightforward manner. |
||||
|
||||
This DAG gets an AWS TaskArn for a specified source and destination, and then attempts to execute it. |
||||
It assumes there is a single task returned and does not do error checking (eg if multiple tasks were found). |
||||
|
||||
This DAG relies on the following environment variables: |
||||
|
||||
* SOURCE_LOCATION_URI - Source location URI, usually on premises SMB or NFS |
||||
* DESTINATION_LOCATION_URI - Destination location URI, usually S3 |
||||
""" |
||||
|
||||
from os import getenv |
||||
|
||||
from airflow import models |
||||
from airflow.providers.amazon.aws.operators.datasync import AWSDataSyncOperator |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
# [START howto_operator_datasync_1_args_1] |
||||
TASK_ARN = getenv("TASK_ARN", "my_aws_datasync_task_arn") |
||||
# [END howto_operator_datasync_1_args_1] |
||||
|
||||
# [START howto_operator_datasync_1_args_2] |
||||
SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/") |
||||
|
||||
DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix") |
||||
# [END howto_operator_datasync_1_args_2] |
||||
|
||||
|
||||
with models.DAG( |
||||
"example_datasync_1_1", |
||||
schedule_interval=None, # Override to match your needs |
||||
start_date=days_ago(1), |
||||
tags=["example"], |
||||
) as dag: |
||||
|
||||
# [START howto_operator_datasync_1_1] |
||||
datasync_task_1 = AWSDataSyncOperator( |
||||
aws_conn_id="aws_default", task_id="datasync_task_1", task_arn=TASK_ARN |
||||
) |
||||
# [END howto_operator_datasync_1_1] |
||||
|
||||
with models.DAG( |
||||
"example_datasync_1_2", |
||||
start_date=days_ago(1), |
||||
schedule_interval=None, # Override to match your needs |
||||
) as dag: |
||||
# [START howto_operator_datasync_1_2] |
||||
datasync_task_2 = AWSDataSyncOperator( |
||||
aws_conn_id="aws_default", |
||||
task_id="datasync_task_2", |
||||
source_location_uri=SOURCE_LOCATION_URI, |
||||
destination_location_uri=DESTINATION_LOCATION_URI, |
||||
) |
||||
# [END howto_operator_datasync_1_2] |
||||
@ -0,0 +1,100 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
""" |
||||
This is an example dag for using `AWSDataSyncOperator` in a more complex manner. |
||||
|
||||
- Try to get a TaskArn. If one exists, update it. |
||||
- If no tasks exist, try to create a new DataSync Task. |
||||
- If source and destination locations don't exist for the new task, create them first |
||||
- If many tasks exist, raise an Exception |
||||
- After getting or creating a DataSync Task, run it |
||||
|
||||
This DAG relies on the following environment variables: |
||||
|
||||
* SOURCE_LOCATION_URI - Source location URI, usually on premises SMB or NFS |
||||
* DESTINATION_LOCATION_URI - Destination location URI, usually S3 |
||||
* CREATE_TASK_KWARGS - Passed to boto3.create_task(**kwargs) |
||||
* CREATE_SOURCE_LOCATION_KWARGS - Passed to boto3.create_location(**kwargs) |
||||
* CREATE_DESTINATION_LOCATION_KWARGS - Passed to boto3.create_location(**kwargs) |
||||
* UPDATE_TASK_KWARGS - Passed to boto3.update_task(**kwargs) |
||||
""" |
||||
|
||||
import json |
||||
import re |
||||
from os import getenv |
||||
|
||||
from airflow import models |
||||
from airflow.providers.amazon.aws.operators.datasync import AWSDataSyncOperator |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
# [START howto_operator_datasync_2_args] |
||||
SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/") |
||||
|
||||
DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix") |
||||
|
||||
default_create_task_kwargs = '{"Name": "Created by Airflow"}' |
||||
CREATE_TASK_KWARGS = json.loads( |
||||
getenv("CREATE_TASK_KWARGS", default_create_task_kwargs) |
||||
) |
||||
|
||||
default_create_source_location_kwargs = "{}" |
||||
CREATE_SOURCE_LOCATION_KWARGS = json.loads( |
||||
getenv("CREATE_SOURCE_LOCATION_KWARGS", default_create_source_location_kwargs) |
||||
) |
||||
|
||||
bucket_access_role_arn = ( |
||||
"arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role" |
||||
) |
||||
default_destination_location_kwargs = """\ |
||||
{"S3BucketArn": "arn:aws:s3:::mybucket", |
||||
"S3Config": {"BucketAccessRoleArn": |
||||
"arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role"} |
||||
}""" |
||||
CREATE_DESTINATION_LOCATION_KWARGS = json.loads( |
||||
getenv( |
||||
"CREATE_DESTINATION_LOCATION_KWARGS", |
||||
re.sub(r"[\s+]", "", default_destination_location_kwargs), |
||||
) |
||||
) |
||||
|
||||
default_update_task_kwargs = '{"Name": "Updated by Airflow"}' |
||||
UPDATE_TASK_KWARGS = json.loads( |
||||
getenv("UPDATE_TASK_KWARGS", default_update_task_kwargs) |
||||
) |
||||
|
||||
# [END howto_operator_datasync_2_args] |
||||
|
||||
with models.DAG( |
||||
"example_datasync_2", |
||||
schedule_interval=None, # Override to match your needs |
||||
start_date=days_ago(1), |
||||
tags=["example"], |
||||
) as dag: |
||||
|
||||
# [START howto_operator_datasync_2] |
||||
datasync_task = AWSDataSyncOperator( |
||||
aws_conn_id="aws_default", |
||||
task_id="datasync_task", |
||||
source_location_uri=SOURCE_LOCATION_URI, |
||||
destination_location_uri=DESTINATION_LOCATION_URI, |
||||
create_task_kwargs=CREATE_TASK_KWARGS, |
||||
create_source_location_kwargs=CREATE_SOURCE_LOCATION_KWARGS, |
||||
create_destination_location_kwargs=CREATE_DESTINATION_LOCATION_KWARGS, |
||||
update_task_kwargs=UPDATE_TASK_KWARGS, |
||||
delete_task_after_execution=True, |
||||
) |
||||
# [END howto_operator_datasync_2] |
||||
@ -0,0 +1,82 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
""" |
||||
This is an example dag for ECSOperator. |
||||
|
||||
The task "hello_world" runs `hello-world` task in `c` cluster. |
||||
It overrides the command in the `hello-world-container` container. |
||||
""" |
||||
|
||||
import datetime |
||||
import os |
||||
|
||||
from airflow import DAG |
||||
from airflow.providers.amazon.aws.operators.ecs import ECSOperator |
||||
|
||||
DEFAULT_ARGS = { |
||||
"owner": "airflow", |
||||
"depends_on_past": False, |
||||
"email": ["airflow@example.com"], |
||||
"email_on_failure": False, |
||||
"email_on_retry": False, |
||||
} |
||||
|
||||
dag = DAG( |
||||
dag_id="ecs_fargate_dag", |
||||
default_args=DEFAULT_ARGS, |
||||
default_view="graph", |
||||
schedule_interval=None, |
||||
start_date=datetime.datetime(2020, 1, 1), |
||||
tags=["example"], |
||||
) |
||||
# generate dag documentation |
||||
dag.doc_md = __doc__ |
||||
|
||||
# [START howto_operator_ecs] |
||||
hello_world = ECSOperator( |
||||
task_id="hello_world", |
||||
dag=dag, |
||||
aws_conn_id="aws_ecs", |
||||
cluster="c", |
||||
task_definition="hello-world", |
||||
launch_type="FARGATE", |
||||
overrides={ |
||||
"containerOverrides": [ |
||||
{ |
||||
"name": "hello-world-container", |
||||
"command": ["echo", "hello", "world"], |
||||
}, |
||||
], |
||||
}, |
||||
network_configuration={ |
||||
"awsvpcConfiguration": { |
||||
"securityGroups": [os.environ.get("SECURITY_GROUP_ID", "sg-123abc")], |
||||
"subnets": [os.environ.get("SUBNET_ID", "subnet-123456ab")], |
||||
}, |
||||
}, |
||||
tags={ |
||||
"Customer": "X", |
||||
"Project": "Y", |
||||
"Application": "Z", |
||||
"Version": "0.0.1", |
||||
"Environment": "Development", |
||||
}, |
||||
awslogs_group="/ecs/hello-world", |
||||
awslogs_stream_prefix="prefix_b/hello-world-container", # prefix with container name |
||||
) |
||||
# [END howto_operator_ecs] |
||||
@ -0,0 +1,96 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
""" |
||||
This is an example dag for a AWS EMR Pipeline with auto steps. |
||||
""" |
||||
from datetime import timedelta |
||||
|
||||
from airflow import DAG |
||||
from airflow.providers.amazon.aws.operators.emr_create_job_flow import ( |
||||
EmrCreateJobFlowOperator, |
||||
) |
||||
from airflow.providers.amazon.aws.sensors.emr_job_flow import EmrJobFlowSensor |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
DEFAULT_ARGS = { |
||||
"owner": "airflow", |
||||
"depends_on_past": False, |
||||
"email": ["airflow@example.com"], |
||||
"email_on_failure": False, |
||||
"email_on_retry": False, |
||||
} |
||||
|
||||
# [START howto_operator_emr_automatic_steps_config] |
||||
SPARK_STEPS = [ |
||||
{ |
||||
"Name": "calculate_pi", |
||||
"ActionOnFailure": "CONTINUE", |
||||
"HadoopJarStep": { |
||||
"Jar": "command-runner.jar", |
||||
"Args": ["/usr/lib/spark/bin/run-example", "SparkPi", "10"], |
||||
}, |
||||
} |
||||
] |
||||
|
||||
JOB_FLOW_OVERRIDES = { |
||||
"Name": "PiCalc", |
||||
"ReleaseLabel": "emr-5.29.0", |
||||
"Instances": { |
||||
"InstanceGroups": [ |
||||
{ |
||||
"Name": "Master node", |
||||
"Market": "SPOT", |
||||
"InstanceRole": "MASTER", |
||||
"InstanceType": "m1.medium", |
||||
"InstanceCount": 1, |
||||
} |
||||
], |
||||
"KeepJobFlowAliveWhenNoSteps": False, |
||||
"TerminationProtected": False, |
||||
}, |
||||
"Steps": SPARK_STEPS, |
||||
"JobFlowRole": "EMR_EC2_DefaultRole", |
||||
"ServiceRole": "EMR_DefaultRole", |
||||
} |
||||
# [END howto_operator_emr_automatic_steps_config] |
||||
|
||||
with DAG( |
||||
dag_id="emr_job_flow_automatic_steps_dag", |
||||
default_args=DEFAULT_ARGS, |
||||
dagrun_timeout=timedelta(hours=2), |
||||
start_date=days_ago(2), |
||||
schedule_interval="0 3 * * *", |
||||
tags=["example"], |
||||
) as dag: |
||||
|
||||
# [START howto_operator_emr_automatic_steps_tasks] |
||||
job_flow_creator = EmrCreateJobFlowOperator( |
||||
task_id="create_job_flow", |
||||
job_flow_overrides=JOB_FLOW_OVERRIDES, |
||||
aws_conn_id="aws_default", |
||||
emr_conn_id="emr_default", |
||||
) |
||||
|
||||
job_sensor = EmrJobFlowSensor( |
||||
task_id="check_job_flow", |
||||
job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}", |
||||
aws_conn_id="aws_default", |
||||
) |
||||
|
||||
job_flow_creator >> job_sensor |
||||
# [END howto_operator_emr_automatic_steps_tasks] |
||||
@ -0,0 +1,114 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
""" |
||||
This is an example dag for a AWS EMR Pipeline. |
||||
|
||||
Starting by creating a cluster, adding steps/operations, checking steps and finally when finished |
||||
terminating the cluster. |
||||
""" |
||||
from datetime import timedelta |
||||
|
||||
from airflow import DAG |
||||
from airflow.providers.amazon.aws.operators.emr_add_steps import EmrAddStepsOperator |
||||
from airflow.providers.amazon.aws.operators.emr_create_job_flow import ( |
||||
EmrCreateJobFlowOperator, |
||||
) |
||||
from airflow.providers.amazon.aws.operators.emr_terminate_job_flow import ( |
||||
EmrTerminateJobFlowOperator, |
||||
) |
||||
from airflow.providers.amazon.aws.sensors.emr_step import EmrStepSensor |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
DEFAULT_ARGS = { |
||||
"owner": "airflow", |
||||
"depends_on_past": False, |
||||
"email": ["airflow@example.com"], |
||||
"email_on_failure": False, |
||||
"email_on_retry": False, |
||||
} |
||||
|
||||
SPARK_STEPS = [ |
||||
{ |
||||
"Name": "calculate_pi", |
||||
"ActionOnFailure": "CONTINUE", |
||||
"HadoopJarStep": { |
||||
"Jar": "command-runner.jar", |
||||
"Args": ["/usr/lib/spark/bin/run-example", "SparkPi", "10"], |
||||
}, |
||||
} |
||||
] |
||||
|
||||
JOB_FLOW_OVERRIDES = { |
||||
"Name": "PiCalc", |
||||
"ReleaseLabel": "emr-5.29.0", |
||||
"Instances": { |
||||
"InstanceGroups": [ |
||||
{ |
||||
"Name": "Master node", |
||||
"Market": "SPOT", |
||||
"InstanceRole": "MASTER", |
||||
"InstanceType": "m1.medium", |
||||
"InstanceCount": 1, |
||||
} |
||||
], |
||||
"KeepJobFlowAliveWhenNoSteps": True, |
||||
"TerminationProtected": False, |
||||
}, |
||||
"JobFlowRole": "EMR_EC2_DefaultRole", |
||||
"ServiceRole": "EMR_DefaultRole", |
||||
} |
||||
|
||||
with DAG( |
||||
dag_id="emr_job_flow_manual_steps_dag", |
||||
default_args=DEFAULT_ARGS, |
||||
dagrun_timeout=timedelta(hours=2), |
||||
start_date=days_ago(2), |
||||
schedule_interval="0 3 * * *", |
||||
tags=["example"], |
||||
) as dag: |
||||
|
||||
# [START howto_operator_emr_manual_steps_tasks] |
||||
cluster_creator = EmrCreateJobFlowOperator( |
||||
task_id="create_job_flow", |
||||
job_flow_overrides=JOB_FLOW_OVERRIDES, |
||||
aws_conn_id="aws_default", |
||||
emr_conn_id="emr_default", |
||||
) |
||||
|
||||
step_adder = EmrAddStepsOperator( |
||||
task_id="add_steps", |
||||
job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}", |
||||
aws_conn_id="aws_default", |
||||
steps=SPARK_STEPS, |
||||
) |
||||
|
||||
step_checker = EmrStepSensor( |
||||
task_id="watch_step", |
||||
job_flow_id="{{ task_instance.xcom_pull('create_job_flow', key='return_value') }}", |
||||
step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}", |
||||
aws_conn_id="aws_default", |
||||
) |
||||
|
||||
cluster_remover = EmrTerminateJobFlowOperator( |
||||
task_id="remove_cluster", |
||||
job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}", |
||||
aws_conn_id="aws_default", |
||||
) |
||||
|
||||
cluster_creator >> step_adder >> step_checker >> cluster_remover |
||||
# [END howto_operator_emr_manual_steps_tasks] |
||||
@ -0,0 +1,70 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
import os |
||||
|
||||
from airflow import models |
||||
from airflow.providers.amazon.aws.operators.glacier import GlacierCreateJobOperator |
||||
from airflow.providers.amazon.aws.sensors.glacier import GlacierJobOperationSensor |
||||
from airflow.providers.amazon.aws.transfers.glacier_to_gcs import GlacierToGCSOperator |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
VAULT_NAME = "airflow" |
||||
BUCKET_NAME = os.environ.get("GLACIER_GCS_BUCKET_NAME", "gs://glacier_bucket") |
||||
OBJECT_NAME = os.environ.get("GLACIER_OBJECT", "example-text.txt") |
||||
|
||||
with models.DAG( |
||||
"example_glacier_to_gcs", |
||||
schedule_interval=None, |
||||
start_date=days_ago(1), # Override to match your needs |
||||
) as dag: |
||||
# [START howto_glacier_create_job_operator] |
||||
create_glacier_job = GlacierCreateJobOperator( |
||||
task_id="create_glacier_job", |
||||
aws_conn_id="aws_default", |
||||
vault_name=VAULT_NAME, |
||||
) |
||||
JOB_ID = '{{ task_instance.xcom_pull("create_glacier_job")["jobId"] }}' |
||||
# [END howto_glacier_create_job_operator] |
||||
|
||||
# [START howto_glacier_job_operation_sensor] |
||||
wait_for_operation_complete = GlacierJobOperationSensor( |
||||
aws_conn_id="aws_default", |
||||
vault_name=VAULT_NAME, |
||||
job_id=JOB_ID, |
||||
task_id="wait_for_operation_complete", |
||||
) |
||||
# [END howto_glacier_job_operation_sensor] |
||||
|
||||
# [START howto_glacier_transfer_data_to_gcs] |
||||
transfer_archive_to_gcs = GlacierToGCSOperator( |
||||
task_id="transfer_archive_to_gcs", |
||||
aws_conn_id="aws_default", |
||||
gcp_conn_id="google_cloud_default", |
||||
vault_name=VAULT_NAME, |
||||
bucket_name=BUCKET_NAME, |
||||
object_name=OBJECT_NAME, |
||||
gzip=False, |
||||
# Override to match your needs |
||||
# If chunk size is bigger than actual file size |
||||
# then whole file will be downloaded |
||||
chunk_size=1024, |
||||
delegate_to=None, |
||||
google_impersonation_chain=None, |
||||
) |
||||
# [END howto_glacier_transfer_data_to_gcs] |
||||
|
||||
create_glacier_job >> wait_for_operation_complete >> transfer_archive_to_gcs |
||||
@ -0,0 +1,141 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
""" |
||||
This is a more advanced example dag for using `GoogleApiToS3Transfer` which uses xcom to pass data between |
||||
tasks to retrieve specific information about YouTube videos: |
||||
|
||||
First it searches for up to 50 videos (due to pagination) in a given time range |
||||
(YOUTUBE_VIDEO_PUBLISHED_AFTER, YOUTUBE_VIDEO_PUBLISHED_BEFORE) on a YouTube channel (YOUTUBE_CHANNEL_ID) |
||||
saves the response in S3 + passes over the YouTube IDs to the next request which then gets the information |
||||
(YOUTUBE_VIDEO_FIELDS) for the requested videos and saves them in S3 (S3_DESTINATION_KEY). |
||||
|
||||
Further information: |
||||
|
||||
YOUTUBE_VIDEO_PUBLISHED_AFTER and YOUTUBE_VIDEO_PUBLISHED_BEFORE needs to be formatted |
||||
"YYYY-MM-DDThh:mm:ss.sZ". See https://developers.google.com/youtube/v3/docs/search/list for more information. |
||||
YOUTUBE_VIDEO_PARTS depends on the fields you pass via YOUTUBE_VIDEO_FIELDS. See |
||||
https://developers.google.com/youtube/v3/docs/videos/list#parameters for more information. |
||||
YOUTUBE_CONN_ID is optional for public videos. It does only need to authenticate when there are private videos |
||||
on a YouTube channel you want to retrieve. |
||||
""" |
||||
|
||||
from os import getenv |
||||
|
||||
from airflow import DAG |
||||
from airflow.operators.dummy import DummyOperator |
||||
from airflow.operators.python import BranchPythonOperator |
||||
from airflow.providers.amazon.aws.transfers.google_api_to_s3 import ( |
||||
GoogleApiToS3Operator, |
||||
) |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
# [START howto_operator_google_api_to_s3_transfer_advanced_env_variables] |
||||
YOUTUBE_CONN_ID = getenv("YOUTUBE_CONN_ID", "google_cloud_default") |
||||
YOUTUBE_CHANNEL_ID = getenv( |
||||
"YOUTUBE_CHANNEL_ID", "UCSXwxpWZQ7XZ1WL3wqevChA" |
||||
) # "Apache Airflow" |
||||
YOUTUBE_VIDEO_PUBLISHED_AFTER = getenv( |
||||
"YOUTUBE_VIDEO_PUBLISHED_AFTER", "2019-09-25T00:00:00Z" |
||||
) |
||||
YOUTUBE_VIDEO_PUBLISHED_BEFORE = getenv( |
||||
"YOUTUBE_VIDEO_PUBLISHED_BEFORE", "2019-10-18T00:00:00Z" |
||||
) |
||||
S3_DESTINATION_KEY = getenv("S3_DESTINATION_KEY", "s3://bucket/key.json") |
||||
YOUTUBE_VIDEO_PARTS = getenv("YOUTUBE_VIDEO_PARTS", "snippet") |
||||
YOUTUBE_VIDEO_FIELDS = getenv( |
||||
"YOUTUBE_VIDEO_FIELDS", "items(id,snippet(description,publishedAt,tags,title))" |
||||
) |
||||
# [END howto_operator_google_api_to_s3_transfer_advanced_env_variables] |
||||
|
||||
|
||||
# pylint: disable=unused-argument |
||||
# [START howto_operator_google_api_to_s3_transfer_advanced_task_1_2] |
||||
def _check_and_transform_video_ids(xcom_key, task_ids, task_instance, **kwargs): |
||||
video_ids_response = task_instance.xcom_pull(task_ids=task_ids, key=xcom_key) |
||||
video_ids = [item["id"]["videoId"] for item in video_ids_response["items"]] |
||||
|
||||
if video_ids: |
||||
task_instance.xcom_push(key="video_ids", value={"id": ",".join(video_ids)}) |
||||
return "video_data_to_s3" |
||||
return "no_video_ids" |
||||
|
||||
|
||||
# [END howto_operator_google_api_to_s3_transfer_advanced_task_1_2] |
||||
# pylint: enable=unused-argument |
||||
|
||||
s3_directory, s3_file = S3_DESTINATION_KEY.rsplit("/", 1) |
||||
s3_file_name, _ = s3_file.rsplit(".", 1) |
||||
|
||||
with DAG( |
||||
dag_id="example_google_api_to_s3_transfer_advanced", |
||||
schedule_interval=None, |
||||
start_date=days_ago(1), |
||||
tags=["example"], |
||||
) as dag: |
||||
# [START howto_operator_google_api_to_s3_transfer_advanced_task_1] |
||||
task_video_ids_to_s3 = GoogleApiToS3Operator( |
||||
gcp_conn_id=YOUTUBE_CONN_ID, |
||||
google_api_service_name="youtube", |
||||
google_api_service_version="v3", |
||||
google_api_endpoint_path="youtube.search.list", |
||||
google_api_endpoint_params={ |
||||
"part": "snippet", |
||||
"channelId": YOUTUBE_CHANNEL_ID, |
||||
"maxResults": 50, |
||||
"publishedAfter": YOUTUBE_VIDEO_PUBLISHED_AFTER, |
||||
"publishedBefore": YOUTUBE_VIDEO_PUBLISHED_BEFORE, |
||||
"type": "video", |
||||
"fields": "items/id/videoId", |
||||
}, |
||||
google_api_response_via_xcom="video_ids_response", |
||||
s3_destination_key=f"{s3_directory}/youtube_search_{s3_file_name}.json", |
||||
task_id="video_ids_to_s3", |
||||
) |
||||
# [END howto_operator_google_api_to_s3_transfer_advanced_task_1] |
||||
# [START howto_operator_google_api_to_s3_transfer_advanced_task_1_1] |
||||
task_check_and_transform_video_ids = BranchPythonOperator( |
||||
python_callable=_check_and_transform_video_ids, |
||||
op_args=[ |
||||
task_video_ids_to_s3.google_api_response_via_xcom, |
||||
task_video_ids_to_s3.task_id, |
||||
], |
||||
task_id="check_and_transform_video_ids", |
||||
) |
||||
# [END howto_operator_google_api_to_s3_transfer_advanced_task_1_1] |
||||
# [START howto_operator_google_api_to_s3_transfer_advanced_task_2] |
||||
task_video_data_to_s3 = GoogleApiToS3Operator( |
||||
gcp_conn_id=YOUTUBE_CONN_ID, |
||||
google_api_service_name="youtube", |
||||
google_api_service_version="v3", |
||||
google_api_endpoint_path="youtube.videos.list", |
||||
google_api_endpoint_params={ |
||||
"part": YOUTUBE_VIDEO_PARTS, |
||||
"maxResults": 50, |
||||
"fields": YOUTUBE_VIDEO_FIELDS, |
||||
}, |
||||
google_api_endpoint_params_via_xcom="video_ids", |
||||
s3_destination_key=f"{s3_directory}/youtube_videos_{s3_file_name}.json", |
||||
task_id="video_data_to_s3", |
||||
) |
||||
# [END howto_operator_google_api_to_s3_transfer_advanced_task_2] |
||||
# [START howto_operator_google_api_to_s3_transfer_advanced_task_2_1] |
||||
task_no_video_ids = DummyOperator(task_id="no_video_ids") |
||||
# [END howto_operator_google_api_to_s3_transfer_advanced_task_2_1] |
||||
task_video_ids_to_s3 >> task_check_and_transform_video_ids >> [ |
||||
task_video_data_to_s3, |
||||
task_no_video_ids, |
||||
] |
||||
@ -0,0 +1,57 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
""" |
||||
This is a basic example dag for using `GoogleApiToS3Transfer` to retrieve Google Sheets data: |
||||
|
||||
You need to set all env variables to request the data. |
||||
""" |
||||
|
||||
from os import getenv |
||||
|
||||
from airflow import DAG |
||||
from airflow.providers.amazon.aws.transfers.google_api_to_s3 import ( |
||||
GoogleApiToS3Operator, |
||||
) |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
# [START howto_operator_google_api_to_s3_transfer_basic_env_variables] |
||||
GOOGLE_SHEET_ID = getenv("GOOGLE_SHEET_ID") |
||||
GOOGLE_SHEET_RANGE = getenv("GOOGLE_SHEET_RANGE") |
||||
S3_DESTINATION_KEY = getenv("S3_DESTINATION_KEY", "s3://bucket/key.json") |
||||
# [END howto_operator_google_api_to_s3_transfer_basic_env_variables] |
||||
|
||||
|
||||
with DAG( |
||||
dag_id="example_google_api_to_s3_transfer_basic", |
||||
schedule_interval=None, |
||||
start_date=days_ago(1), |
||||
tags=["example"], |
||||
) as dag: |
||||
# [START howto_operator_google_api_to_s3_transfer_basic_task_1] |
||||
task_google_sheets_values_to_s3 = GoogleApiToS3Operator( |
||||
google_api_service_name="sheets", |
||||
google_api_service_version="v4", |
||||
google_api_endpoint_path="sheets.spreadsheets.values.get", |
||||
google_api_endpoint_params={ |
||||
"spreadsheetId": GOOGLE_SHEET_ID, |
||||
"range": GOOGLE_SHEET_RANGE, |
||||
}, |
||||
s3_destination_key=S3_DESTINATION_KEY, |
||||
task_id="google_sheets_values_to_s3", |
||||
dag=dag, |
||||
) |
||||
# [END howto_operator_google_api_to_s3_transfer_basic_task_1] |
||||
@ -0,0 +1,53 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
""" |
||||
This is an example dag for using `ImapAttachmentToS3Operator` to transfer an email attachment via IMAP |
||||
protocol from a mail server to S3 Bucket. |
||||
""" |
||||
|
||||
from os import getenv |
||||
|
||||
from airflow import DAG |
||||
from airflow.providers.amazon.aws.transfers.imap_attachment_to_s3 import ( |
||||
ImapAttachmentToS3Operator, |
||||
) |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
# [START howto_operator_imap_attachment_to_s3_env_variables] |
||||
IMAP_ATTACHMENT_NAME = getenv("IMAP_ATTACHMENT_NAME", "test.txt") |
||||
IMAP_MAIL_FOLDER = getenv("IMAP_MAIL_FOLDER", "INBOX") |
||||
IMAP_MAIL_FILTER = getenv("IMAP_MAIL_FILTER", "All") |
||||
S3_DESTINATION_KEY = getenv("S3_DESTINATION_KEY", "s3://bucket/key.json") |
||||
# [END howto_operator_imap_attachment_to_s3_env_variables] |
||||
|
||||
with DAG( |
||||
dag_id="example_imap_attachment_to_s3", |
||||
start_date=days_ago(1), |
||||
schedule_interval=None, |
||||
tags=["example"], |
||||
) as dag: |
||||
# [START howto_operator_imap_attachment_to_s3_task_1] |
||||
task_transfer_imap_attachment_to_s3 = ImapAttachmentToS3Operator( |
||||
imap_attachment_name=IMAP_ATTACHMENT_NAME, |
||||
s3_key=S3_DESTINATION_KEY, |
||||
imap_mail_folder=IMAP_MAIL_FOLDER, |
||||
imap_mail_filter=IMAP_MAIL_FILTER, |
||||
task_id="transfer_imap_attachment_to_s3", |
||||
dag=dag, |
||||
) |
||||
# [END howto_operator_imap_attachment_to_s3_task_1] |
||||
@ -0,0 +1,69 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
import os |
||||
|
||||
from airflow.models.dag import DAG |
||||
from airflow.operators.python import PythonOperator |
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook |
||||
from airflow.providers.amazon.aws.operators.s3_bucket import ( |
||||
S3CreateBucketOperator, |
||||
S3DeleteBucketOperator, |
||||
) |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
BUCKET_NAME = os.environ.get("BUCKET_NAME", "test-airflow-12345") |
||||
|
||||
|
||||
def upload_keys(): |
||||
"""This is a python callback to add keys into the s3 bucket""" |
||||
# add keys to bucket |
||||
s3_hook = S3Hook() |
||||
for i in range(0, 3): |
||||
s3_hook.load_string( |
||||
string_data="input", |
||||
key=f"path/data{i}", |
||||
bucket_name=BUCKET_NAME, |
||||
) |
||||
|
||||
|
||||
with DAG( |
||||
dag_id="s3_bucket_dag", |
||||
schedule_interval=None, |
||||
start_date=days_ago(2), |
||||
max_active_runs=1, |
||||
tags=["example"], |
||||
) as dag: |
||||
|
||||
# [START howto_operator_s3_bucket] |
||||
create_bucket = S3CreateBucketOperator( |
||||
task_id="s3_bucket_dag_create", |
||||
bucket_name=BUCKET_NAME, |
||||
region_name="us-east-1", |
||||
) |
||||
|
||||
add_keys_to_bucket = PythonOperator( |
||||
task_id="s3_bucket_dag_add_keys_to_bucket", python_callable=upload_keys |
||||
) |
||||
|
||||
delete_bucket = S3DeleteBucketOperator( |
||||
task_id="s3_bucket_dag_delete", |
||||
bucket_name=BUCKET_NAME, |
||||
force_delete=True, |
||||
) |
||||
# [END howto_operator_s3_bucket] |
||||
|
||||
create_bucket >> add_keys_to_bucket >> delete_bucket |
||||
@ -0,0 +1,73 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
import os |
||||
|
||||
from airflow.models.dag import DAG |
||||
from airflow.providers.amazon.aws.operators.s3_bucket import ( |
||||
S3CreateBucketOperator, |
||||
S3DeleteBucketOperator, |
||||
) |
||||
from airflow.providers.amazon.aws.operators.s3_bucket_tagging import ( |
||||
S3DeleteBucketTaggingOperator, |
||||
S3GetBucketTaggingOperator, |
||||
S3PutBucketTaggingOperator, |
||||
) |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
BUCKET_NAME = os.environ.get("BUCKET_NAME", "test-s3-bucket-tagging") |
||||
TAG_KEY = os.environ.get("TAG_KEY", "test-s3-bucket-tagging-key") |
||||
TAG_VALUE = os.environ.get("TAG_VALUE", "test-s3-bucket-tagging-value") |
||||
|
||||
|
||||
with DAG( |
||||
dag_id="s3_bucket_tagging_dag", |
||||
schedule_interval=None, |
||||
start_date=days_ago(2), |
||||
max_active_runs=1, |
||||
tags=["example"], |
||||
) as dag: |
||||
|
||||
create_bucket = S3CreateBucketOperator( |
||||
task_id="s3_bucket_tagging_dag_create", |
||||
bucket_name=BUCKET_NAME, |
||||
region_name="us-east-1", |
||||
) |
||||
|
||||
delete_bucket = S3DeleteBucketOperator( |
||||
task_id="s3_bucket_tagging_dag_delete", |
||||
bucket_name=BUCKET_NAME, |
||||
force_delete=True, |
||||
) |
||||
|
||||
# [START howto_operator_s3_bucket_tagging] |
||||
get_tagging = S3GetBucketTaggingOperator( |
||||
task_id="s3_bucket_tagging_dag_get_tagging", bucket_name=BUCKET_NAME |
||||
) |
||||
|
||||
put_tagging = S3PutBucketTaggingOperator( |
||||
task_id="s3_bucket_tagging_dag_put_tagging", |
||||
bucket_name=BUCKET_NAME, |
||||
key=TAG_KEY, |
||||
value=TAG_VALUE, |
||||
) |
||||
|
||||
delete_tagging = S3DeleteBucketTaggingOperator( |
||||
task_id="s3_bucket_tagging_dag_delete_tagging", bucket_name=BUCKET_NAME |
||||
) |
||||
# [END howto_operator_s3_bucket_tagging] |
||||
|
||||
create_bucket >> put_tagging >> get_tagging >> delete_tagging >> delete_bucket |
||||
@ -0,0 +1,90 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
""" |
||||
This is an example dag for using `S3ToRedshiftOperator` to copy a S3 key into a Redshift table. |
||||
""" |
||||
|
||||
from os import getenv |
||||
|
||||
from airflow import DAG |
||||
from airflow.operators.python import PythonOperator |
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook |
||||
from airflow.providers.amazon.aws.transfers.s3_to_redshift import S3ToRedshiftOperator |
||||
from airflow.providers.postgres.operators.postgres import PostgresOperator |
||||
from airflow.utils.dates import days_ago |
||||
|
||||
# [START howto_operator_s3_to_redshift_env_variables] |
||||
S3_BUCKET = getenv("S3_BUCKET", "test-bucket") |
||||
S3_KEY = getenv("S3_KEY", "key") |
||||
REDSHIFT_TABLE = getenv("REDSHIFT_TABLE", "test_table") |
||||
# [END howto_operator_s3_to_redshift_env_variables] |
||||
|
||||
|
||||
def _add_sample_data_to_s3(): |
||||
s3_hook = S3Hook() |
||||
s3_hook.load_string( |
||||
"0,Airflow", f"{S3_KEY}/{REDSHIFT_TABLE}", S3_BUCKET, replace=True |
||||
) |
||||
|
||||
|
||||
def _remove_sample_data_from_s3(): |
||||
s3_hook = S3Hook() |
||||
if s3_hook.check_for_key(f"{S3_KEY}/{REDSHIFT_TABLE}", S3_BUCKET): |
||||
s3_hook.delete_objects(S3_BUCKET, f"{S3_KEY}/{REDSHIFT_TABLE}") |
||||
|
||||
|
||||
with DAG( |
||||
dag_id="example_s3_to_redshift", |
||||
start_date=days_ago(1), |
||||
schedule_interval=None, |
||||
tags=["example"], |
||||
) as dag: |
||||
setup__task_add_sample_data_to_s3 = PythonOperator( |
||||
python_callable=_add_sample_data_to_s3, task_id="setup__add_sample_data_to_s3" |
||||
) |
||||
setup__task_create_table = PostgresOperator( |
||||
sql=f"CREATE TABLE IF NOT EXISTS {REDSHIFT_TABLE}(Id int, Name varchar)", |
||||
postgres_conn_id="redshift_default", |
||||
task_id="setup__create_table", |
||||
) |
||||
# [START howto_operator_s3_to_redshift_task_1] |
||||
task_transfer_s3_to_redshift = S3ToRedshiftOperator( |
||||
s3_bucket=S3_BUCKET, |
||||
s3_key=S3_KEY, |
||||
schema="PUBLIC", |
||||
table=REDSHIFT_TABLE, |
||||
copy_options=["csv"], |
||||
task_id="transfer_s3_to_redshift", |
||||
) |
||||
# [END howto_operator_s3_to_redshift_task_1] |
||||
teardown__task_drop_table = PostgresOperator( |
||||
sql=f"DROP TABLE IF EXISTS {REDSHIFT_TABLE}", |
||||
postgres_conn_id="redshift_default", |
||||
task_id="teardown__drop_table", |
||||
) |
||||
teardown__task_remove_sample_data_from_s3 = PythonOperator( |
||||
python_callable=_remove_sample_data_from_s3, |
||||
task_id="teardown__remove_sample_data_from_s3", |
||||
) |
||||
[ |
||||
setup__task_add_sample_data_to_s3, |
||||
setup__task_create_table, |
||||
] >> task_transfer_s3_to_redshift >> [ |
||||
teardown__task_drop_table, |
||||
teardown__task_remove_sample_data_from_s3, |
||||
] |
||||
@ -0,0 +1,16 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
@ -0,0 +1,263 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
"""This module contains AWS Athena hook""" |
||||
from time import sleep |
||||
from typing import Any, Dict, Optional |
||||
|
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
from botocore.paginate import PageIterator |
||||
|
||||
|
||||
class AWSAthenaHook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS Athena to run, poll queries and return query results |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
|
||||
:param sleep_time: Time (in seconds) to wait between two consecutive calls to check query status on Athena |
||||
:type sleep_time: int |
||||
""" |
||||
|
||||
INTERMEDIATE_STATES = ( |
||||
"QUEUED", |
||||
"RUNNING", |
||||
) |
||||
FAILURE_STATES = ( |
||||
"FAILED", |
||||
"CANCELLED", |
||||
) |
||||
SUCCESS_STATES = ("SUCCEEDED",) |
||||
|
||||
def __init__(self, *args: Any, sleep_time: int = 30, **kwargs: Any) -> None: |
||||
super().__init__(client_type="athena", *args, **kwargs) # type: ignore |
||||
self.sleep_time = sleep_time |
||||
|
||||
def run_query( |
||||
self, |
||||
query: str, |
||||
query_context: Dict[str, str], |
||||
result_configuration: Dict[str, Any], |
||||
client_request_token: Optional[str] = None, |
||||
workgroup: str = "primary", |
||||
) -> str: |
||||
""" |
||||
Run Presto query on athena with provided config and return submitted query_execution_id |
||||
|
||||
:param query: Presto query to run |
||||
:type query: str |
||||
:param query_context: Context in which query need to be run |
||||
:type query_context: dict |
||||
:param result_configuration: Dict with path to store results in and config related to encryption |
||||
:type result_configuration: dict |
||||
:param client_request_token: Unique token created by user to avoid multiple executions of same query |
||||
:type client_request_token: str |
||||
:param workgroup: Athena workgroup name, when not specified, will be 'primary' |
||||
:type workgroup: str |
||||
:return: str |
||||
""" |
||||
params = { |
||||
"QueryString": query, |
||||
"QueryExecutionContext": query_context, |
||||
"ResultConfiguration": result_configuration, |
||||
"WorkGroup": workgroup, |
||||
} |
||||
if client_request_token: |
||||
params["ClientRequestToken"] = client_request_token |
||||
response = self.get_conn().start_query_execution(**params) |
||||
query_execution_id = response["QueryExecutionId"] |
||||
return query_execution_id |
||||
|
||||
def check_query_status(self, query_execution_id: str) -> Optional[str]: |
||||
""" |
||||
Fetch the status of submitted athena query. Returns None or one of valid query states. |
||||
|
||||
:param query_execution_id: Id of submitted athena query |
||||
:type query_execution_id: str |
||||
:return: str |
||||
""" |
||||
response = self.get_conn().get_query_execution( |
||||
QueryExecutionId=query_execution_id |
||||
) |
||||
state = None |
||||
try: |
||||
state = response["QueryExecution"]["Status"]["State"] |
||||
except Exception as ex: # pylint: disable=broad-except |
||||
self.log.error("Exception while getting query state %s", ex) |
||||
finally: |
||||
# The error is being absorbed here and is being handled by the caller. |
||||
# The error is being absorbed to implement retries. |
||||
return state # pylint: disable=lost-exception |
||||
|
||||
def get_state_change_reason(self, query_execution_id: str) -> Optional[str]: |
||||
""" |
||||
Fetch the reason for a state change (e.g. error message). Returns None or reason string. |
||||
|
||||
:param query_execution_id: Id of submitted athena query |
||||
:type query_execution_id: str |
||||
:return: str |
||||
""" |
||||
response = self.get_conn().get_query_execution( |
||||
QueryExecutionId=query_execution_id |
||||
) |
||||
reason = None |
||||
try: |
||||
reason = response["QueryExecution"]["Status"]["StateChangeReason"] |
||||
except Exception as ex: # pylint: disable=broad-except |
||||
self.log.error("Exception while getting query state change reason: %s", ex) |
||||
finally: |
||||
# The error is being absorbed here and is being handled by the caller. |
||||
# The error is being absorbed to implement retries. |
||||
return reason # pylint: disable=lost-exception |
||||
|
||||
def get_query_results( |
||||
self, |
||||
query_execution_id: str, |
||||
next_token_id: Optional[str] = None, |
||||
max_results: int = 1000, |
||||
) -> Optional[dict]: |
||||
""" |
||||
Fetch submitted athena query results. returns none if query is in intermediate state or |
||||
failed/cancelled state else dict of query output |
||||
|
||||
:param query_execution_id: Id of submitted athena query |
||||
:type query_execution_id: str |
||||
:param next_token_id: The token that specifies where to start pagination. |
||||
:type next_token_id: str |
||||
:param max_results: The maximum number of results (rows) to return in this request. |
||||
:type max_results: int |
||||
:return: dict |
||||
""" |
||||
query_state = self.check_query_status(query_execution_id) |
||||
if query_state is None: |
||||
self.log.error("Invalid Query state") |
||||
return None |
||||
elif ( |
||||
query_state in self.INTERMEDIATE_STATES |
||||
or query_state in self.FAILURE_STATES |
||||
): |
||||
self.log.error('Query is in "%s" state. Cannot fetch results', query_state) |
||||
return None |
||||
result_params = { |
||||
"QueryExecutionId": query_execution_id, |
||||
"MaxResults": max_results, |
||||
} |
||||
if next_token_id: |
||||
result_params["NextToken"] = next_token_id |
||||
return self.get_conn().get_query_results(**result_params) |
||||
|
||||
def get_query_results_paginator( |
||||
self, |
||||
query_execution_id: str, |
||||
max_items: Optional[int] = None, |
||||
page_size: Optional[int] = None, |
||||
starting_token: Optional[str] = None, |
||||
) -> Optional[PageIterator]: |
||||
""" |
||||
Fetch submitted athena query results. returns none if query is in intermediate state or |
||||
failed/cancelled state else a paginator to iterate through pages of results. If you |
||||
wish to get all results at once, call build_full_result() on the returned PageIterator |
||||
|
||||
:param query_execution_id: Id of submitted athena query |
||||
:type query_execution_id: str |
||||
:param max_items: The total number of items to return. |
||||
:type max_items: int |
||||
:param page_size: The size of each page. |
||||
:type page_size: int |
||||
:param starting_token: A token to specify where to start paginating. |
||||
:type starting_token: str |
||||
:return: PageIterator |
||||
""" |
||||
query_state = self.check_query_status(query_execution_id) |
||||
if query_state is None: |
||||
self.log.error("Invalid Query state (null)") |
||||
return None |
||||
if ( |
||||
query_state in self.INTERMEDIATE_STATES |
||||
or query_state in self.FAILURE_STATES |
||||
): |
||||
self.log.error('Query is in "%s" state. Cannot fetch results', query_state) |
||||
return None |
||||
result_params = { |
||||
"QueryExecutionId": query_execution_id, |
||||
"PaginationConfig": { |
||||
"MaxItems": max_items, |
||||
"PageSize": page_size, |
||||
"StartingToken": starting_token, |
||||
}, |
||||
} |
||||
paginator = self.get_conn().get_paginator("get_query_results") |
||||
return paginator.paginate(**result_params) |
||||
|
||||
def poll_query_status( |
||||
self, query_execution_id: str, max_tries: Optional[int] = None |
||||
) -> Optional[str]: |
||||
""" |
||||
Poll the status of submitted athena query until query state reaches final state. |
||||
Returns one of the final states |
||||
|
||||
:param query_execution_id: Id of submitted athena query |
||||
:type query_execution_id: str |
||||
:param max_tries: Number of times to poll for query state before function exits |
||||
:type max_tries: int |
||||
:return: str |
||||
""" |
||||
try_number = 1 |
||||
final_query_state = ( |
||||
None # Query state when query reaches final state or max_tries reached |
||||
) |
||||
while True: |
||||
query_state = self.check_query_status(query_execution_id) |
||||
if query_state is None: |
||||
self.log.info( |
||||
"Trial %s: Invalid query state. Retrying again", try_number |
||||
) |
||||
elif query_state in self.INTERMEDIATE_STATES: |
||||
self.log.info( |
||||
"Trial %s: Query is still in an intermediate state - %s", |
||||
try_number, |
||||
query_state, |
||||
) |
||||
else: |
||||
self.log.info( |
||||
"Trial %s: Query execution completed. Final state is %s}", |
||||
try_number, |
||||
query_state, |
||||
) |
||||
final_query_state = query_state |
||||
break |
||||
if max_tries and try_number >= max_tries: # Break loop if max_tries reached |
||||
final_query_state = query_state |
||||
break |
||||
try_number += 1 |
||||
sleep(self.sleep_time) |
||||
return final_query_state |
||||
|
||||
def stop_query(self, query_execution_id: str) -> Dict: |
||||
""" |
||||
Cancel the submitted athena query |
||||
|
||||
:param query_execution_id: Id of submitted athena query |
||||
:type query_execution_id: str |
||||
:return: dict |
||||
""" |
||||
return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id) |
||||
@ -0,0 +1,29 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
"""This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`.""" |
||||
|
||||
import warnings |
||||
|
||||
# pylint: disable=unused-import |
||||
from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook # noqa |
||||
|
||||
warnings.warn( |
||||
"This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`.", |
||||
DeprecationWarning, |
||||
stacklevel=2, |
||||
) |
||||
@ -0,0 +1,596 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
""" |
||||
This module contains Base AWS Hook. |
||||
|
||||
.. seealso:: |
||||
For more information on how to use this hook, take a look at the guide: |
||||
:ref:`howto/connection:AWSHook` |
||||
""" |
||||
|
||||
import configparser |
||||
import datetime |
||||
import logging |
||||
from typing import Any, Dict, Optional, Tuple, Union |
||||
|
||||
import boto3 |
||||
import botocore |
||||
import botocore.session |
||||
from botocore.config import Config |
||||
from botocore.credentials import ReadOnlyCredentials |
||||
|
||||
try: |
||||
from functools import cached_property |
||||
except ImportError: |
||||
from cached_property import cached_property |
||||
|
||||
from airflow.exceptions import AirflowException |
||||
from airflow.hooks.base import BaseHook |
||||
from airflow.models.connection import Connection |
||||
from airflow.utils.log.logging_mixin import LoggingMixin |
||||
from dateutil.tz import tzlocal |
||||
|
||||
|
||||
class _SessionFactory(LoggingMixin): |
||||
def __init__( |
||||
self, conn: Connection, region_name: Optional[str], config: Config |
||||
) -> None: |
||||
super().__init__() |
||||
self.conn = conn |
||||
self.region_name = region_name |
||||
self.config = config |
||||
self.extra_config = self.conn.extra_dejson |
||||
|
||||
def create_session(self) -> boto3.session.Session: |
||||
"""Create AWS session.""" |
||||
session_kwargs = {} |
||||
if "session_kwargs" in self.extra_config: |
||||
self.log.info( |
||||
"Retrieving session_kwargs from Connection.extra_config['session_kwargs']: %s", |
||||
self.extra_config["session_kwargs"], |
||||
) |
||||
session_kwargs = self.extra_config["session_kwargs"] |
||||
session = self._create_basic_session(session_kwargs=session_kwargs) |
||||
role_arn = self._read_role_arn_from_extra_config() |
||||
# If role_arn was specified then STS + assume_role |
||||
if role_arn is None: |
||||
return session |
||||
|
||||
return self._impersonate_to_role( |
||||
role_arn=role_arn, session=session, session_kwargs=session_kwargs |
||||
) |
||||
|
||||
def _create_basic_session( |
||||
self, session_kwargs: Dict[str, Any] |
||||
) -> boto3.session.Session: |
||||
( |
||||
aws_access_key_id, |
||||
aws_secret_access_key, |
||||
) = self._read_credentials_from_connection() |
||||
aws_session_token = self.extra_config.get("aws_session_token") |
||||
region_name = self.region_name |
||||
if self.region_name is None and "region_name" in self.extra_config: |
||||
self.log.info( |
||||
"Retrieving region_name from Connection.extra_config['region_name']" |
||||
) |
||||
region_name = self.extra_config["region_name"] |
||||
self.log.info( |
||||
"Creating session with aws_access_key_id=%s region_name=%s", |
||||
aws_access_key_id, |
||||
region_name, |
||||
) |
||||
|
||||
return boto3.session.Session( |
||||
aws_access_key_id=aws_access_key_id, |
||||
aws_secret_access_key=aws_secret_access_key, |
||||
region_name=region_name, |
||||
aws_session_token=aws_session_token, |
||||
**session_kwargs, |
||||
) |
||||
|
||||
def _impersonate_to_role( |
||||
self, |
||||
role_arn: str, |
||||
session: boto3.session.Session, |
||||
session_kwargs: Dict[str, Any], |
||||
) -> boto3.session.Session: |
||||
assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {}) |
||||
assume_role_method = self.extra_config.get("assume_role_method") |
||||
self.log.info("assume_role_method=%s", assume_role_method) |
||||
if not assume_role_method or assume_role_method == "assume_role": |
||||
sts_client = session.client("sts", config=self.config) |
||||
sts_response = self._assume_role( |
||||
sts_client=sts_client, |
||||
role_arn=role_arn, |
||||
assume_role_kwargs=assume_role_kwargs, |
||||
) |
||||
elif assume_role_method == "assume_role_with_saml": |
||||
sts_client = session.client("sts", config=self.config) |
||||
sts_response = self._assume_role_with_saml( |
||||
sts_client=sts_client, |
||||
role_arn=role_arn, |
||||
assume_role_kwargs=assume_role_kwargs, |
||||
) |
||||
elif assume_role_method == "assume_role_with_web_identity": |
||||
botocore_session = self._assume_role_with_web_identity( |
||||
role_arn=role_arn, |
||||
assume_role_kwargs=assume_role_kwargs, |
||||
base_session=session._session, # pylint: disable=protected-access |
||||
) |
||||
return boto3.session.Session( |
||||
region_name=session.region_name, |
||||
botocore_session=botocore_session, |
||||
**session_kwargs, |
||||
) |
||||
else: |
||||
raise NotImplementedError( |
||||
f"assume_role_method={assume_role_method} in Connection {self.conn.conn_id} Extra." |
||||
'Currently "assume_role" or "assume_role_with_saml" are supported.' |
||||
'(Exclude this setting will default to "assume_role").' |
||||
) |
||||
# Use credentials retrieved from STS |
||||
credentials = sts_response["Credentials"] |
||||
aws_access_key_id = credentials["AccessKeyId"] |
||||
aws_secret_access_key = credentials["SecretAccessKey"] |
||||
aws_session_token = credentials["SessionToken"] |
||||
self.log.info( |
||||
"Creating session with aws_access_key_id=%s region_name=%s", |
||||
aws_access_key_id, |
||||
session.region_name, |
||||
) |
||||
|
||||
return boto3.session.Session( |
||||
aws_access_key_id=aws_access_key_id, |
||||
aws_secret_access_key=aws_secret_access_key, |
||||
region_name=session.region_name, |
||||
aws_session_token=aws_session_token, |
||||
**session_kwargs, |
||||
) |
||||
|
||||
def _read_role_arn_from_extra_config(self) -> Optional[str]: |
||||
aws_account_id = self.extra_config.get("aws_account_id") |
||||
aws_iam_role = self.extra_config.get("aws_iam_role") |
||||
role_arn = self.extra_config.get("role_arn") |
||||
if role_arn is None and aws_account_id is not None and aws_iam_role is not None: |
||||
self.log.info("Constructing role_arn from aws_account_id and aws_iam_role") |
||||
role_arn = f"arn:aws:iam::{aws_account_id}:role/{aws_iam_role}" |
||||
self.log.info("role_arn is %s", role_arn) |
||||
return role_arn |
||||
|
||||
def _read_credentials_from_connection(self) -> Tuple[Optional[str], Optional[str]]: |
||||
aws_access_key_id = None |
||||
aws_secret_access_key = None |
||||
if self.conn.login: |
||||
aws_access_key_id = self.conn.login |
||||
aws_secret_access_key = self.conn.password |
||||
self.log.info("Credentials retrieved from login") |
||||
elif ( |
||||
"aws_access_key_id" in self.extra_config |
||||
and "aws_secret_access_key" in self.extra_config |
||||
): |
||||
aws_access_key_id = self.extra_config["aws_access_key_id"] |
||||
aws_secret_access_key = self.extra_config["aws_secret_access_key"] |
||||
self.log.info("Credentials retrieved from extra_config") |
||||
elif "s3_config_file" in self.extra_config: |
||||
aws_access_key_id, aws_secret_access_key = _parse_s3_config( |
||||
self.extra_config["s3_config_file"], |
||||
self.extra_config.get("s3_config_format"), |
||||
self.extra_config.get("profile"), |
||||
) |
||||
self.log.info("Credentials retrieved from extra_config['s3_config_file']") |
||||
else: |
||||
self.log.info("No credentials retrieved from Connection") |
||||
return aws_access_key_id, aws_secret_access_key |
||||
|
||||
def _assume_role( |
||||
self, |
||||
sts_client: boto3.client, |
||||
role_arn: str, |
||||
assume_role_kwargs: Dict[str, Any], |
||||
) -> Dict: |
||||
if "external_id" in self.extra_config: # Backwards compatibility |
||||
assume_role_kwargs["ExternalId"] = self.extra_config.get("external_id") |
||||
role_session_name = f"Airflow_{self.conn.conn_id}" |
||||
self.log.info( |
||||
"Doing sts_client.assume_role to role_arn=%s (role_session_name=%s)", |
||||
role_arn, |
||||
role_session_name, |
||||
) |
||||
return sts_client.assume_role( |
||||
RoleArn=role_arn, RoleSessionName=role_session_name, **assume_role_kwargs |
||||
) |
||||
|
||||
def _assume_role_with_saml( |
||||
self, |
||||
sts_client: boto3.client, |
||||
role_arn: str, |
||||
assume_role_kwargs: Dict[str, Any], |
||||
) -> Dict[str, Any]: |
||||
saml_config = self.extra_config["assume_role_with_saml"] |
||||
principal_arn = saml_config["principal_arn"] |
||||
|
||||
idp_auth_method = saml_config["idp_auth_method"] |
||||
if idp_auth_method == "http_spegno_auth": |
||||
saml_assertion = self._fetch_saml_assertion_using_http_spegno_auth( |
||||
saml_config |
||||
) |
||||
else: |
||||
raise NotImplementedError( |
||||
f"idp_auth_method={idp_auth_method} in Connection {self.conn.conn_id} Extra." |
||||
'Currently only "http_spegno_auth" is supported, and must be specified.' |
||||
) |
||||
|
||||
self.log.info("Doing sts_client.assume_role_with_saml to role_arn=%s", role_arn) |
||||
return sts_client.assume_role_with_saml( |
||||
RoleArn=role_arn, |
||||
PrincipalArn=principal_arn, |
||||
SAMLAssertion=saml_assertion, |
||||
**assume_role_kwargs, |
||||
) |
||||
|
||||
def _fetch_saml_assertion_using_http_spegno_auth( |
||||
self, saml_config: Dict[str, Any] |
||||
) -> str: |
||||
import requests |
||||
|
||||
# requests_gssapi will need paramiko > 2.6 since you'll need |
||||
# 'gssapi' not 'python-gssapi' from PyPi. |
||||
# https://github.com/paramiko/paramiko/pull/1311 |
||||
import requests_gssapi |
||||
from lxml import etree |
||||
|
||||
idp_url = saml_config["idp_url"] |
||||
self.log.info("idp_url= %s", idp_url) |
||||
idp_request_kwargs = saml_config["idp_request_kwargs"] |
||||
auth = requests_gssapi.HTTPSPNEGOAuth() |
||||
if "mutual_authentication" in saml_config: |
||||
mutual_auth = saml_config["mutual_authentication"] |
||||
if mutual_auth == "REQUIRED": |
||||
auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.REQUIRED) |
||||
elif mutual_auth == "OPTIONAL": |
||||
auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.OPTIONAL) |
||||
elif mutual_auth == "DISABLED": |
||||
auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.DISABLED) |
||||
else: |
||||
raise NotImplementedError( |
||||
f"mutual_authentication={mutual_auth} in Connection {self.conn.conn_id} Extra." |
||||
'Currently "REQUIRED", "OPTIONAL" and "DISABLED" are supported.' |
||||
"(Exclude this setting will default to HTTPSPNEGOAuth() )." |
||||
) |
||||
# Query the IDP |
||||
idp_response = requests.get(idp_url, auth=auth, **idp_request_kwargs) |
||||
idp_response.raise_for_status() |
||||
# Assist with debugging. Note: contains sensitive info! |
||||
xpath = saml_config["saml_response_xpath"] |
||||
log_idp_response = ( |
||||
"log_idp_response" in saml_config and saml_config["log_idp_response"] |
||||
) |
||||
if log_idp_response: |
||||
self.log.warning( |
||||
"The IDP response contains sensitive information, but log_idp_response is ON (%s).", |
||||
log_idp_response, |
||||
) |
||||
self.log.info("idp_response.content= %s", idp_response.content) |
||||
self.log.info("xpath= %s", xpath) |
||||
# Extract SAML Assertion from the returned HTML / XML |
||||
xml = etree.fromstring(idp_response.content) |
||||
saml_assertion = xml.xpath(xpath) |
||||
if isinstance(saml_assertion, list): |
||||
if len(saml_assertion) == 1: |
||||
saml_assertion = saml_assertion[0] |
||||
if not saml_assertion: |
||||
raise ValueError("Invalid SAML Assertion") |
||||
return saml_assertion |
||||
|
||||
def _assume_role_with_web_identity( |
||||
self, role_arn, assume_role_kwargs, base_session |
||||
): |
||||
base_session = base_session or botocore.session.get_session() |
||||
client_creator = base_session.create_client |
||||
federation = self.extra_config.get("assume_role_with_web_identity_federation") |
||||
if federation == "google": |
||||
web_identity_token_loader = self._get_google_identity_token_loader() |
||||
else: |
||||
raise AirflowException( |
||||
f'Unsupported federation: {federation}. Currently "google" only are supported.' |
||||
) |
||||
fetcher = botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher( |
||||
client_creator=client_creator, |
||||
web_identity_token_loader=web_identity_token_loader, |
||||
role_arn=role_arn, |
||||
extra_args=assume_role_kwargs or {}, |
||||
) |
||||
aws_creds = botocore.credentials.DeferredRefreshableCredentials( |
||||
method="assume-role-with-web-identity", |
||||
refresh_using=fetcher.fetch_credentials, |
||||
time_fetcher=lambda: datetime.datetime.now(tz=tzlocal()), |
||||
) |
||||
botocore_session = botocore.session.Session() |
||||
botocore_session._credentials = aws_creds # pylint: disable=protected-access |
||||
return botocore_session |
||||
|
||||
def _get_google_identity_token_loader(self): |
||||
from airflow.providers.google.common.utils.id_token_credentials import ( |
||||
get_default_id_token_credentials, |
||||
) |
||||
from google.auth.transport import requests as requests_transport |
||||
|
||||
audience = self.extra_config.get( |
||||
"assume_role_with_web_identity_federation_audience" |
||||
) |
||||
|
||||
google_id_token_credentials = get_default_id_token_credentials( |
||||
target_audience=audience |
||||
) |
||||
|
||||
def web_identity_token_loader(): |
||||
if not google_id_token_credentials.valid: |
||||
request_adapter = requests_transport.Request() |
||||
google_id_token_credentials.refresh(request=request_adapter) |
||||
return google_id_token_credentials.token |
||||
|
||||
return web_identity_token_loader |
||||
|
||||
|
||||
class AwsBaseHook(BaseHook): |
||||
""" |
||||
Interact with AWS. |
||||
This class is a thin wrapper around the boto3 python library. |
||||
|
||||
:param aws_conn_id: The Airflow connection used for AWS credentials. |
||||
If this is None or empty then the default boto3 behaviour is used. If |
||||
running Airflow in a distributed manner and aws_conn_id is None or |
||||
empty, then default boto3 configuration would be used (and must be |
||||
maintained on each worker node). |
||||
:type aws_conn_id: str |
||||
:param verify: Whether or not to verify SSL certificates. |
||||
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html |
||||
:type verify: Union[bool, str, None] |
||||
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. |
||||
:type region_name: Optional[str] |
||||
:param client_type: boto3.client client_type. Eg 's3', 'emr' etc |
||||
:type client_type: Optional[str] |
||||
:param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc |
||||
:type resource_type: Optional[str] |
||||
:param config: Configuration for botocore client. |
||||
(https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html) |
||||
:type config: Optional[botocore.client.Config] |
||||
""" |
||||
|
||||
conn_name_attr = "aws_conn_id" |
||||
default_conn_name = "aws_default" |
||||
conn_type = "aws" |
||||
hook_name = "Amazon Web Services" |
||||
|
||||
def __init__( |
||||
self, |
||||
aws_conn_id: Optional[str] = default_conn_name, |
||||
verify: Union[bool, str, None] = None, |
||||
region_name: Optional[str] = None, |
||||
client_type: Optional[str] = None, |
||||
resource_type: Optional[str] = None, |
||||
config: Optional[Config] = None, |
||||
) -> None: |
||||
super().__init__() |
||||
self.aws_conn_id = aws_conn_id |
||||
self.verify = verify |
||||
self.client_type = client_type |
||||
self.resource_type = resource_type |
||||
self.region_name = region_name |
||||
self.config = config |
||||
|
||||
if not (self.client_type or self.resource_type): |
||||
raise AirflowException( |
||||
"Either client_type or resource_type must be provided." |
||||
) |
||||
|
||||
def _get_credentials( |
||||
self, region_name: Optional[str] |
||||
) -> Tuple[boto3.session.Session, Optional[str]]: |
||||
|
||||
if not self.aws_conn_id: |
||||
session = boto3.session.Session(region_name=region_name) |
||||
return session, None |
||||
|
||||
self.log.info("Airflow Connection: aws_conn_id=%s", self.aws_conn_id) |
||||
|
||||
try: |
||||
# Fetch the Airflow connection object |
||||
connection_object = self.get_connection(self.aws_conn_id) |
||||
extra_config = connection_object.extra_dejson |
||||
endpoint_url = extra_config.get("host") |
||||
|
||||
# https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html#botocore.config.Config |
||||
if "config_kwargs" in extra_config: |
||||
self.log.info( |
||||
"Retrieving config_kwargs from Connection.extra_config['config_kwargs']: %s", |
||||
extra_config["config_kwargs"], |
||||
) |
||||
self.config = Config(**extra_config["config_kwargs"]) |
||||
|
||||
session = _SessionFactory( |
||||
conn=connection_object, region_name=region_name, config=self.config |
||||
).create_session() |
||||
|
||||
return session, endpoint_url |
||||
|
||||
except AirflowException: |
||||
self.log.warning("Unable to use Airflow Connection for credentials.") |
||||
self.log.info("Fallback on boto3 credential strategy") |
||||
# http://boto3.readthedocs.io/en/latest/guide/configuration.html |
||||
|
||||
self.log.info( |
||||
"Creating session using boto3 credential strategy region_name=%s", |
||||
region_name, |
||||
) |
||||
session = boto3.session.Session(region_name=region_name) |
||||
return session, None |
||||
|
||||
def get_client_type( |
||||
self, |
||||
client_type: str, |
||||
region_name: Optional[str] = None, |
||||
config: Optional[Config] = None, |
||||
) -> boto3.client: |
||||
"""Get the underlying boto3 client using boto3 session""" |
||||
session, endpoint_url = self._get_credentials(region_name) |
||||
|
||||
# No AWS Operators use the config argument to this method. |
||||
# Keep backward compatibility with other users who might use it |
||||
if config is None: |
||||
config = self.config |
||||
|
||||
return session.client( |
||||
client_type, endpoint_url=endpoint_url, config=config, verify=self.verify |
||||
) |
||||
|
||||
def get_resource_type( |
||||
self, |
||||
resource_type: str, |
||||
region_name: Optional[str] = None, |
||||
config: Optional[Config] = None, |
||||
) -> boto3.re# |
||||
"""Get the underlying boto3 resource using boto3 session""" |
||||
session, endpoint_url = self._get_credentials(region_name) |
||||
|
||||
# No AWS Operators use the config argument to this method. |
||||
# Keep backward compatibility with other users who might use it |
||||
if config is None: |
||||
config = self.config |
||||
|
||||
return session.resource( |
||||
resource_type, endpoint_url=endpoint_url, config=config, verify=self.verify |
||||
) |
||||
|
||||
@cached_property |
||||
def conn(self) -> Union[boto3.client, boto3.resource]: |
||||
""" |
||||
Get the underlying boto3 client/resource (cached) |
||||
|
||||
:return: boto3.client or boto3.resource |
||||
:rtype: Union[boto3.client, boto3.resource] |
||||
""" |
||||
if self.client_type: |
||||
return self.get_client_type(self.client_type, region_name=self.region_name) |
||||
elif self.resource_type: |
||||
return self.get_resource_type( |
||||
self.resource_type, region_name=self.region_name |
||||
) |
||||
else: |
||||
# Rare possibility - subclasses have not specified a client_type or resource_type |
||||
raise NotImplementedError("Could not get boto3 connection!") |
||||
|
||||
def get_conn(self) -> Union[boto3.client, boto3.resource]: |
||||
""" |
||||
Get the underlying boto3 client/resource (cached) |
||||
|
||||
Implemented so that caching works as intended. It exists for compatibility |
||||
with subclasses that rely on a super().get_conn() method. |
||||
|
||||
:return: boto3.client or boto3.resource |
||||
:rtype: Union[boto3.client, boto3.resource] |
||||
""" |
||||
# Compat shim |
||||
return self.conn |
||||
|
||||
def get_session(self, region_name: Optional[str] = None) -> boto3.session.Session: |
||||
"""Get the underlying boto3.session.""" |
||||
session, _ = self._get_credentials(region_name) |
||||
return session |
||||
|
||||
def get_credentials(self, region_name: Optional[str] = None) -> ReadOnlyCredentials: |
||||
""" |
||||
Get the underlying `botocore.Credentials` object. |
||||
|
||||
This contains the following authentication attributes: access_key, secret_key and token. |
||||
""" |
||||
session, _ = self._get_credentials(region_name) |
||||
# Credentials are refreshable, so accessing your access key and |
||||
# secret key separately can lead to a race condition. |
||||
# See https://stackoverflow.com/a/36291428/8283373 |
||||
return session.get_credentials().get_frozen_credentials() |
||||
|
||||
def expand_role(self, role: str) -> str: |
||||
""" |
||||
If the IAM role is a role name, get the Amazon Resource Name (ARN) for the role. |
||||
If IAM role is already an IAM role ARN, no change is made. |
||||
|
||||
:param role: IAM role name or ARN |
||||
:return: IAM role ARN |
||||
""" |
||||
if "/" in role: |
||||
return role |
||||
else: |
||||
return self.get_client_type("iam").get_role(RoleName=role)["Role"]["Arn"] |
||||
|
||||
|
||||
def _parse_s3_config( |
||||
config_file_name: str, |
||||
config_format: Optional[str] = "boto", |
||||
profile: Optional[str] = None, |
||||
) -> Tuple[Optional[str], Optional[str]]: |
||||
""" |
||||
Parses a config file for s3 credentials. Can currently |
||||
parse boto, s3cmd.conf and AWS SDK config formats |
||||
|
||||
:param config_file_name: path to the config file |
||||
:type config_file_name: str |
||||
:param config_format: config type. One of "boto", "s3cmd" or "aws". |
||||
Defaults to "boto" |
||||
:type config_format: str |
||||
:param profile: profile name in AWS type config file |
||||
:type profile: str |
||||
""" |
||||
config = configparser.ConfigParser() |
||||
if config.read(config_file_name): # pragma: no cover |
||||
sections = config.sections() |
||||
else: |
||||
raise AirflowException(f"Couldn't read {config_file_name}") |
||||
# Setting option names depending on file format |
||||
if config_format is None: |
||||
config_format = "boto" |
||||
conf_format = config_format.lower() |
||||
if conf_format == "boto": # pragma: no cover |
||||
if profile is not None and "profile " + profile in sections: |
||||
cred_section = "profile " + profile |
||||
else: |
||||
cred_section = "Credentials" |
||||
elif conf_format == "aws" and profile is not None: |
||||
cred_section = profile |
||||
else: |
||||
cred_section = "default" |
||||
# Option names |
||||
if conf_format in ("boto", "aws"): # pragma: no cover |
||||
key_id_option = "aws_access_key_id" |
||||
secret_key_option = "aws_secret_access_key" |
||||
# security_token_option = 'aws_security_token' |
||||
else: |
||||
key_id_option = "access_key" |
||||
secret_key_option = "secret_key" |
||||
# Actual Parsing |
||||
if cred_section not in sections: |
||||
raise AirflowException("This config file format is not recognized") |
||||
else: |
||||
try: |
||||
access_key = config.get(cred_section, key_id_option) |
||||
secret_key = config.get(cred_section, secret_key_option) |
||||
except Exception: |
||||
logging.warning("Option Error in parsing s3 config file") |
||||
raise |
||||
return access_key, secret_key |
||||
@ -0,0 +1,556 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
""" |
||||
A client for AWS batch services |
||||
|
||||
.. seealso:: |
||||
|
||||
- http://boto3.readthedocs.io/en/latest/guide/configuration.html |
||||
- http://boto3.readthedocs.io/en/latest/reference/services/batch.html |
||||
- https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html |
||||
""" |
||||
|
||||
from random import uniform |
||||
from time import sleep |
||||
from typing import Dict, List, Optional, Union |
||||
|
||||
import botocore.client |
||||
import botocore.exceptions |
||||
import botocore.waiter |
||||
from airflow.exceptions import AirflowException |
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
from airflow.typing_compat import Protocol, runtime_checkable |
||||
|
||||
# Add exceptions to pylint for the boto3 protocol only; ideally the boto3 library |
||||
# could provide |
||||
# protocols for all their dynamically generated classes (try to migrate this to a PR on botocore). |
||||
# Note that the use of invalid-name parameters should be restricted to the boto3 mappings only; |
||||
# all the Airflow wrappers of boto3 clients should not adopt invalid-names to match boto3. |
||||
# pylint: disable=invalid-name, unused-argument |
||||
|
||||
|
||||
@runtime_checkable |
||||
class AwsBatchProtocol(Protocol): |
||||
""" |
||||
A structured Protocol for ``boto3.client('batch') -> botocore.client.Batch``. |
||||
This is used for type hints on :py:meth:`.AwsBatchClient.client`; it covers |
||||
only the subset of client methods required. |
||||
|
||||
.. seealso:: |
||||
|
||||
- https://mypy.readthedocs.io/en/latest/protocols.html |
||||
- http://boto3.readthedocs.io/en/latest/reference/services/batch.html |
||||
""" |
||||
|
||||
def describe_jobs(self, jobs: List[str]) -> Dict: |
||||
""" |
||||
Get job descriptions from AWS batch |
||||
|
||||
:param jobs: a list of JobId to describe |
||||
:type jobs: List[str] |
||||
|
||||
:return: an API response to describe jobs |
||||
:rtype: Dict |
||||
""" |
||||
... |
||||
|
||||
def get_waiter(self, waiterName: str) -> botocore.waiter.Waiter: |
||||
""" |
||||
Get an AWS Batch service waiter |
||||
|
||||
:param waiterName: The name of the waiter. The name should match |
||||
the name (including the casing) of the key name in the waiter |
||||
model file (typically this is CamelCasing). |
||||
:type waiterName: str |
||||
|
||||
:return: a waiter object for the named AWS batch service |
||||
:rtype: botocore.waiter.Waiter |
||||
|
||||
.. note:: |
||||
AWS batch might not have any waiters (until botocore PR-1307 is released). |
||||
|
||||
.. code-block:: python |
||||
|
||||
import boto3 |
||||
boto3.client('batch').waiter_names == [] |
||||
|
||||
.. seealso:: |
||||
|
||||
- https://boto3.amazonaws.com/v1/documentation/api/latest/guide/clients.html#waiters |
||||
- https://github.com/boto/botocore/pull/1307 |
||||
""" |
||||
... |
||||
|
||||
def submit_job( |
||||
self, |
||||
jobName: str, |
||||
jobQueue: str, |
||||
jobDefinition: str, |
||||
arrayProperties: Dict, |
||||
parameters: Dict, |
||||
containerOverrides: Dict, |
||||
tags: Dict, |
||||
) -> Dict: |
||||
""" |
||||
Submit a batch job |
||||
|
||||
:param jobName: the name for the AWS batch job |
||||
:type jobName: str |
||||
|
||||
:param jobQueue: the queue name on AWS Batch |
||||
:type jobQueue: str |
||||
|
||||
:param jobDefinition: the job definition name on AWS Batch |
||||
:type jobDefinition: str |
||||
|
||||
:param arrayProperties: the same parameter that boto3 will receive |
||||
:type arrayProperties: Dict |
||||
|
||||
:param parameters: the same parameter that boto3 will receive |
||||
:type parameters: Dict |
||||
|
||||
:param containerOverrides: the same parameter that boto3 will receive |
||||
:type containerOverrides: Dict |
||||
|
||||
:param tags: the same parameter that boto3 will receive |
||||
:type tags: Dict |
||||
|
||||
:return: an API response |
||||
:rtype: Dict |
||||
""" |
||||
... |
||||
|
||||
def terminate_job(self, jobId: str, reason: str) -> Dict: |
||||
""" |
||||
Terminate a batch job |
||||
|
||||
:param jobId: a job ID to terminate |
||||
:type jobId: str |
||||
|
||||
:param reason: a reason to terminate job ID |
||||
:type reason: str |
||||
|
||||
:return: an API response |
||||
:rtype: Dict |
||||
""" |
||||
... |
||||
|
||||
|
||||
# Note that the use of invalid-name parameters should be restricted to the boto3 mappings only; |
||||
# all the Airflow wrappers of boto3 clients should not adopt invalid-names to match boto3. |
||||
# pylint: enable=invalid-name, unused-argument |
||||
|
||||
|
||||
class AwsBatchClientHook(AwsBaseHook): |
||||
""" |
||||
A client for AWS batch services. |
||||
|
||||
:param max_retries: exponential back-off retries, 4200 = 48 hours; |
||||
polling is only used when waiters is None |
||||
:type max_retries: Optional[int] |
||||
|
||||
:param status_retries: number of HTTP retries to get job status, 10; |
||||
polling is only used when waiters is None |
||||
:type status_retries: Optional[int] |
||||
|
||||
.. note:: |
||||
Several methods use a default random delay to check or poll for job status, i.e. |
||||
``random.uniform(DEFAULT_DELAY_MIN, DEFAULT_DELAY_MAX)`` |
||||
Using a random interval helps to avoid AWS API throttle limits |
||||
when many concurrent tasks request job-descriptions. |
||||
|
||||
To modify the global defaults for the range of jitter allowed when a |
||||
random delay is used to check batch job status, modify these defaults, e.g.: |
||||
.. code-block:: |
||||
|
||||
AwsBatchClient.DEFAULT_DELAY_MIN = 0 |
||||
AwsBatchClient.DEFAULT_DELAY_MAX = 5 |
||||
|
||||
When explicit delay values are used, a 1 second random jitter is applied to the |
||||
delay (e.g. a delay of 0 sec will be a ``random.uniform(0, 1)`` delay. It is |
||||
generally recommended that random jitter is added to API requests. A |
||||
convenience method is provided for this, e.g. to get a random delay of |
||||
10 sec +/- 5 sec: ``delay = AwsBatchClient.add_jitter(10, width=5, minima=0)`` |
||||
|
||||
.. seealso:: |
||||
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html |
||||
- https://docs.aws.amazon.com/general/latest/gr/api-retries.html |
||||
- https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ |
||||
""" |
||||
|
||||
MAX_RETRIES = 4200 |
||||
STATUS_RETRIES = 10 |
||||
|
||||
# delays are in seconds |
||||
DEFAULT_DELAY_MIN = 1 |
||||
DEFAULT_DELAY_MAX = 10 |
||||
|
||||
def __init__( |
||||
self, |
||||
*args, |
||||
max_retries: Optional[int] = None, |
||||
status_retries: Optional[int] = None, |
||||
**kwargs, |
||||
) -> None: |
||||
# https://github.com/python/mypy/issues/6799 hence type: ignore |
||||
super().__init__(client_type="batch", *args, **kwargs) # type: ignore |
||||
self.max_retries = max_retries or self.MAX_RETRIES |
||||
self.status_retries = status_retries or self.STATUS_RETRIES |
||||
|
||||
@property |
||||
def client( |
||||
self, |
||||
) -> Union[AwsBatchProtocol, botocore.client.BaseClient]: # noqa: D402 |
||||
""" |
||||
An AWS API client for batch services, like ``boto3.client('batch')`` |
||||
|
||||
:return: a boto3 'batch' client for the ``.region_name`` |
||||
:rtype: Union[AwsBatchProtocol, botocore.client.BaseClient] |
||||
""" |
||||
return self.conn |
||||
|
||||
def terminate_job(self, job_id: str, reason: str) -> Dict: |
||||
""" |
||||
Terminate a batch job |
||||
|
||||
:param job_id: a job ID to terminate |
||||
:type job_id: str |
||||
|
||||
:param reason: a reason to terminate job ID |
||||
:type reason: str |
||||
|
||||
:return: an API response |
||||
:rtype: Dict |
||||
""" |
||||
response = self.get_conn().terminate_job(jobId=job_id, reason=reason) |
||||
self.log.info(response) |
||||
return response |
||||
|
||||
def check_job_success(self, job_id: str) -> bool: |
||||
""" |
||||
Check the final status of the batch job; return True if the job |
||||
'SUCCEEDED', else raise an AirflowException |
||||
|
||||
:param job_id: a batch job ID |
||||
:type job_id: str |
||||
|
||||
:rtype: bool |
||||
|
||||
:raises: AirflowException |
||||
""" |
||||
job = self.get_job_description(job_id) |
||||
job_status = job.get("status") |
||||
|
||||
if job_status == "SUCCEEDED": |
||||
self.log.info("AWS batch job (%s) succeeded: %s", job_id, job) |
||||
return True |
||||
|
||||
if job_status == "FAILED": |
||||
raise AirflowException(f"AWS Batch job ({job_id}) failed: {job}") |
||||
|
||||
if job_status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]: |
||||
raise AirflowException(f"AWS Batch job ({job_id}) is not complete: {job}") |
||||
|
||||
raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}") |
||||
|
||||
def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None: |
||||
""" |
||||
Wait for batch job to complete |
||||
|
||||
:param job_id: a batch job ID |
||||
:type job_id: str |
||||
|
||||
:param delay: a delay before polling for job status |
||||
:type delay: Optional[Union[int, float]] |
||||
|
||||
:raises: AirflowException |
||||
""" |
||||
self.delay(delay) |
||||
self.poll_for_job_running(job_id, delay) |
||||
self.poll_for_job_complete(job_id, delay) |
||||
self.log.info("AWS Batch job (%s) has completed", job_id) |
||||
|
||||
def poll_for_job_running( |
||||
self, job_id: str, delay: Union[int, float, None] = None |
||||
) -> None: |
||||
""" |
||||
Poll for job running. The status that indicates a job is running or |
||||
already complete are: 'RUNNING'|'SUCCEEDED'|'FAILED'. |
||||
|
||||
So the status options that this will wait for are the transitions from: |
||||
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'|'SUCCEEDED'|'FAILED' |
||||
|
||||
The completed status options are included for cases where the status |
||||
changes too quickly for polling to detect a RUNNING status that moves |
||||
quickly from STARTING to RUNNING to completed (often a failure). |
||||
|
||||
:param job_id: a batch job ID |
||||
:type job_id: str |
||||
|
||||
:param delay: a delay before polling for job status |
||||
:type delay: Optional[Union[int, float]] |
||||
|
||||
:raises: AirflowException |
||||
""" |
||||
self.delay(delay) |
||||
running_status = ["RUNNING", "SUCCEEDED", "FAILED"] |
||||
self.poll_job_status(job_id, running_status) |
||||
|
||||
def poll_for_job_complete( |
||||
self, job_id: str, delay: Union[int, float, None] = None |
||||
) -> None: |
||||
""" |
||||
Poll for job completion. The status that indicates job completion |
||||
are: 'SUCCEEDED'|'FAILED'. |
||||
|
||||
So the status options that this will wait for are the transitions from: |
||||
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED' |
||||
|
||||
:param job_id: a batch job ID |
||||
:type job_id: str |
||||
|
||||
:param delay: a delay before polling for job status |
||||
:type delay: Optional[Union[int, float]] |
||||
|
||||
:raises: AirflowException |
||||
""" |
||||
self.delay(delay) |
||||
complete_status = ["SUCCEEDED", "FAILED"] |
||||
self.poll_job_status(job_id, complete_status) |
||||
|
||||
def poll_job_status(self, job_id: str, match_status: List[str]) -> bool: |
||||
""" |
||||
Poll for job status using an exponential back-off strategy (with max_retries). |
||||
|
||||
:param job_id: a batch job ID |
||||
:type job_id: str |
||||
|
||||
:param match_status: a list of job status to match; the batch job status are: |
||||
'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED' |
||||
:type match_status: List[str] |
||||
|
||||
:rtype: bool |
||||
|
||||
:raises: AirflowException |
||||
""" |
||||
retries = 0 |
||||
while True: |
||||
|
||||
job = self.get_job_description(job_id) |
||||
job_status = job.get("status") |
||||
self.log.info( |
||||
"AWS Batch job (%s) check status (%s) in %s", |
||||
job_id, |
||||
job_status, |
||||
match_status, |
||||
) |
||||
|
||||
if job_status in match_status: |
||||
return True |
||||
|
||||
if retries >= self.max_retries: |
||||
raise AirflowException( |
||||
f"AWS Batch job ({job_id}) status checks exceed max_retries" |
||||
) |
||||
|
||||
retries += 1 |
||||
pause = self.exponential_delay(retries) |
||||
self.log.info( |
||||
"AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds", |
||||
job_id, |
||||
retries, |
||||
self.max_retries, |
||||
pause, |
||||
) |
||||
self.delay(pause) |
||||
|
||||
def get_job_description(self, job_id: str) -> Dict: |
||||
""" |
||||
Get job description (using status_retries). |
||||
|
||||
:param job_id: a batch job ID |
||||
:type job_id: str |
||||
|
||||
:return: an API response for describe jobs |
||||
:rtype: Dict |
||||
|
||||
:raises: AirflowException |
||||
""" |
||||
retries = 0 |
||||
while True: |
||||
try: |
||||
response = self.get_conn().describe_jobs(jobs=[job_id]) |
||||
return self.parse_job_description(job_id, response) |
||||
|
||||
except botocore.exceptions.ClientError as err: |
||||
error = err.response.get("Error", {}) |
||||
if error.get("Code") == "TooManyRequestsException": |
||||
pass # allow it to retry, if possible |
||||
else: |
||||
raise AirflowException( |
||||
f"AWS Batch job ({job_id}) description error: {err}" |
||||
) |
||||
|
||||
retries += 1 |
||||
if retries >= self.status_retries: |
||||
raise AirflowException( |
||||
"AWS Batch job ({}) description error: exceeded " |
||||
"status_retries ({})".format(job_id, self.status_retries) |
||||
) |
||||
|
||||
pause = self.exponential_delay(retries) |
||||
self.log.info( |
||||
"AWS Batch job (%s) description retry (%d of %d) in the next %.2f seconds", |
||||
job_id, |
||||
retries, |
||||
self.status_retries, |
||||
pause, |
||||
) |
||||
self.delay(pause) |
||||
|
||||
@staticmethod |
||||
def parse_job_description(job_id: str, response: Dict) -> Dict: |
||||
""" |
||||
Parse job description to extract description for job_id |
||||
|
||||
:param job_id: a batch job ID |
||||
:type job_id: str |
||||
|
||||
:param response: an API response for describe jobs |
||||
:type response: Dict |
||||
|
||||
:return: an API response to describe job_id |
||||
:rtype: Dict |
||||
|
||||
:raises: AirflowException |
||||
""" |
||||
jobs = response.get("jobs", []) |
||||
matching_jobs = [job for job in jobs if job.get("jobId") == job_id] |
||||
if len(matching_jobs) != 1: |
||||
raise AirflowException( |
||||
f"AWS Batch job ({job_id}) description error: response: {response}" |
||||
) |
||||
|
||||
return matching_jobs[0] |
||||
|
||||
@staticmethod |
||||
def add_jitter( |
||||
delay: Union[int, float], |
||||
width: Union[int, float] = 1, |
||||
minima: Union[int, float] = 0, |
||||
) -> float: |
||||
""" |
||||
Use delay +/- width for random jitter |
||||
|
||||
Adding jitter to status polling can help to avoid |
||||
AWS batch API limits for monitoring batch jobs with |
||||
a high concurrency in Airflow tasks. |
||||
|
||||
:param delay: number of seconds to pause; |
||||
delay is assumed to be a positive number |
||||
:type delay: Union[int, float] |
||||
|
||||
:param width: delay +/- width for random jitter; |
||||
width is assumed to be a positive number |
||||
:type width: Union[int, float] |
||||
|
||||
:param minima: minimum delay allowed; |
||||
minima is assumed to be a non-negative number |
||||
:type minima: Union[int, float] |
||||
|
||||
:return: uniform(delay - width, delay + width) jitter |
||||
and it is a non-negative number |
||||
:rtype: float |
||||
""" |
||||
delay = abs(delay) |
||||
width = abs(width) |
||||
minima = abs(minima) |
||||
lower = max(minima, delay - width) |
||||
upper = delay + width |
||||
return uniform(lower, upper) |
||||
|
||||
@staticmethod |
||||
def delay(delay: Union[int, float, None] = None) -> None: |
||||
""" |
||||
Pause execution for ``delay`` seconds. |
||||
|
||||
:param delay: a delay to pause execution using ``time.sleep(delay)``; |
||||
a small 1 second jitter is applied to the delay. |
||||
:type delay: Optional[Union[int, float]] |
||||
|
||||
.. note:: |
||||
This method uses a default random delay, i.e. |
||||
``random.uniform(DEFAULT_DELAY_MIN, DEFAULT_DELAY_MAX)``; |
||||
using a random interval helps to avoid AWS API throttle limits |
||||
when many concurrent tasks request job-descriptions. |
||||
""" |
||||
if delay is None: |
||||
delay = uniform( |
||||
AwsBatchClientHook.DEFAULT_DELAY_MIN, |
||||
AwsBatchClientHook.DEFAULT_DELAY_MAX, |
||||
) |
||||
else: |
||||
delay = AwsBatchClientHook.add_jitter(delay) |
||||
sleep(delay) |
||||
|
||||
@staticmethod |
||||
def exponential_delay(tries: int) -> float: |
||||
""" |
||||
An exponential back-off delay, with random jitter. There is a maximum |
||||
interval of 10 minutes (with random jitter between 3 and 10 minutes). |
||||
This is used in the :py:meth:`.poll_for_job_status` method. |
||||
|
||||
:param tries: Number of tries |
||||
:type tries: int |
||||
|
||||
:rtype: float |
||||
|
||||
Examples of behavior: |
||||
|
||||
.. code-block:: python |
||||
|
||||
def exp(tries): |
||||
max_interval = 600.0 # 10 minutes in seconds |
||||
delay = 1 + pow(tries * 0.6, 2) |
||||
delay = min(max_interval, delay) |
||||
print(delay / 3, delay) |
||||
|
||||
for tries in range(10): |
||||
exp(tries) |
||||
|
||||
# 0.33 1.0 |
||||
# 0.45 1.35 |
||||
# 0.81 2.44 |
||||
# 1.41 4.23 |
||||
# 2.25 6.76 |
||||
# 3.33 10.00 |
||||
# 4.65 13.95 |
||||
# 6.21 18.64 |
||||
# 8.01 24.04 |
||||
# 10.05 30.15 |
||||
|
||||
.. seealso:: |
||||
|
||||
- https://docs.aws.amazon.com/general/latest/gr/api-retries.html |
||||
- https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ |
||||
""" |
||||
max_interval = 600.0 # results in 3 to 10 minute delay |
||||
delay = 1 + pow(tries * 0.6, 2) |
||||
delay = min(max_interval, delay) |
||||
return uniform(delay / 3, delay) |
||||
@ -0,0 +1,105 @@ |
||||
{ |
||||
"version": 2, |
||||
"waiters": { |
||||
"JobExists": { |
||||
"delay": 2, |
||||
"operation": "DescribeJobs", |
||||
"maxAttempts": 100, |
||||
"acceptors": [ |
||||
{ |
||||
"argument": "jobs[].status", |
||||
"expected": "SUBMITTED", |
||||
"matcher": "pathAll", |
||||
"state": "success" |
||||
}, |
||||
{ |
||||
"argument": "jobs[].status", |
||||
"expected": "PENDING", |
||||
"matcher": "pathAll", |
||||
"state": "success" |
||||
}, |
||||
{ |
||||
"argument": "jobs[].status", |
||||
"expected": "RUNNABLE", |
||||
"matcher": "pathAll", |
||||
"state": "success" |
||||
}, |
||||
{ |
||||
"argument": "jobs[].status", |
||||
"expected": "STARTING", |
||||
"matcher": "pathAll", |
||||
"state": "success" |
||||
}, |
||||
{ |
||||
"argument": "jobs[].status", |
||||
"expected": "RUNNING", |
||||
"matcher": "pathAll", |
||||
"state": "success" |
||||
}, |
||||
{ |
||||
"argument": "jobs[].status", |
||||
"expected": "FAILED", |
||||
"matcher": "pathAll", |
||||
"state": "success" |
||||
}, |
||||
{ |
||||
"argument": "jobs[].status", |
||||
"expected": "SUCCEEDED", |
||||
"matcher": "pathAll", |
||||
"state": "success" |
||||
}, |
||||
{ |
||||
"argument": "jobs[].status", |
||||
"expected": "", |
||||
"matcher": "error", |
||||
"state": "failure" |
||||
} |
||||
] |
||||
}, |
||||
"JobRunning": { |
||||
"delay": 5, |
||||
"operation": "DescribeJobs", |
||||
"maxAttempts": 100, |
||||
"acceptors": [ |
||||
{ |
||||
"argument": "jobs[].status", |
||||
"expected": "RUNNING", |
||||
"matcher": "pathAll", |
||||
"state": "success" |
||||
}, |
||||
{ |
||||
"argument": "jobs[].status", |
||||
"expected": "FAILED", |
||||
"matcher": "pathAll", |
||||
"state": "success" |
||||
}, |
||||
{ |
||||
"argument": "jobs[].status", |
||||
"expected": "SUCCEEDED", |
||||
"matcher": "pathAll", |
||||
"state": "success" |
||||
} |
||||
] |
||||
}, |
||||
"JobComplete": { |
||||
"delay": 300, |
||||
"operation": "DescribeJobs", |
||||
"maxAttempts": 288, |
||||
"description": "Wait until job status is SUCCEEDED or FAILED", |
||||
"acceptors": [ |
||||
{ |
||||
"argument": "jobs[].status", |
||||
"expected": "SUCCEEDED", |
||||
"matcher": "pathAll", |
||||
"state": "success" |
||||
}, |
||||
{ |
||||
"argument": "jobs[].status", |
||||
"expected": "FAILED", |
||||
"matcher": "pathAny", |
||||
"state": "failure" |
||||
} |
||||
] |
||||
} |
||||
} |
||||
} |
||||
@ -0,0 +1,242 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
# |
||||
|
||||
""" |
||||
AWS batch service waiters |
||||
|
||||
.. seealso:: |
||||
|
||||
- https://boto3.amazonaws.com/v1/documentation/api/latest/guide/clients.html#waiters |
||||
- https://github.com/boto/botocore/blob/develop/botocore/waiter.py |
||||
""" |
||||
|
||||
import json |
||||
import sys |
||||
from copy import deepcopy |
||||
from pathlib import Path |
||||
from typing import Dict, List, Optional, Union |
||||
|
||||
import botocore.client |
||||
import botocore.exceptions |
||||
import botocore.waiter |
||||
from airflow.exceptions import AirflowException |
||||
from airflow.providers.amazon.aws.hooks.batch_client import AwsBatchClientHook |
||||
|
||||
|
||||
class AwsBatchWaitersHook(AwsBatchClientHook): |
||||
""" |
||||
A utility to manage waiters for AWS batch services |
||||
Examples: |
||||
|
||||
.. code-block:: python |
||||
|
||||
import random |
||||
from airflow.providers.amazon.aws.operators.batch_waiters import AwsBatchWaiters |
||||
|
||||
# to inspect default waiters |
||||
waiters = AwsBatchWaiters() |
||||
config = waiters.default_config # type: Dict |
||||
waiter_names = waiters.list_waiters() # -> ["JobComplete", "JobExists", "JobRunning"] |
||||
|
||||
# The default_config is a useful stepping stone to creating custom waiters, e.g. |
||||
custom_config = waiters.default_config # this is a deepcopy |
||||
# modify custom_config['waiters'] as necessary and get a new instance: |
||||
waiters = AwsBatchWaiters(waiter_config=custom_config) |
||||
waiters.waiter_config # check the custom configuration (this is a deepcopy) |
||||
waiters.list_waiters() # names of custom waiters |
||||
|
||||
# During the init for AwsBatchWaiters, the waiter_config is used to build a waiter_model; |
||||
# and note that this only occurs during the class init, to avoid any accidental mutations |
||||
# of waiter_config leaking into the waiter_model. |
||||
waiters.waiter_model # -> botocore.waiter.WaiterModel object |
||||
|
||||
# The waiter_model is combined with the waiters.client to get a specific waiter |
||||
# and the details of the config on that waiter can be further modified without any |
||||
# accidental impact on the generation of new waiters from the defined waiter_model, e.g. |
||||
waiters.get_waiter("JobExists").config.delay # -> 5 |
||||
waiter = waiters.get_waiter("JobExists") # -> botocore.waiter.Batch.Waiter.JobExists object |
||||
waiter.config.delay = 10 |
||||
waiters.get_waiter("JobExists").config.delay # -> 5 as defined by waiter_model |
||||
|
||||
# To use a specific waiter, update the config and call the `wait()` method for jobId, e.g. |
||||
waiter = waiters.get_waiter("JobExists") # -> botocore.waiter.Batch.Waiter.JobExists object |
||||
waiter.config.delay = random.uniform(1, 10) # seconds |
||||
waiter.config.max_attempts = 10 |
||||
waiter.wait(jobs=[jobId]) |
||||
|
||||
.. seealso:: |
||||
|
||||
- https://www.2ndwatch.com/blog/use-waiters-boto3-write/ |
||||
- https://github.com/boto/botocore/blob/develop/botocore/waiter.py |
||||
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html#waiters |
||||
- https://github.com/boto/botocore/tree/develop/botocore/data/ec2/2016-11-15 |
||||
- https://github.com/boto/botocore/issues/1915 |
||||
|
||||
:param waiter_config: a custom waiter configuration for AWS batch services |
||||
:type waiter_config: Optional[Dict] |
||||
|
||||
:param aws_conn_id: connection id of AWS credentials / region name. If None, |
||||
credential boto3 strategy will be used |
||||
(http://boto3.readthedocs.io/en/latest/guide/configuration.html). |
||||
:type aws_conn_id: Optional[str] |
||||
|
||||
:param region_name: region name to use in AWS client. |
||||
Override the AWS region in connection (if provided) |
||||
:type region_name: Optional[str] |
||||
""" |
||||
|
||||
def __init__(self, *args, waiter_config: Optional[Dict] = None, **kwargs) -> None: |
||||
|
||||
super().__init__(*args, **kwargs) |
||||
|
||||
self._default_config = None # type: Optional[Dict] |
||||
self._waiter_config = waiter_config or self.default_config |
||||
self._waiter_model = botocore.waiter.WaiterModel(self._waiter_config) |
||||
|
||||
@property |
||||
def default_config(self) -> Dict: |
||||
""" |
||||
An immutable default waiter configuration |
||||
|
||||
:return: a waiter configuration for AWS batch services |
||||
:rtype: Dict |
||||
""" |
||||
if self._default_config is None: |
||||
config_path = Path(__file__).with_name("batch_waiters.json").absolute() |
||||
with open(config_path) as config_file: |
||||
self._default_config = json.load(config_file) |
||||
return deepcopy(self._default_config) # avoid accidental mutation |
||||
|
||||
@property |
||||
def waiter_config(self) -> Dict: |
||||
""" |
||||
An immutable waiter configuration for this instance; a ``deepcopy`` is returned by this |
||||
property. During the init for AwsBatchWaiters, the waiter_config is used to build a |
||||
waiter_model and this only occurs during the class init, to avoid any accidental |
||||
mutations of waiter_config leaking into the waiter_model. |
||||
|
||||
:return: a waiter configuration for AWS batch services |
||||
:rtype: Dict |
||||
""" |
||||
return deepcopy(self._waiter_config) # avoid accidental mutation |
||||
|
||||
@property |
||||
def waiter_model(self) -> botocore.waiter.WaiterModel: |
||||
""" |
||||
A configured waiter model used to generate waiters on AWS batch services. |
||||
|
||||
:return: a waiter model for AWS batch services |
||||
:rtype: botocore.waiter.WaiterModel |
||||
""" |
||||
return self._waiter_model |
||||
|
||||
def get_waiter(self, waiter_name: str) -> botocore.waiter.Waiter: |
||||
""" |
||||
Get an AWS Batch service waiter, using the configured ``.waiter_model``. |
||||
|
||||
The ``.waiter_model`` is combined with the ``.client`` to get a specific waiter and |
||||
the properties of that waiter can be modified without any accidental impact on the |
||||
generation of new waiters from the ``.waiter_model``, e.g. |
||||
.. code-block:: |
||||
|
||||
waiters.get_waiter("JobExists").config.delay # -> 5 |
||||
waiter = waiters.get_waiter("JobExists") # a new waiter object |
||||
waiter.config.delay = 10 |
||||
waiters.get_waiter("JobExists").config.delay # -> 5 as defined by waiter_model |
||||
|
||||
To use a specific waiter, update the config and call the `wait()` method for jobId, e.g. |
||||
.. code-block:: |
||||
|
||||
import random |
||||
waiter = waiters.get_waiter("JobExists") # a new waiter object |
||||
waiter.config.delay = random.uniform(1, 10) # seconds |
||||
waiter.config.max_attempts = 10 |
||||
waiter.wait(jobs=[jobId]) |
||||
|
||||
:param waiter_name: The name of the waiter. The name should match |
||||
the name (including the casing) of the key name in the waiter |
||||
model file (typically this is CamelCasing); see ``.list_waiters``. |
||||
:type waiter_name: str |
||||
|
||||
:return: a waiter object for the named AWS batch service |
||||
:rtype: botocore.waiter.Waiter |
||||
""" |
||||
return botocore.waiter.create_waiter_with_client( |
||||
waiter_name, self.waiter_model, self.client |
||||
) |
||||
|
||||
def list_waiters(self) -> List[str]: |
||||
""" |
||||
List the waiters in a waiter configuration for AWS Batch services. |
||||
|
||||
:return: waiter names for AWS batch services |
||||
:rtype: List[str] |
||||
""" |
||||
return self.waiter_model.waiter_names |
||||
|
||||
def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None: |
||||
""" |
||||
Wait for batch job to complete. This assumes that the ``.waiter_model`` is configured |
||||
using some variation of the ``.default_config`` so that it can generate waiters with the |
||||
following names: "JobExists", "JobRunning" and "JobComplete". |
||||
|
||||
:param job_id: a batch job ID |
||||
:type job_id: str |
||||
|
||||
:param delay: A delay before polling for job status |
||||
:type delay: Union[int, float, None] |
||||
|
||||
:raises: AirflowException |
||||
|
||||
.. note:: |
||||
This method adds a small random jitter to the ``delay`` (+/- 2 sec, >= 1 sec). |
||||
Using a random interval helps to avoid AWS API throttle limits when many |
||||
concurrent tasks request job-descriptions. |
||||
|
||||
It also modifies the ``max_attempts`` to use the ``sys.maxsize``, |
||||
which allows Airflow to manage the timeout on waiting. |
||||
""" |
||||
self.delay(delay) |
||||
try: |
||||
waiter = self.get_waiter("JobExists") |
||||
waiter.config.delay = self.add_jitter( |
||||
waiter.config.delay, width=2, minima=1 |
||||
) |
||||
waiter.config.max_attempts = sys.maxsize # timeout is managed by Airflow |
||||
waiter.wait(jobs=[job_id]) |
||||
|
||||
waiter = self.get_waiter("JobRunning") |
||||
waiter.config.delay = self.add_jitter( |
||||
waiter.config.delay, width=2, minima=1 |
||||
) |
||||
waiter.config.max_attempts = sys.maxsize # timeout is managed by Airflow |
||||
waiter.wait(jobs=[job_id]) |
||||
|
||||
waiter = self.get_waiter("JobComplete") |
||||
waiter.config.delay = self.add_jitter( |
||||
waiter.config.delay, width=2, minima=1 |
||||
) |
||||
waiter.config.max_attempts = sys.maxsize # timeout is managed by Airflow |
||||
waiter.wait(jobs=[job_id]) |
||||
|
||||
except ( |
||||
botocore.exceptions.ClientError, |
||||
botocore.exceptions.WaiterError, |
||||
) as err: |
||||
raise AirflowException(err) |
||||
@ -0,0 +1,79 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
"""This module contains AWS CloudFormation Hook""" |
||||
from typing import Optional, Union |
||||
|
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
from boto3 import client, resource |
||||
from botocore.exceptions import ClientError |
||||
|
||||
|
||||
class AWSCloudFormationHook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS CloudFormation. |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
""" |
||||
|
||||
def __init__(self, *args, **kwargs): |
||||
super().__init__(client_type="cloudformation", *args, **kwargs) |
||||
|
||||
def get_stack_status(self, stack_name: Union[client, resource]) -> Optional[dict]: |
||||
"""Get stack status from CloudFormation.""" |
||||
self.log.info("Poking for stack %s", stack_name) |
||||
|
||||
try: |
||||
stacks = self.get_conn().describe_stacks(StackName=stack_name)["Stacks"] |
||||
return stacks[0]["StackStatus"] |
||||
except ClientError as e: |
||||
if "does not exist" in str(e): |
||||
return None |
||||
else: |
||||
raise e |
||||
|
||||
def create_stack(self, stack_name: str, params: dict) -> None: |
||||
""" |
||||
Create stack in CloudFormation. |
||||
|
||||
:param stack_name: stack_name. |
||||
:type stack_name: str |
||||
:param params: parameters to be passed to CloudFormation. |
||||
:type params: dict |
||||
""" |
||||
if "StackName" not in params: |
||||
params["StackName"] = stack_name |
||||
self.get_conn().create_stack(**params) |
||||
|
||||
def delete_stack(self, stack_name: str, params: Optional[dict] = None) -> None: |
||||
""" |
||||
Delete stack in CloudFormation. |
||||
|
||||
:param stack_name: stack_name. |
||||
:type stack_name: str |
||||
:param params: parameters to be passed to CloudFormation (optional). |
||||
:type params: dict |
||||
""" |
||||
params = params or {} |
||||
if "StackName" not in params: |
||||
params["StackName"] = stack_name |
||||
self.get_conn().delete_stack(**params) |
||||
@ -0,0 +1,334 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
"""Interact with AWS DataSync, using the AWS ``boto3`` library.""" |
||||
|
||||
import time |
||||
from typing import List, Optional |
||||
|
||||
from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowTaskTimeout |
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class AWSDataSyncHook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS DataSync. |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
:class:`~airflow.providers.amazon.aws.operators.datasync.AWSDataSyncOperator` |
||||
|
||||
:param int wait_for_task_execution: Time to wait between two |
||||
consecutive calls to check TaskExecution status. |
||||
:raises ValueError: If wait_interval_seconds is not between 0 and 15*60 seconds. |
||||
""" |
||||
|
||||
TASK_EXECUTION_INTERMEDIATE_STATES = ( |
||||
"INITIALIZING", |
||||
"QUEUED", |
||||
"LAUNCHING", |
||||
"PREPARING", |
||||
"TRANSFERRING", |
||||
"VERIFYING", |
||||
) |
||||
TASK_EXECUTION_FAILURE_STATES = ("ERROR",) |
||||
TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",) |
||||
|
||||
def __init__(self, wait_interval_seconds: int = 5, *args, **kwargs) -> None: |
||||
super().__init__(client_type="datasync", *args, **kwargs) # type: ignore[misc] |
||||
self.locations: list = [] |
||||
self.tasks: list = [] |
||||
# wait_interval_seconds = 0 is used during unit tests |
||||
if wait_interval_seconds < 0 or wait_interval_seconds > 15 * 60: |
||||
raise ValueError(f"Invalid wait_interval_seconds {wait_interval_seconds}") |
||||
self.wait_interval_seconds = wait_interval_seconds |
||||
|
||||
def create_location(self, location_uri: str, **create_location_kwargs) -> str: |
||||
r"""Creates a new location. |
||||
|
||||
:param str location_uri: Location URI used to determine the location type (S3, SMB, NFS, EFS). |
||||
:param create_location_kwargs: Passed to ``boto.create_location_xyz()``. |
||||
See AWS boto3 datasync documentation. |
||||
:return str: LocationArn of the created Location. |
||||
:raises AirflowException: If location type (prefix from ``location_uri``) is invalid. |
||||
""" |
||||
typ = location_uri.split(":")[0] |
||||
if typ == "smb": |
||||
location = self.get_conn().create_location_smb(**create_location_kwargs) |
||||
elif typ == "s3": |
||||
location = self.get_conn().create_location_s3(**create_location_kwargs) |
||||
elif typ == "nfs": |
||||
location = self.get_conn().create_loction_nfs(**create_location_kwargs) |
||||
elif typ == "efs": |
||||
location = self.get_conn().create_loction_efs(**create_location_kwargs) |
||||
else: |
||||
raise AirflowException(f"Invalid location type: {typ}") |
||||
self._refresh_locations() |
||||
return location["LocationArn"] |
||||
|
||||
def get_location_arns( |
||||
self, |
||||
location_uri: str, |
||||
case_sensitive: bool = False, |
||||
ignore_trailing_slash: bool = True, |
||||
) -> List[str]: |
||||
""" |
||||
Return all LocationArns which match a LocationUri. |
||||
|
||||
:param str location_uri: Location URI to search for, eg ``s3://mybucket/mypath`` |
||||
:param bool case_sensitive: Do a case sensitive search for location URI. |
||||
:param bool ignore_trailing_slash: Ignore / at the end of URI when matching. |
||||
:return: List of LocationArns. |
||||
:rtype: list(str) |
||||
:raises AirflowBadRequest: if ``location_uri`` is empty |
||||
""" |
||||
if not location_uri: |
||||
raise AirflowBadRequest("location_uri not specified") |
||||
if not self.locations: |
||||
self._refresh_locations() |
||||
result = [] |
||||
|
||||
if not case_sensitive: |
||||
location_uri = location_uri.lower() |
||||
if ignore_trailing_slash and location_uri.endswith("/"): |
||||
location_uri = location_uri[:-1] |
||||
|
||||
for location_from_aws in self.locations: |
||||
location_uri_from_aws = location_from_aws["LocationUri"] |
||||
if not case_sensitive: |
||||
location_uri_from_aws = location_uri_from_aws.lower() |
||||
if ignore_trailing_slash and location_uri_from_aws.endswith("/"): |
||||
location_uri_from_aws = location_uri_from_aws[:-1] |
||||
if location_uri == location_uri_from_aws: |
||||
result.append(location_from_aws["LocationArn"]) |
||||
return result |
||||
|
||||
def _refresh_locations(self) -> None: |
||||
"""Refresh the local list of Locations.""" |
||||
self.locations = [] |
||||
next_token = None |
||||
while True: |
||||
if next_token: |
||||
locations = self.get_conn().list_locations(NextToken=next_token) |
||||
else: |
||||
locations = self.get_conn().list_locations() |
||||
self.locations.extend(locations["Locations"]) |
||||
if "NextToken" not in locations: |
||||
break |
||||
next_token = locations["NextToken"] |
||||
|
||||
def create_task( |
||||
self, |
||||
source_location_arn: str, |
||||
destination_location_arn: str, |
||||
**create_task_kwargs, |
||||
) -> str: |
||||
r"""Create a Task between the specified source and destination LocationArns. |
||||
|
||||
:param str source_location_arn: Source LocationArn. Must exist already. |
||||
:param str destination_location_arn: Destination LocationArn. Must exist already. |
||||
:param create_task_kwargs: Passed to ``boto.create_task()``. See AWS boto3 datasync documentation. |
||||
:return: TaskArn of the created Task |
||||
:rtype: str |
||||
""" |
||||
task = self.get_conn().create_task( |
||||
SourceLocationArn=source_location_arn, |
||||
DestinationLocationArn=destination_location_arn, |
||||
**create_task_kwargs, |
||||
) |
||||
self._refresh_tasks() |
||||
return task["TaskArn"] |
||||
|
||||
def update_task(self, task_arn: str, **update_task_kwargs) -> None: |
||||
r"""Update a Task. |
||||
|
||||
:param str task_arn: The TaskArn to update. |
||||
:param update_task_kwargs: Passed to ``boto.update_task()``, See AWS boto3 datasync documentation. |
||||
""" |
||||
self.get_conn().update_task(TaskArn=task_arn, **update_task_kwargs) |
||||
|
||||
def delete_task(self, task_arn: str) -> None: |
||||
r"""Delete a Task. |
||||
|
||||
:param str task_arn: The TaskArn to delete. |
||||
""" |
||||
self.get_conn().delete_task(TaskArn=task_arn) |
||||
|
||||
def _refresh_tasks(self) -> None: |
||||
"""Refreshes the local list of Tasks""" |
||||
self.tasks = [] |
||||
next_token = None |
||||
while True: |
||||
if next_token: |
||||
tasks = self.get_conn().list_tasks(NextToken=next_token) |
||||
else: |
||||
tasks = self.get_conn().list_tasks() |
||||
self.tasks.extend(tasks["Tasks"]) |
||||
if "NextToken" not in tasks: |
||||
break |
||||
next_token = tasks["NextToken"] |
||||
|
||||
def get_task_arns_for_location_arns( |
||||
self, |
||||
source_location_arns: list, |
||||
destination_location_arns: list, |
||||
) -> list: |
||||
""" |
||||
Return list of TaskArns for which use any one of the specified |
||||
source LocationArns and any one of the specified destination LocationArns. |
||||
|
||||
:param list source_location_arns: List of source LocationArns. |
||||
:param list destination_location_arns: List of destination LocationArns. |
||||
:return: list |
||||
:rtype: list(TaskArns) |
||||
:raises AirflowBadRequest: if ``source_location_arns`` or ``destination_location_arns`` are empty. |
||||
""" |
||||
if not source_location_arns: |
||||
raise AirflowBadRequest("source_location_arns not specified") |
||||
if not destination_location_arns: |
||||
raise AirflowBadRequest("destination_location_arns not specified") |
||||
if not self.tasks: |
||||
self._refresh_tasks() |
||||
|
||||
result = [] |
||||
for task in self.tasks: |
||||
task_arn = task["TaskArn"] |
||||
task_description = self.get_task_description(task_arn) |
||||
if task_description["SourceLocationArn"] in source_location_arns: |
||||
if ( |
||||
task_description["DestinationLocationArn"] |
||||
in destination_location_arns |
||||
): |
||||
result.append(task_arn) |
||||
return result |
||||
|
||||
def start_task_execution(self, task_arn: str, **kwargs) -> str: |
||||
r""" |
||||
Start a TaskExecution for the specified task_arn. |
||||
Each task can have at most one TaskExecution. |
||||
|
||||
:param str task_arn: TaskArn |
||||
:return: TaskExecutionArn |
||||
:param kwargs: kwargs sent to ``boto3.start_task_execution()`` |
||||
:rtype: str |
||||
:raises ClientError: If a TaskExecution is already busy running for this ``task_arn``. |
||||
:raises AirflowBadRequest: If ``task_arn`` is empty. |
||||
""" |
||||
if not task_arn: |
||||
raise AirflowBadRequest("task_arn not specified") |
||||
task_execution = self.get_conn().start_task_execution( |
||||
TaskArn=task_arn, **kwargs |
||||
) |
||||
return task_execution["TaskExecutionArn"] |
||||
|
||||
def cancel_task_execution(self, task_execution_arn: str) -> None: |
||||
""" |
||||
Cancel a TaskExecution for the specified ``task_execution_arn``. |
||||
|
||||
:param str task_execution_arn: TaskExecutionArn. |
||||
:raises AirflowBadRequest: If ``task_execution_arn`` is empty. |
||||
""" |
||||
if not task_execution_arn: |
||||
raise AirflowBadRequest("task_execution_arn not specified") |
||||
self.get_conn().cancel_task_execution(TaskExecutionArn=task_execution_arn) |
||||
|
||||
def get_task_description(self, task_arn: str) -> dict: |
||||
""" |
||||
Get description for the specified ``task_arn``. |
||||
|
||||
:param str task_arn: TaskArn |
||||
:return: AWS metadata about a task. |
||||
:rtype: dict |
||||
:raises AirflowBadRequest: If ``task_arn`` is empty. |
||||
""" |
||||
if not task_arn: |
||||
raise AirflowBadRequest("task_arn not specified") |
||||
return self.get_conn().describe_task(TaskArn=task_arn) |
||||
|
||||
def describe_task_execution(self, task_execution_arn: str) -> dict: |
||||
""" |
||||
Get description for the specified ``task_execution_arn``. |
||||
|
||||
:param str task_execution_arn: TaskExecutionArn |
||||
:return: AWS metadata about a task execution. |
||||
:rtype: dict |
||||
:raises AirflowBadRequest: If ``task_execution_arn`` is empty. |
||||
""" |
||||
return self.get_conn().describe_task_execution( |
||||
TaskExecutionArn=task_execution_arn |
||||
) |
||||
|
||||
def get_current_task_execution_arn(self, task_arn: str) -> Optional[str]: |
||||
""" |
||||
Get current TaskExecutionArn (if one exists) for the specified ``task_arn``. |
||||
|
||||
:param str task_arn: TaskArn |
||||
:return: CurrentTaskExecutionArn for this ``task_arn`` or None. |
||||
:rtype: str |
||||
:raises AirflowBadRequest: if ``task_arn`` is empty. |
||||
""" |
||||
if not task_arn: |
||||
raise AirflowBadRequest("task_arn not specified") |
||||
task_description = self.get_task_description(task_arn) |
||||
if "CurrentTaskExecutionArn" in task_description: |
||||
return task_description["CurrentTaskExecutionArn"] |
||||
return None |
||||
|
||||
def wait_for_task_execution( |
||||
self, task_execution_arn: str, max_iterations: int = 2 * 180 |
||||
) -> bool: |
||||
""" |
||||
Wait for Task Execution status to be complete (SUCCESS/ERROR). |
||||
The ``task_execution_arn`` must exist, or a boto3 ClientError will be raised. |
||||
|
||||
:param str task_execution_arn: TaskExecutionArn |
||||
:param int max_iterations: Maximum number of iterations before timing out. |
||||
:return: Result of task execution. |
||||
:rtype: bool |
||||
:raises AirflowTaskTimeout: If maximum iterations is exceeded. |
||||
:raises AirflowBadRequest: If ``task_execution_arn`` is empty. |
||||
""" |
||||
if not task_execution_arn: |
||||
raise AirflowBadRequest("task_execution_arn not specified") |
||||
|
||||
status = None |
||||
iterations = max_iterations |
||||
while status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES: |
||||
task_execution = self.get_conn().describe_task_execution( |
||||
TaskExecutionArn=task_execution_arn |
||||
) |
||||
status = task_execution["Status"] |
||||
self.log.info("status=%s", status) |
||||
iterations -= 1 |
||||
if status in self.TASK_EXECUTION_FAILURE_STATES: |
||||
break |
||||
if status in self.TASK_EXECUTION_SUCCESS_STATES: |
||||
break |
||||
if iterations <= 0: |
||||
break |
||||
time.sleep(self.wait_interval_seconds) |
||||
|
||||
if status in self.TASK_EXECUTION_SUCCESS_STATES: |
||||
return True |
||||
if status in self.TASK_EXECUTION_FAILURE_STATES: |
||||
return False |
||||
if iterations <= 0: |
||||
raise AirflowTaskTimeout("Max iterations exceeded!") |
||||
raise AirflowException(f"Unknown status: {status}") # Should never happen |
||||
@ -0,0 +1,67 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
|
||||
"""This module contains the AWS DynamoDB hook""" |
||||
from typing import Iterable, List, Optional |
||||
|
||||
from airflow.exceptions import AirflowException |
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class AwsDynamoDBHook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS DynamoDB. |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
|
||||
:param table_keys: partition key and sort key |
||||
:type table_keys: list |
||||
:param table_name: target DynamoDB table |
||||
:type table_name: str |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
*args, |
||||
table_keys: Optional[List] = None, |
||||
table_name: Optional[str] = None, |
||||
**kwargs, |
||||
) -> None: |
||||
self.table_keys = table_keys |
||||
self.table_name = table_name |
||||
kwargs["resource_type"] = "dynamodb" |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
def write_batch_data(self, items: Iterable) -> bool: |
||||
"""Write batch items to DynamoDB table with provisioned throughout capacity.""" |
||||
try: |
||||
table = self.get_conn().Table(self.table_name) |
||||
|
||||
with table.batch_writer(overwrite_by_pkeys=self.table_keys) as batch: |
||||
for item in items: |
||||
batch.put_item(Item=item) |
||||
return True |
||||
except Exception as general_error: |
||||
raise AirflowException( |
||||
f"Failed to insert items in dynamodb, error: {str(general_error)}" |
||||
) |
||||
@ -0,0 +1,82 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
# |
||||
|
||||
import time |
||||
|
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class EC2Hook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS EC2 Service. |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
""" |
||||
|
||||
def __init__(self, *args, **kwargs) -> None: |
||||
kwargs["resource_type"] = "ec2" |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
def get_instance(self, instance_id: str): |
||||
""" |
||||
Get EC2 instance by id and return it. |
||||
|
||||
:param instance_id: id of the AWS EC2 instance |
||||
:type instance_id: str |
||||
:return: Instance object |
||||
:rtype: ec2.Instance |
||||
""" |
||||
return self.conn.Instance(id=instance_id) |
||||
|
||||
def get_instance_state(self, instance_id: str) -> str: |
||||
""" |
||||
Get EC2 instance state by id and return it. |
||||
|
||||
:param instance_id: id of the AWS EC2 instance |
||||
:type instance_id: str |
||||
:return: current state of the instance |
||||
:rtype: str |
||||
""" |
||||
return self.get_instance(instance_id=instance_id).state["Name"] |
||||
|
||||
def wait_for_state( |
||||
self, instance_id: str, target_state: str, check_interval: float |
||||
) -> None: |
||||
""" |
||||
Wait EC2 instance until its state is equal to the target_state. |
||||
|
||||
:param instance_id: id of the AWS EC2 instance |
||||
:type instance_id: str |
||||
:param target_state: target state of instance |
||||
:type target_state: str |
||||
:param check_interval: time in seconds that the job should wait in |
||||
between each instance state checks until operation is completed |
||||
:type check_interval: float |
||||
:return: None |
||||
:rtype: None |
||||
""" |
||||
instance_state = self.get_instance_state(instance_id=instance_id) |
||||
while instance_state != target_state: |
||||
self.log.info("instance state: %s", instance_state) |
||||
time.sleep(check_interval) |
||||
instance_state = self.get_instance_state(instance_id=instance_id) |
||||
@ -0,0 +1,325 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
from time import sleep |
||||
from typing import Optional |
||||
|
||||
from airflow.exceptions import AirflowException |
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class ElastiCacheReplicationGroupHook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS ElastiCache |
||||
|
||||
:param max_retries: Max retries for checking availability of and deleting replication group |
||||
If this is not supplied then this is defaulted to 10 |
||||
:type max_retries: int |
||||
:param exponential_back_off_factor: Multiplication factor for deciding next sleep time |
||||
If this is not supplied then this is defaulted to 1 |
||||
:type exponential_back_off_factor: float |
||||
:param initial_poke_interval: Initial sleep time in seconds |
||||
If this is not supplied then this is defaulted to 60 seconds |
||||
:type initial_poke_interval: float |
||||
""" |
||||
|
||||
TERMINAL_STATES = frozenset({"available", "create-failed", "deleting"}) |
||||
|
||||
def __init__( |
||||
self, |
||||
max_retries: int = 10, |
||||
exponential_back_off_factor: float = 1, |
||||
initial_poke_interval: float = 60, |
||||
*args, |
||||
**kwargs, |
||||
): |
||||
self.max_retries = max_retries |
||||
self.exponential_back_off_factor = exponential_back_off_factor |
||||
self.initial_poke_interval = initial_poke_interval |
||||
|
||||
kwargs["client_type"] = "elasticache" |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
def create_replication_group(self, config: dict) -> dict: |
||||
""" |
||||
Call ElastiCache API for creating a replication group |
||||
|
||||
:param config: Configuration for creating the replication group |
||||
:type config: dict |
||||
:return: Response from ElastiCache create replication group API |
||||
:rtype: dict |
||||
""" |
||||
return self.conn.create_replication_group(**config) |
||||
|
||||
def delete_replication_group(self, replication_group_id: str) -> dict: |
||||
""" |
||||
Call ElastiCache API for deleting a replication group |
||||
|
||||
:param replication_group_id: ID of replication group to delete |
||||
:type replication_group_id: str |
||||
:return: Response from ElastiCache delete replication group API |
||||
:rtype: dict |
||||
""" |
||||
return self.conn.delete_replication_group( |
||||
ReplicationGroupId=replication_group_id |
||||
) |
||||
|
||||
def describe_replication_group(self, replication_group_id: str) -> dict: |
||||
""" |
||||
Call ElastiCache API for describing a replication group |
||||
|
||||
:param replication_group_id: ID of replication group to describe |
||||
:type replication_group_id: str |
||||
:return: Response from ElastiCache describe replication group API |
||||
:rtype: dict |
||||
""" |
||||
return self.conn.describe_replication_groups( |
||||
ReplicationGroupId=replication_group_id |
||||
) |
||||
|
||||
def get_replication_group_status(self, replication_group_id: str) -> str: |
||||
""" |
||||
Get current status of replication group |
||||
|
||||
:param replication_group_id: ID of replication group to check for status |
||||
:type replication_group_id: str |
||||
:return: Current status of replication group |
||||
:rtype: str |
||||
""" |
||||
return self.describe_replication_group(replication_group_id)[ |
||||
"ReplicationGroups" |
||||
][0]["Status"] |
||||
|
||||
def is_replication_group_available(self, replication_group_id: str) -> bool: |
||||
""" |
||||
Helper for checking if replication group is available or not |
||||
|
||||
:param replication_group_id: ID of replication group to check for availability |
||||
:type replication_group_id: str |
||||
:return: True if available else False |
||||
:rtype: bool |
||||
""" |
||||
return self.get_replication_group_status(replication_group_id) == "available" |
||||
|
||||
def wait_for_availability( |
||||
self, |
||||
replication_group_id: str, |
||||
initial_sleep_time: Optional[float] = None, |
||||
exponential_back_off_factor: Optional[float] = None, |
||||
max_retries: Optional[int] = None, |
||||
): |
||||
""" |
||||
Check if replication group is available or not by performing a describe over it |
||||
|
||||
:param replication_group_id: ID of replication group to check for availability |
||||
:type replication_group_id: str |
||||
:param initial_sleep_time: Initial sleep time in seconds |
||||
If this is not supplied then this is defaulted to class level value |
||||
:type initial_sleep_time: float |
||||
:param exponential_back_off_factor: Multiplication factor for deciding next sleep time |
||||
If this is not supplied then this is defaulted to class level value |
||||
:type exponential_back_off_factor: float |
||||
:param max_retries: Max retries for checking availability of replication group |
||||
If this is not supplied then this is defaulted to class level value |
||||
:type max_retries: int |
||||
:return: True if replication is available else False |
||||
:rtype: bool |
||||
""" |
||||
sleep_time = initial_sleep_time or self.initial_poke_interval |
||||
exponential_back_off_factor = ( |
||||
exponential_back_off_factor or self.exponential_back_off_factor |
||||
) |
||||
max_retries = max_retries or self.max_retries |
||||
num_tries = 0 |
||||
status = "not-found" |
||||
stop_poking = False |
||||
|
||||
while not stop_poking and num_tries <= max_retries: |
||||
status = self.get_replication_group_status( |
||||
replication_group_id=replication_group_id |
||||
) |
||||
stop_poking = status in self.TERMINAL_STATES |
||||
|
||||
self.log.info( |
||||
"Current status of replication group with ID %s is %s", |
||||
replication_group_id, |
||||
status, |
||||
) |
||||
|
||||
if not stop_poking: |
||||
num_tries += 1 |
||||
|
||||
# No point in sleeping if all tries have exhausted |
||||
if num_tries > max_retries: |
||||
break |
||||
|
||||
self.log.info( |
||||
"Poke retry %s. Sleep time %s seconds. Sleeping...", |
||||
num_tries, |
||||
sleep_time, |
||||
) |
||||
|
||||
sleep(sleep_time) |
||||
|
||||
sleep_time *= exponential_back_off_factor |
||||
|
||||
if status != "available": |
||||
self.log.warning( |
||||
'Replication group is not available. Current status is "%s"', status |
||||
) |
||||
|
||||
return False |
||||
|
||||
return True |
||||
|
||||
def wait_for_deletion( |
||||
self, |
||||
replication_group_id: str, |
||||
initial_sleep_time: Optional[float] = None, |
||||
exponential_back_off_factor: Optional[float] = None, |
||||
max_retries: Optional[int] = None, |
||||
): |
||||
""" |
||||
Helper for deleting a replication group ensuring it is either deleted or can't be deleted |
||||
|
||||
:param replication_group_id: ID of replication to delete |
||||
:type replication_group_id: str |
||||
:param initial_sleep_time: Initial sleep time in second |
||||
If this is not supplied then this is defaulted to class level value |
||||
:type initial_sleep_time: float |
||||
:param exponential_back_off_factor: Multiplication factor for deciding next sleep time |
||||
If this is not supplied then this is defaulted to class level value |
||||
:type exponential_back_off_factor: float |
||||
:param max_retries: Max retries for checking availability of replication group |
||||
If this is not supplied then this is defaulted to class level value |
||||
:type max_retries: int |
||||
:return: Response from ElastiCache delete replication group API and flag to identify if deleted or not |
||||
:rtype: (dict, bool) |
||||
""" |
||||
deleted = False |
||||
sleep_time = initial_sleep_time or self.initial_poke_interval |
||||
exponential_back_off_factor = ( |
||||
exponential_back_off_factor or self.exponential_back_off_factor |
||||
) |
||||
max_retries = max_retries or self.max_retries |
||||
num_tries = 0 |
||||
response = None |
||||
|
||||
while not deleted and num_tries <= max_retries: |
||||
try: |
||||
status = self.get_replication_group_status( |
||||
replication_group_id=replication_group_id |
||||
) |
||||
|
||||
self.log.info( |
||||
"Current status of replication group with ID %s is %s", |
||||
replication_group_id, |
||||
status, |
||||
) |
||||
|
||||
# Can only delete if status is `available` |
||||
# Status becomes `deleting` after this call so this will only run once |
||||
if status == "available": |
||||
self.log.info("Initiating delete and then wait for it to finish") |
||||
|
||||
response = self.delete_replication_group( |
||||
replication_group_id=replication_group_id |
||||
) |
||||
|
||||
except self.conn.exceptions.ReplicationGroupNotFoundFault: |
||||
self.log.info( |
||||
"Replication group with ID '%s' does not exist", |
||||
replication_group_id, |
||||
) |
||||
|
||||
deleted = True |
||||
|
||||
# This should never occur as we only issue a delete request when status is `available` |
||||
# which is a valid status for deletion. Still handling for safety. |
||||
except self.conn.exceptions.InvalidReplicationGroupStateFault as exp: |
||||
# status Error Response |
||||
# creating - Cache cluster <cluster_id> is not in a valid state to be deleted. |
||||
# deleting - Replication group <replication_group_id> has status deleting which is not valid |
||||
# for deletion. |
||||
# modifying - Replication group <replication_group_id> has status deleting which is not valid |
||||
# for deletion. |
||||
|
||||
message = exp.response["Error"]["Message"] |
||||
|
||||
self.log.warning( |
||||
"Received error message from AWS ElastiCache API : %s", message |
||||
) |
||||
|
||||
if not deleted: |
||||
num_tries += 1 |
||||
|
||||
# No point in sleeping if all tries have exhausted |
||||
if num_tries > max_retries: |
||||
break |
||||
|
||||
self.log.info( |
||||
"Poke retry %s. Sleep time %s seconds. Sleeping...", |
||||
num_tries, |
||||
sleep_time, |
||||
) |
||||
|
||||
sleep(sleep_time) |
||||
|
||||
sleep_time *= exponential_back_off_factor |
||||
|
||||
return response, deleted |
||||
|
||||
def ensure_delete_replication_group( |
||||
self, |
||||
replication_group_id: str, |
||||
initial_sleep_time: Optional[float] = None, |
||||
exponential_back_off_factor: Optional[float] = None, |
||||
max_retries: Optional[int] = None, |
||||
): |
||||
""" |
||||
Delete a replication group ensuring it is either deleted or can't be deleted |
||||
|
||||
:param replication_group_id: ID of replication to delete |
||||
:type replication_group_id: str |
||||
:param initial_sleep_time: Initial sleep time in second |
||||
If this is not supplied then this is defaulted to class level value |
||||
:type initial_sleep_time: float |
||||
:param exponential_back_off_factor: Multiplication factor for deciding next sleep time |
||||
If this is not supplied then this is defaulted to class level value |
||||
:type exponential_back_off_factor: float |
||||
:param max_retries: Max retries for checking availability of replication group |
||||
If this is not supplied then this is defaulted to class level value |
||||
:type max_retries: int |
||||
:return: Response from ElastiCache delete replication group API |
||||
:rtype: dict |
||||
:raises AirflowException: If replication group is not deleted |
||||
""" |
||||
self.log.info("Deleting replication group with ID %s", replication_group_id) |
||||
|
||||
response, deleted = self.wait_for_deletion( |
||||
replication_group_id=replication_group_id, |
||||
initial_sleep_time=initial_sleep_time, |
||||
exponential_back_off_factor=exponential_back_off_factor, |
||||
max_retries=max_retries, |
||||
) |
||||
|
||||
if not deleted: |
||||
raise AirflowException( |
||||
f'Replication group could not be deleted. Response "{response}"' |
||||
) |
||||
|
||||
return response |
||||
@ -0,0 +1,101 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
from typing import Any, Dict, List, Optional |
||||
|
||||
from airflow.exceptions import AirflowException |
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class EmrHook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS EMR. emr_conn_id is only necessary for using the |
||||
create_job_flow method. |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
""" |
||||
|
||||
conn_name_attr = "emr_conn_id" |
||||
default_conn_name = "emr_default" |
||||
conn_type = "emr" |
||||
hook_name = "Elastic MapReduce" |
||||
|
||||
def __init__( |
||||
self, emr_conn_id: Optional[str] = default_conn_name, *args, **kwargs |
||||
) -> None: |
||||
self.emr_conn_id = emr_conn_id |
||||
kwargs["client_type"] = "emr" |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
def get_cluster_id_by_name( |
||||
self, emr_cluster_name: str, cluster_states: List[str] |
||||
) -> Optional[str]: |
||||
""" |
||||
Fetch id of EMR cluster with given name and (optional) states. |
||||
Will return only if single id is found. |
||||
|
||||
:param emr_cluster_name: Name of a cluster to find |
||||
:type emr_cluster_name: str |
||||
:param cluster_states: State(s) of cluster to find |
||||
:type cluster_states: list |
||||
:return: id of the EMR cluster |
||||
""" |
||||
response = self.get_conn().list_clusters(ClusterStates=cluster_states) |
||||
|
||||
matching_clusters = list( |
||||
filter( |
||||
lambda cluster: cluster["Name"] == emr_cluster_name, |
||||
response["Clusters"], |
||||
) |
||||
) |
||||
|
||||
if len(matching_clusters) == 1: |
||||
cluster_id = matching_clusters[0]["Id"] |
||||
self.log.info( |
||||
"Found cluster name = %s id = %s", emr_cluster_name, cluster_id |
||||
) |
||||
return cluster_id |
||||
elif len(matching_clusters) > 1: |
||||
raise AirflowException( |
||||
f"More than one cluster found for name {emr_cluster_name}" |
||||
) |
||||
else: |
||||
self.log.info("No cluster found for name %s", emr_cluster_name) |
||||
return None |
||||
|
||||
def create_job_flow(self, job_flow_overrides: Dict[str, Any]) -> Dict[str, Any]: |
||||
""" |
||||
Creates a job flow using the config from the EMR connection. |
||||
Keys of the json extra hash may have the arguments of the boto3 |
||||
run_job_flow method. |
||||
Overrides for this config may be passed as the job_flow_overrides. |
||||
""" |
||||
if not self.emr_conn_id: |
||||
raise AirflowException("emr_conn_id must be present to use create_job_flow") |
||||
|
||||
emr_conn = self.get_connection(self.emr_conn_id) |
||||
|
||||
config = emr_conn.extra_dejson.copy() |
||||
config.update(job_flow_overrides) |
||||
|
||||
response = self.get_conn().run_job_flow(**config) |
||||
|
||||
return response |
||||
@ -0,0 +1,80 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
|
||||
from typing import Any, Dict |
||||
|
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class GlacierHook(AwsBaseHook): |
||||
"""Hook for connection with Amazon Glacier""" |
||||
|
||||
def __init__(self, aws_conn_id: str = "aws_default") -> None: |
||||
super().__init__(client_type="glacier") |
||||
self.aws_conn_id = aws_conn_id |
||||
|
||||
def retrieve_inventory(self, vault_name: str) -> Dict[str, Any]: |
||||
""" |
||||
Initiate an Amazon Glacier inventory-retrieval job |
||||
|
||||
:param vault_name: the Glacier vault on which job is executed |
||||
:type vault_name: str |
||||
""" |
||||
job_params = {"Type": "inventory-retrieval"} |
||||
self.log.info("Retrieving inventory for vault: %s", vault_name) |
||||
response = self.get_conn().initiate_job( |
||||
vaultName=vault_name, jobParameters=job_params |
||||
) |
||||
self.log.info("Initiated inventory-retrieval job for: %s", vault_name) |
||||
self.log.info("Retrieval Job ID: %s", response["jobId"]) |
||||
return response |
||||
|
||||
def retrieve_inventory_results( |
||||
self, vault_name: str, job_id: str |
||||
) -> Dict[str, Any]: |
||||
""" |
||||
Retrieve the results of an Amazon Glacier inventory-retrieval job |
||||
|
||||
:param vault_name: the Glacier vault on which job is executed |
||||
:type vault_name: string |
||||
:param job_id: the job ID was returned by retrieve_inventory() |
||||
:type job_id: str |
||||
""" |
||||
self.log.info("Retrieving the job results for vault: %s...", vault_name) |
||||
response = self.get_conn().get_job_output(vaultName=vault_name, jobId=job_id) |
||||
return response |
||||
|
||||
def describe_job(self, vault_name: str, job_id: str) -> Dict[str, Any]: |
||||
""" |
||||
Retrieve the status of an Amazon S3 Glacier job, such as an |
||||
inventory-retrieval job |
||||
|
||||
:param vault_name: the Glacier vault on which job is executed |
||||
:type vault_name: string |
||||
:param job_id: the job ID was returned by retrieve_inventory() |
||||
:type job_id: str |
||||
""" |
||||
self.log.info("Retrieving status for vault: %s and job %s", vault_name, job_id) |
||||
response = self.get_conn().describe_job(vaultName=vault_name, jobId=job_id) |
||||
self.log.info( |
||||
"Job status: %s, code status: %s", |
||||
response["Action"], |
||||
response["StatusCode"], |
||||
) |
||||
return response |
||||
@ -0,0 +1,205 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
import time |
||||
from typing import Dict, List, Optional |
||||
|
||||
from airflow.exceptions import AirflowException |
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class AwsGlueJobHook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS Glue - create job, trigger, crawler |
||||
|
||||
:param s3_bucket: S3 bucket where logs and local etl script will be uploaded |
||||
:type s3_bucket: Optional[str] |
||||
:param job_name: unique job name per AWS account |
||||
:type job_name: Optional[str] |
||||
:param desc: job description |
||||
:type desc: Optional[str] |
||||
:param concurrent_run_limit: The maximum number of concurrent runs allowed for a job |
||||
:type concurrent_run_limit: int |
||||
:param script_location: path to etl script on s3 |
||||
:type script_location: Optional[str] |
||||
:param retry_limit: Maximum number of times to retry this job if it fails |
||||
:type retry_limit: int |
||||
:param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job |
||||
:type num_of_dpus: int |
||||
:param region_name: aws region name (example: us-east-1) |
||||
:type region_name: Optional[str] |
||||
:param iam_role_name: AWS IAM Role for Glue Job Execution |
||||
:type iam_role_name: Optional[str] |
||||
:param create_job_kwargs: Extra arguments for Glue Job Creation |
||||
:type create_job_kwargs: Optional[dict] |
||||
""" |
||||
|
||||
JOB_POLL_INTERVAL = 6 # polls job status after every JOB_POLL_INTERVAL seconds |
||||
|
||||
def __init__( |
||||
self, |
||||
s3_bucket: Optional[str] = None, |
||||
job_name: Optional[str] = None, |
||||
desc: Optional[str] = None, |
||||
concurrent_run_limit: int = 1, |
||||
script_location: Optional[str] = None, |
||||
retry_limit: int = 0, |
||||
num_of_dpus: int = 10, |
||||
region_name: Optional[str] = None, |
||||
iam_role_name: Optional[str] = None, |
||||
create_job_kwargs: Optional[dict] = None, |
||||
*args, |
||||
**kwargs, |
||||
): # pylint: disable=too-many-arguments |
||||
self.job_name = job_name |
||||
self.desc = desc |
||||
self.concurrent_run_limit = concurrent_run_limit |
||||
self.script_location = script_location |
||||
self.retry_limit = retry_limit |
||||
self.num_of_dpus = num_of_dpus |
||||
self.region_name = region_name |
||||
self.s3_bucket = s3_bucket |
||||
self.role_name = iam_role_name |
||||
self.s3_glue_logs = "logs/glue-logs/" |
||||
self.create_job_kwargs = create_job_kwargs or {} |
||||
kwargs["client_type"] = "glue" |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
def list_jobs(self) -> List: |
||||
""":return: Lists of Jobs""" |
||||
conn = self.get_conn() |
||||
return conn.get_jobs() |
||||
|
||||
def get_iam_execution_role(self) -> Dict: |
||||
""":return: iam role for job execution""" |
||||
iam_client = self.get_client_type("iam", self.region_name) |
||||
|
||||
try: |
||||
glue_execution_role = iam_client.get_role(RoleName=self.role_name) |
||||
self.log.info("Iam Role Name: %s", self.role_name) |
||||
return glue_execution_role |
||||
except Exception as general_error: |
||||
self.log.error("Failed to create aws glue job, error: %s", general_error) |
||||
raise |
||||
|
||||
def initialize_job(self, script_arguments: Optional[dict] = None) -> Dict[str, str]: |
||||
""" |
||||
Initializes connection with AWS Glue |
||||
to run job |
||||
:return: |
||||
""" |
||||
glue_client = self.get_conn() |
||||
script_arguments = script_arguments or {} |
||||
|
||||
try: |
||||
job_name = self.get_or_create_glue_job() |
||||
job_run = glue_client.start_job_run( |
||||
JobName=job_name, Arguments=script_arguments |
||||
) |
||||
return job_run |
||||
except Exception as general_error: |
||||
self.log.error("Failed to run aws glue job, error: %s", general_error) |
||||
raise |
||||
|
||||
def get_job_state(self, job_name: str, run_id: str) -> str: |
||||
""" |
||||
Get state of the Glue job. The job state can be |
||||
running, finished, failed, stopped or timeout. |
||||
:param job_name: unique job name per AWS account |
||||
:type job_name: str |
||||
:param run_id: The job-run ID of the predecessor job run |
||||
:type run_id: str |
||||
:return: State of the Glue job |
||||
""" |
||||
glue_client = self.get_conn() |
||||
job_run = glue_client.get_job_run( |
||||
JobName=job_name, RunId=run_id, PredecessorsIncluded=True |
||||
) |
||||
job_run_state = job_run["JobRun"]["JobRunState"] |
||||
return job_run_state |
||||
|
||||
def job_completion(self, job_name: str, run_id: str) -> Dict[str, str]: |
||||
""" |
||||
Waits until Glue job with job_name completes or |
||||
fails and return final state if finished. |
||||
Raises AirflowException when the job failed |
||||
:param job_name: unique job name per AWS account |
||||
:type job_name: str |
||||
:param run_id: The job-run ID of the predecessor job run |
||||
:type run_id: str |
||||
:return: Dict of JobRunState and JobRunId |
||||
""" |
||||
failed_states = ["FAILED", "TIMEOUT"] |
||||
finished_states = ["SUCCEEDED", "STOPPED"] |
||||
|
||||
while True: |
||||
job_run_state = self.get_job_state(job_name, run_id) |
||||
if job_run_state in finished_states: |
||||
self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state) |
||||
return {"JobRunState": job_run_state, "JobRunId": run_id} |
||||
if job_run_state in failed_states: |
||||
job_error_message = ( |
||||
"Exiting Job " + run_id + " Run State: " + job_run_state |
||||
) |
||||
self.log.info(job_error_message) |
||||
raise AirflowException(job_error_message) |
||||
else: |
||||
self.log.info( |
||||
"Polling for AWS Glue Job %s current run state with status %s", |
||||
job_name, |
||||
job_run_state, |
||||
) |
||||
time.sleep(self.JOB_POLL_INTERVAL) |
||||
|
||||
def get_or_create_glue_job(self) -> str: |
||||
""" |
||||
Creates(or just returns) and returns the Job name |
||||
:return:Name of the Job |
||||
""" |
||||
glue_client = self.get_conn() |
||||
try: |
||||
get_job_response = glue_client.get_job(JobName=self.job_name) |
||||
self.log.info("Job Already exist. Returning Name of the job") |
||||
return get_job_response["Job"]["Name"] |
||||
|
||||
except glue_client.exceptions.EntityNotFoundException: |
||||
self.log.info("Job doesnt exist. Now creating and running AWS Glue Job") |
||||
if self.s3_bucket is None: |
||||
raise AirflowException( |
||||
"Could not initialize glue job, error: Specify Parameter `s3_bucket`" |
||||
) |
||||
s3_log_path = f"s3://{self.s3_bucket}/{self.s3_glue_logs}{self.job_name}" |
||||
execution_role = self.get_iam_execution_role() |
||||
try: |
||||
create_job_response = glue_client.create_job( |
||||
Name=self.job_name, |
||||
Description=self.desc, |
||||
LogUri=s3_log_path, |
||||
Role=execution_role["Role"]["RoleName"], |
||||
ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit}, |
||||
Command={"Name": "glueetl", "ScriptLocation": self.script_location}, |
||||
MaxRetries=self.retry_limit, |
||||
AllocatedCapacity=self.num_of_dpus, |
||||
**self.create_job_kwargs, |
||||
) |
||||
return create_job_response["Name"] |
||||
except Exception as general_error: |
||||
self.log.error( |
||||
"Failed to create aws glue job, error: %s", general_error |
||||
) |
||||
raise |
||||
@ -0,0 +1,142 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
"""This module contains AWS Glue Catalog Hook""" |
||||
from typing import Optional, Set |
||||
|
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class AwsGlueCatalogHook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS Glue Catalog |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
""" |
||||
|
||||
def __init__(self, *args, **kwargs): |
||||
super().__init__(client_type="glue", *args, **kwargs) |
||||
|
||||
def get_partitions( |
||||
self, |
||||
database_name: str, |
||||
table_name: str, |
||||
expression: str = "", |
||||
page_size: Optional[int] = None, |
||||
max_items: Optional[int] = None, |
||||
) -> Set[tuple]: |
||||
""" |
||||
Retrieves the partition values for a table. |
||||
|
||||
:param database_name: The name of the catalog database where the partitions reside. |
||||
:type database_name: str |
||||
:param table_name: The name of the partitions' table. |
||||
:type table_name: str |
||||
:param expression: An expression filtering the partitions to be returned. |
||||
Please see official AWS documentation for further information. |
||||
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartitions |
||||
:type expression: str |
||||
:param page_size: pagination size |
||||
:type page_size: int |
||||
:param max_items: maximum items to return |
||||
:type max_items: int |
||||
:return: set of partition values where each value is a tuple since |
||||
a partition may be composed of multiple columns. For example: |
||||
``{('2018-01-01','1'), ('2018-01-01','2')}`` |
||||
""" |
||||
config = { |
||||
"PageSize": page_size, |
||||
"MaxItems": max_items, |
||||
} |
||||
|
||||
paginator = self.get_conn().get_paginator("get_partitions") |
||||
response = paginator.paginate( |
||||
DatabaseName=database_name, |
||||
TableName=table_name, |
||||
Expression=expression, |
||||
PaginationConfig=config, |
||||
) |
||||
|
||||
partitions = set() |
||||
for page in response: |
||||
for partition in page["Partitions"]: |
||||
partitions.add(tuple(partition["Values"])) |
||||
|
||||
return partitions |
||||
|
||||
def check_for_partition( |
||||
self, database_name: str, table_name: str, expression: str |
||||
) -> bool: |
||||
""" |
||||
Checks whether a partition exists |
||||
|
||||
:param database_name: Name of hive database (schema) @table belongs to |
||||
:type database_name: str |
||||
:param table_name: Name of hive table @partition belongs to |
||||
:type table_name: str |
||||
:expression: Expression that matches the partitions to check for |
||||
(eg `a = 'b' AND c = 'd'`) |
||||
:type expression: str |
||||
:rtype: bool |
||||
|
||||
>>> hook = AwsGlueCatalogHook() |
||||
>>> t = 'static_babynames_partitioned' |
||||
>>> hook.check_for_partition('airflow', t, "ds='2015-01-01'") |
||||
True |
||||
""" |
||||
partitions = self.get_partitions( |
||||
database_name, table_name, expression, max_items=1 |
||||
) |
||||
|
||||
return bool(partitions) |
||||
|
||||
def get_table(self, database_name: str, table_name: str) -> dict: |
||||
""" |
||||
Get the information of the table |
||||
|
||||
:param database_name: Name of hive database (schema) @table belongs to |
||||
:type database_name: str |
||||
:param table_name: Name of hive table |
||||
:type table_name: str |
||||
:rtype: dict |
||||
|
||||
>>> hook = AwsGlueCatalogHook() |
||||
>>> r = hook.get_table('db', 'table_foo') |
||||
>>> r['Name'] = 'table_foo' |
||||
""" |
||||
result = self.get_conn().get_table(DatabaseName=database_name, Name=table_name) |
||||
|
||||
return result["Table"] |
||||
|
||||
def get_table_location(self, database_name: str, table_name: str) -> str: |
||||
""" |
||||
Get the physical location of the table |
||||
|
||||
:param database_name: Name of hive database (schema) @table belongs to |
||||
:type database_name: str |
||||
:param table_name: Name of hive table |
||||
:type table_name: str |
||||
:return: str |
||||
""" |
||||
table = self.get_table(database_name, table_name) |
||||
|
||||
return table["StorageDescriptor"]["Location"] |
||||
@ -0,0 +1,184 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
from time import sleep |
||||
|
||||
try: |
||||
from functools import cached_property |
||||
except ImportError: |
||||
from cached_property import cached_property |
||||
|
||||
from airflow.exceptions import AirflowException |
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class AwsGlueCrawlerHook(AwsBaseHook): |
||||
""" |
||||
Interacts with AWS Glue Crawler. |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
""" |
||||
|
||||
def __init__(self, *args, **kwargs): |
||||
kwargs["client_type"] = "glue" |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
@cached_property |
||||
def glue_client(self): |
||||
""":return: AWS Glue client""" |
||||
return self.get_conn() |
||||
|
||||
def has_crawler(self, crawler_name) -> bool: |
||||
""" |
||||
Checks if the crawler already exists |
||||
|
||||
:param crawler_name: unique crawler name per AWS account |
||||
:type crawler_name: str |
||||
:return: Returns True if the crawler already exists and False if not. |
||||
""" |
||||
self.log.info("Checking if crawler already exists: %s", crawler_name) |
||||
|
||||
try: |
||||
self.get_crawler(crawler_name) |
||||
return True |
||||
except self.glue_client.exceptions.EntityNotFoundException: |
||||
return False |
||||
|
||||
def get_crawler(self, crawler_name: str) -> dict: |
||||
""" |
||||
Gets crawler configurations |
||||
|
||||
:param crawler_name: unique crawler name per AWS account |
||||
:type crawler_name: str |
||||
:return: Nested dictionary of crawler configurations |
||||
""" |
||||
return self.glue_client.get_crawler(Name=crawler_name)["Crawler"] |
||||
|
||||
def update_crawler(self, **crawler_kwargs) -> str: |
||||
""" |
||||
Updates crawler configurations |
||||
|
||||
:param crawler_kwargs: Keyword args that define the configurations used for the crawler |
||||
:type crawler_kwargs: any |
||||
:return: True if crawler was updated and false otherwise |
||||
""" |
||||
crawler_name = crawler_kwargs["Name"] |
||||
current_crawler = self.get_crawler(crawler_name) |
||||
|
||||
update_config = { |
||||
key: value |
||||
for key, value in crawler_kwargs.items() |
||||
if current_crawler[key] != crawler_kwargs[key] |
||||
} |
||||
if update_config != {}: |
||||
self.log.info("Updating crawler: %s", crawler_name) |
||||
self.glue_client.update_crawler(**crawler_kwargs) |
||||
self.log.info("Updated configurations: %s", update_config) |
||||
return True |
||||
else: |
||||
return False |
||||
|
||||
def create_crawler(self, **crawler_kwargs) -> str: |
||||
""" |
||||
Creates an AWS Glue Crawler |
||||
|
||||
:param crawler_kwargs: Keyword args that define the configurations used to create the crawler |
||||
:type crawler_kwargs: any |
||||
:return: Name of the crawler |
||||
""" |
||||
crawler_name = crawler_kwargs["Name"] |
||||
self.log.info("Creating crawler: %s", crawler_name) |
||||
return self.glue_client.create_crawler(**crawler_kwargs)["Crawler"]["Name"] |
||||
|
||||
def start_crawler(self, crawler_name: str) -> dict: |
||||
""" |
||||
Triggers the AWS Glue crawler |
||||
|
||||
:param crawler_name: unique crawler name per AWS account |
||||
:type crawler_name: str |
||||
:return: Empty dictionary |
||||
""" |
||||
self.log.info("Starting crawler %s", crawler_name) |
||||
crawler = self.glue_client.start_crawler(Name=crawler_name) |
||||
return crawler |
||||
|
||||
def wait_for_crawler_completion( |
||||
self, crawler_name: str, poll_interval: int = 5 |
||||
) -> str: |
||||
""" |
||||
Waits until Glue crawler completes and |
||||
returns the status of the latest crawl run. |
||||
Raises AirflowException if the crawler fails or is cancelled. |
||||
|
||||
:param crawler_name: unique crawler name per AWS account |
||||
:type crawler_name: str |
||||
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check crawler status |
||||
:type poll_interval: int |
||||
:return: Crawler's status |
||||
""" |
||||
failed_status = ["FAILED", "CANCELLED"] |
||||
|
||||
while True: |
||||
crawler = self.get_crawler(crawler_name) |
||||
crawler_state = crawler["State"] |
||||
if crawler_state == "READY": |
||||
self.log.info("State: %s", crawler_state) |
||||
self.log.info("crawler_config: %s", crawler) |
||||
crawler_status = crawler["LastCrawl"]["Status"] |
||||
if crawler_status in failed_status: |
||||
raise AirflowException( |
||||
f"Status: {crawler_status}" |
||||
) # pylint: disable=raising-format-tuple |
||||
else: |
||||
metrics = self.glue_client.get_crawler_metrics( |
||||
CrawlerNameList=[crawler_name] |
||||
)["CrawlerMetricsList"][0] |
||||
self.log.info("Status: %s", crawler_status) |
||||
self.log.info( |
||||
"Last Runtime Duration (seconds): %s", |
||||
metrics["LastRuntimeSeconds"], |
||||
) |
||||
self.log.info( |
||||
"Median Runtime Duration (seconds): %s", |
||||
metrics["MedianRuntimeSeconds"], |
||||
) |
||||
self.log.info("Tables Created: %s", metrics["TablesCreated"]) |
||||
self.log.info("Tables Updated: %s", metrics["TablesUpdated"]) |
||||
self.log.info("Tables Deleted: %s", metrics["TablesDeleted"]) |
||||
|
||||
return crawler_status |
||||
|
||||
else: |
||||
self.log.info("Polling for AWS Glue crawler: %s ", crawler_name) |
||||
self.log.info("State: %s", crawler_state) |
||||
|
||||
metrics = self.glue_client.get_crawler_metrics( |
||||
CrawlerNameList=[crawler_name] |
||||
)["CrawlerMetricsList"][0] |
||||
time_left = int(metrics["TimeLeftSeconds"]) |
||||
|
||||
if time_left > 0: |
||||
self.log.info("Estimated Time Left (seconds): %s", time_left) |
||||
else: |
||||
self.log.info("Crawler should finish soon") |
||||
|
||||
sleep(poll_interval) |
||||
@ -0,0 +1,50 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
"""This module contains AWS Firehose hook""" |
||||
from typing import Iterable |
||||
|
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class AwsFirehoseHook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS Kinesis Firehose. |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
|
||||
:param delivery_stream: Name of the delivery stream |
||||
:type delivery_stream: str |
||||
""" |
||||
|
||||
def __init__(self, delivery_stream: str, *args, **kwargs) -> None: |
||||
self.delivery_stream = delivery_stream |
||||
kwargs["client_type"] = "firehose" |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
def put_records(self, records: Iterable): |
||||
"""Write batch records to Kinesis Firehose""" |
||||
response = self.get_conn().put_record_batch( |
||||
DeliveryStreamName=self.delivery_stream, Records=records |
||||
) |
||||
|
||||
return response |
||||
@ -0,0 +1,69 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
"""This module contains AWS Lambda hook""" |
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class AwsLambdaHook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS Lambda |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
|
||||
:param function_name: AWS Lambda Function Name |
||||
:type function_name: str |
||||
:param log_type: Tail Invocation Request |
||||
:type log_type: str |
||||
:param qualifier: AWS Lambda Function Version or Alias Name |
||||
:type qualifier: str |
||||
:param invocation_type: AWS Lambda Invocation Type (RequestResponse, Event etc) |
||||
:type invocation_type: str |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
function_name: str, |
||||
log_type: str = "None", |
||||
qualifier: str = "$LATEST", |
||||
invocation_type: str = "RequestResponse", |
||||
*args, |
||||
**kwargs, |
||||
) -> None: |
||||
self.function_name = function_name |
||||
self.log_type = log_type |
||||
self.invocation_type = invocation_type |
||||
self.qualifier = qualifier |
||||
kwargs["client_type"] = "lambda" |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
def invoke_lambda(self, payload: str) -> str: |
||||
"""Invoke Lambda Function""" |
||||
response = self.conn.invoke( |
||||
FunctionName=self.function_name, |
||||
InvocationType=self.invocation_type, |
||||
LogType=self.log_type, |
||||
Payload=payload, |
||||
Qualifier=self.qualifier, |
||||
) |
||||
|
||||
return response |
||||
@ -0,0 +1,105 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
""" |
||||
This module contains a hook (AwsLogsHook) with some very basic |
||||
functionality for interacting with AWS CloudWatch. |
||||
""" |
||||
from typing import Dict, Generator, Optional |
||||
|
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class AwsLogsHook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS CloudWatch Logs |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
""" |
||||
|
||||
def __init__(self, *args, **kwargs) -> None: |
||||
kwargs["client_type"] = "logs" |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
def get_log_events( |
||||
self, |
||||
log_group: str, |
||||
log_stream_name: str, |
||||
start_time: int = 0, |
||||
skip: int = 0, |
||||
start_from_head: bool = True, |
||||
) -> Generator: |
||||
""" |
||||
A generator for log items in a single stream. This will yield all the |
||||
items that are available at the current moment. |
||||
|
||||
:param log_group: The name of the log group. |
||||
:type log_group: str |
||||
:param log_stream_name: The name of the specific stream. |
||||
:type log_stream_name: str |
||||
:param start_time: The time stamp value to start reading the logs from (default: 0). |
||||
:type start_time: int |
||||
:param skip: The number of log entries to skip at the start (default: 0). |
||||
This is for when there are multiple entries at the same timestamp. |
||||
:type skip: int |
||||
:param start_from_head: whether to start from the beginning (True) of the log or |
||||
at the end of the log (False). |
||||
:type start_from_head: bool |
||||
:rtype: dict |
||||
:return: | A CloudWatch log event with the following key-value pairs: |
||||
| 'timestamp' (int): The time in milliseconds of the event. |
||||
| 'message' (str): The log event data. |
||||
| 'ingestionTime' (int): The time in milliseconds the event was ingested. |
||||
""" |
||||
next_token = None |
||||
|
||||
event_count = 1 |
||||
while event_count > 0: |
||||
if next_token is not None: |
||||
token_arg: Optional[Dict[str, str]] = {"nextToken": next_token} |
||||
else: |
||||
token_arg = {} |
||||
|
||||
response = self.get_conn().get_log_events( |
||||
logGroupName=log_group, |
||||
logStreamName=log_stream_name, |
||||
startTime=start_time, |
||||
startFromHead=start_from_head, |
||||
**token_arg, |
||||
) |
||||
|
||||
events = response["events"] |
||||
event_count = len(events) |
||||
|
||||
if event_count > skip: |
||||
events = events[skip:] |
||||
skip = 0 |
||||
else: |
||||
skip -= event_count |
||||
events = [] |
||||
|
||||
yield from events |
||||
|
||||
if "nextForwardToken" in response: |
||||
next_token = response["nextForwardToken"] |
||||
else: |
||||
return |
||||
@ -0,0 +1,131 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
"""Interact with AWS Redshift, using the boto3 library.""" |
||||
|
||||
from typing import List, Optional |
||||
|
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class RedshiftHook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS Redshift, using the boto3 library |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
""" |
||||
|
||||
def __init__(self, *args, **kwargs) -> None: |
||||
kwargs["client_type"] = "redshift" |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
# TODO: Wrap create_cluster_snapshot |
||||
def cluster_status(self, cluster_identifier: str) -> str: |
||||
""" |
||||
Return status of a cluster |
||||
|
||||
:param cluster_identifier: unique identifier of a cluster |
||||
:type cluster_identifier: str |
||||
""" |
||||
try: |
||||
response = self.get_conn().describe_clusters( |
||||
ClusterIdentifier=cluster_identifier |
||||
)["Clusters"] |
||||
return response[0]["ClusterStatus"] if response else None |
||||
except self.get_conn().exceptions.ClusterNotFoundFault: |
||||
return "cluster_not_found" |
||||
|
||||
def delete_cluster( # pylint: disable=invalid-name |
||||
self, |
||||
cluster_identifier: str, |
||||
skip_final_cluster_snapshot: bool = True, |
||||
final_cluster_snapshot_identifier: Optional[str] = None, |
||||
): |
||||
""" |
||||
Delete a cluster and optionally create a snapshot |
||||
|
||||
:param cluster_identifier: unique identifier of a cluster |
||||
:type cluster_identifier: str |
||||
:param skip_final_cluster_snapshot: determines cluster snapshot creation |
||||
:type skip_final_cluster_snapshot: bool |
||||
:param final_cluster_snapshot_identifier: name of final cluster snapshot |
||||
:type final_cluster_snapshot_identifier: str |
||||
""" |
||||
final_cluster_snapshot_identifier = final_cluster_snapshot_identifier or "" |
||||
|
||||
response = self.get_conn().delete_cluster( |
||||
ClusterIdentifier=cluster_identifier, |
||||
SkipFinalClusterSnapshot=skip_final_cluster_snapshot, |
||||
FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier, |
||||
) |
||||
return response["Cluster"] if response["Cluster"] else None |
||||
|
||||
def describe_cluster_snapshots( |
||||
self, cluster_identifier: str |
||||
) -> Optional[List[str]]: |
||||
""" |
||||
Gets a list of snapshots for a cluster |
||||
|
||||
:param cluster_identifier: unique identifier of a cluster |
||||
:type cluster_identifier: str |
||||
""" |
||||
response = self.get_conn().describe_cluster_snapshots( |
||||
ClusterIdentifier=cluster_identifier |
||||
) |
||||
if "Snapshots" not in response: |
||||
return None |
||||
snapshots = response["Snapshots"] |
||||
snapshots = [snapshot for snapshot in snapshots if snapshot["Status"]] |
||||
snapshots.sort(key=lambda x: x["SnapshotCreateTime"], reverse=True) |
||||
return snapshots |
||||
|
||||
def restore_from_cluster_snapshot( |
||||
self, cluster_identifier: str, snapshot_identifier: str |
||||
) -> str: |
||||
""" |
||||
Restores a cluster from its snapshot |
||||
|
||||
:param cluster_identifier: unique identifier of a cluster |
||||
:type cluster_identifier: str |
||||
:param snapshot_identifier: unique identifier for a snapshot of a cluster |
||||
:type snapshot_identifier: str |
||||
""" |
||||
response = self.get_conn().restore_from_cluster_snapshot( |
||||
ClusterIdentifier=cluster_identifier, SnapshotIdentifier=snapshot_identifier |
||||
) |
||||
return response["Cluster"] if response["Cluster"] else None |
||||
|
||||
def create_cluster_snapshot( |
||||
self, snapshot_identifier: str, cluster_identifier: str |
||||
) -> str: |
||||
""" |
||||
Creates a snapshot of a cluster |
||||
|
||||
:param snapshot_identifier: unique identifier for a snapshot of a cluster |
||||
:type snapshot_identifier: str |
||||
:param cluster_identifier: unique identifier of a cluster |
||||
:type cluster_identifier: str |
||||
""" |
||||
response = self.get_conn().create_cluster_snapshot( |
||||
SnapshotIdentifier=snapshot_identifier, |
||||
ClusterIdentifier=cluster_identifier, |
||||
) |
||||
return response["Snapshot"] if response["Snapshot"] else None |
||||
@ -0,0 +1,973 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
# pylint: disable=invalid-name |
||||
"""Interact with AWS S3, using the boto3 library.""" |
||||
import fnmatch |
||||
import gzip as gz |
||||
import io |
||||
import re |
||||
import shutil |
||||
from functools import wraps |
||||
from inspect import signature |
||||
from io import BytesIO |
||||
from tempfile import NamedTemporaryFile |
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, cast |
||||
from urllib.parse import urlparse |
||||
|
||||
from airflow.exceptions import AirflowException |
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
from airflow.utils.helpers import chunks |
||||
from boto3.s3.transfer import S3Transfer, TransferConfig |
||||
from botocore.exceptions import ClientError |
||||
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name |
||||
|
||||
|
||||
def provide_bucket_name(func: T) -> T: |
||||
""" |
||||
Function decorator that provides a bucket name taken from the connection |
||||
in case no bucket name has been passed to the function. |
||||
""" |
||||
function_signature = signature(func) |
||||
|
||||
@wraps(func) |
||||
def wrapper(*args, **kwargs) -> T: |
||||
bound_args = function_signature.bind(*args, **kwargs) |
||||
|
||||
if "bucket_name" not in bound_args.arguments: |
||||
self = args[0] |
||||
if self.aws_conn_id: |
||||
connection = self.get_connection(self.aws_conn_id) |
||||
if connection.schema: |
||||
bound_args.arguments["bucket_name"] = connection.schema |
||||
|
||||
return func(*bound_args.args, **bound_args.kwargs) |
||||
|
||||
return cast(T, wrapper) |
||||
|
||||
|
||||
def unify_bucket_name_and_key(func: T) -> T: |
||||
""" |
||||
Function decorator that unifies bucket name and key taken from the key |
||||
in case no bucket name and at least a key has been passed to the function. |
||||
""" |
||||
function_signature = signature(func) |
||||
|
||||
@wraps(func) |
||||
def wrapper(*args, **kwargs) -> T: |
||||
bound_args = function_signature.bind(*args, **kwargs) |
||||
|
||||
def get_key_name() -> Optional[str]: |
||||
if "wildcard_key" in bound_args.arguments: |
||||
return "wildcard_key" |
||||
if "key" in bound_args.arguments: |
||||
return "key" |
||||
raise ValueError("Missing key parameter!") |
||||
|
||||
key_name = get_key_name() |
||||
if key_name and "bucket_name" not in bound_args.arguments: |
||||
( |
||||
bound_args.arguments["bucket_name"], |
||||
bound_args.arguments[key_name], |
||||
) = S3Hook.parse_s3_url(bound_args.arguments[key_name]) |
||||
|
||||
return func(*bound_args.args, **bound_args.kwargs) |
||||
|
||||
return cast(T, wrapper) |
||||
|
||||
|
||||
class S3Hook(AwsBaseHook): |
||||
""" |
||||
Interact with AWS S3, using the boto3 library. |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
""" |
||||
|
||||
conn_type = "s3" |
||||
hook_name = "S3" |
||||
|
||||
def __init__(self, *args, **kwargs) -> None: |
||||
kwargs["client_type"] = "s3" |
||||
|
||||
self.extra_args = {} |
||||
if "extra_args" in kwargs: |
||||
self.extra_args = kwargs["extra_args"] |
||||
if not isinstance(self.extra_args, dict): |
||||
raise ValueError( |
||||
f"extra_args '{self.extra_args!r}' must be of type {dict}" |
||||
) |
||||
del kwargs["extra_args"] |
||||
|
||||
self.transfer_config = TransferConfig() |
||||
if "transfer_config_args" in kwargs: |
||||
transport_config_args = kwargs["transfer_config_args"] |
||||
if not isinstance(transport_config_args, dict): |
||||
raise ValueError( |
||||
f"transfer_config_args '{transport_config_args!r} must be of type {dict}" |
||||
) |
||||
self.transfer_config = TransferConfig(**transport_config_args) |
||||
del kwargs["transfer_config_args"] |
||||
|
||||
super().__init__(*args, **kwargs) |
||||
|
||||
@staticmethod |
||||
def parse_s3_url(s3url: str) -> Tuple[str, str]: |
||||
""" |
||||
Parses the S3 Url into a bucket name and key. |
||||
|
||||
:param s3url: The S3 Url to parse. |
||||
:rtype s3url: str |
||||
:return: the parsed bucket name and key |
||||
:rtype: tuple of str |
||||
""" |
||||
parsed_url = urlparse(s3url) |
||||
|
||||
if not parsed_url.netloc: |
||||
raise AirflowException(f'Please provide a bucket_name instead of "{s3url}"') |
||||
|
||||
bucket_name = parsed_url.netloc |
||||
key = parsed_url.path.strip("/") |
||||
|
||||
return bucket_name, key |
||||
|
||||
@provide_bucket_name |
||||
def check_for_bucket(self, bucket_name: Optional[str] = None) -> bool: |
||||
""" |
||||
Check if bucket_name exists. |
||||
|
||||
:param bucket_name: the name of the bucket |
||||
:type bucket_name: str |
||||
:return: True if it exists and False if not. |
||||
:rtype: bool |
||||
""" |
||||
try: |
||||
self.get_conn().head_bucket(Bucket=bucket_name) |
||||
return True |
||||
except ClientError as e: |
||||
self.log.error(e.response["Error"]["Message"]) |
||||
return False |
||||
|
||||
@provide_bucket_name |
||||
def get_bucket(self, bucket_name: Optional[str] = None) -> str: |
||||
""" |
||||
Returns a boto3.S3.Bucket object |
||||
|
||||
:param bucket_name: the name of the bucket |
||||
:type bucket_name: str |
||||
:return: the bucket object to the bucket name. |
||||
:rtype: boto3.S3.Bucket |
||||
""" |
||||
s3_resource = self.get_resource_type("s3") |
||||
return s3_resource.Bucket(bucket_name) |
||||
|
||||
@provide_bucket_name |
||||
def create_bucket( |
||||
self, bucket_name: Optional[str] = None, region_name: Optional[str] = None |
||||
) -> None: |
||||
""" |
||||
Creates an Amazon S3 bucket. |
||||
|
||||
:param bucket_name: The name of the bucket |
||||
:type bucket_name: str |
||||
:param region_name: The name of the aws region in which to create the bucket. |
||||
:type region_name: str |
||||
""" |
||||
if not region_name: |
||||
region_name = self.get_conn().meta.region_name |
||||
if region_name == "us-east-1": |
||||
self.get_conn().create_bucket(Bucket=bucket_name) |
||||
else: |
||||
self.get_conn().create_bucket( |
||||
Bucket=bucket_name, |
||||
CreateBucketConfiguration={"LocationConstraint": region_name}, |
||||
) |
||||
|
||||
@provide_bucket_name |
||||
def check_for_prefix( |
||||
self, prefix: str, delimiter: str, bucket_name: Optional[str] = None |
||||
) -> bool: |
||||
""" |
||||
Checks that a prefix exists in a bucket |
||||
|
||||
:param bucket_name: the name of the bucket |
||||
:type bucket_name: str |
||||
:param prefix: a key prefix |
||||
:type prefix: str |
||||
:param delimiter: the delimiter marks key hierarchy. |
||||
:type delimiter: str |
||||
:return: False if the prefix does not exist in the bucket and True if it does. |
||||
:rtype: bool |
||||
""" |
||||
prefix = prefix + delimiter if prefix[-1] != delimiter else prefix |
||||
prefix_split = re.split(fr"(\w+[{delimiter}])$", prefix, 1) |
||||
previous_level = prefix_split[0] |
||||
plist = self.list_prefixes(bucket_name, previous_level, delimiter) |
||||
return prefix in plist |
||||
|
||||
@provide_bucket_name |
||||
def list_prefixes( |
||||
self, |
||||
bucket_name: Optional[str] = None, |
||||
prefix: Optional[str] = None, |
||||
delimiter: Optional[str] = None, |
||||
page_size: Optional[int] = None, |
||||
max_items: Optional[int] = None, |
||||
) -> list: |
||||
""" |
||||
Lists prefixes in a bucket under prefix |
||||
|
||||
:param bucket_name: the name of the bucket |
||||
:type bucket_name: str |
||||
:param prefix: a key prefix |
||||
:type prefix: str |
||||
:param delimiter: the delimiter marks key hierarchy. |
||||
:type delimiter: str |
||||
:param page_size: pagination size |
||||
:type page_size: int |
||||
:param max_items: maximum items to return |
||||
:type max_items: int |
||||
:return: a list of matched prefixes |
||||
:rtype: list |
||||
""" |
||||
prefix = prefix or "" |
||||
delimiter = delimiter or "" |
||||
config = { |
||||
"PageSize": page_size, |
||||
"MaxItems": max_items, |
||||
} |
||||
|
||||
paginator = self.get_conn().get_paginator("list_objects_v2") |
||||
response = paginator.paginate( |
||||
Bucket=bucket_name, |
||||
Prefix=prefix, |
||||
Delimiter=delimiter, |
||||
PaginationConfig=config, |
||||
) |
||||
|
||||
prefixes = [] |
||||
for page in response: |
||||
if "CommonPrefixes" in page: |
||||
for common_prefix in page["CommonPrefixes"]: |
||||
prefixes.append(common_prefix["Prefix"]) |
||||
|
||||
return prefixes |
||||
|
||||
@provide_bucket_name |
||||
def list_keys( |
||||
self, |
||||
bucket_name: Optional[str] = None, |
||||
prefix: Optional[str] = None, |
||||
delimiter: Optional[str] = None, |
||||
page_size: Optional[int] = None, |
||||
max_items: Optional[int] = None, |
||||
) -> list: |
||||
""" |
||||
Lists keys in a bucket under prefix and not containing delimiter |
||||
|
||||
:param bucket_name: the name of the bucket |
||||
:type bucket_name: str |
||||
:param prefix: a key prefix |
||||
:type prefix: str |
||||
:param delimiter: the delimiter marks key hierarchy. |
||||
:type delimiter: str |
||||
:param page_size: pagination size |
||||
:type page_size: int |
||||
:param max_items: maximum items to return |
||||
:type max_items: int |
||||
:return: a list of matched keys |
||||
:rtype: list |
||||
""" |
||||
prefix = prefix or "" |
||||
delimiter = delimiter or "" |
||||
config = { |
||||
"PageSize": page_size, |
||||
"MaxItems": max_items, |
||||
} |
||||
|
||||
paginator = self.get_conn().get_paginator("list_objects_v2") |
||||
response = paginator.paginate( |
||||
Bucket=bucket_name, |
||||
Prefix=prefix, |
||||
Delimiter=delimiter, |
||||
PaginationConfig=config, |
||||
) |
||||
|
||||
keys = [] |
||||
for page in response: |
||||
if "Contents" in page: |
||||
for k in page["Contents"]: |
||||
keys.append(k["Key"]) |
||||
|
||||
return keys |
||||
|
||||
@provide_bucket_name |
||||
@unify_bucket_name_and_key |
||||
def check_for_key(self, key: str, bucket_name: Optional[str] = None) -> bool: |
||||
""" |
||||
Checks if a key exists in a bucket |
||||
|
||||
:param key: S3 key that will point to the file |
||||
:type key: str |
||||
:param bucket_name: Name of the bucket in which the file is stored |
||||
:type bucket_name: str |
||||
:return: True if the key exists and False if not. |
||||
:rtype: bool |
||||
""" |
||||
try: |
||||
self.get_conn().head_object(Bucket=bucket_name, Key=key) |
||||
return True |
||||
except ClientError as e: |
||||
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: |
||||
return False |
||||
else: |
||||
raise e |
||||
|
||||
@provide_bucket_name |
||||
@unify_bucket_name_and_key |
||||
def get_key(self, key: str, bucket_name: Optional[str] = None) -> S3Transfer: |
||||
""" |
||||
Returns a boto3.s3.Object |
||||
|
||||
:param key: the path to the key |
||||
:type key: str |
||||
:param bucket_name: the name of the bucket |
||||
:type bucket_name: str |
||||
:return: the key object from the bucket |
||||
:rtype: boto3.s3.Object |
||||
""" |
||||
obj = self.get_resource_type("s3").Object(bucket_name, key) |
||||
obj.load() |
||||
return obj |
||||
|
||||
@provide_bucket_name |
||||
@unify_bucket_name_and_key |
||||
def read_key(self, key: str, bucket_name: Optional[str] = None) -> str: |
||||
""" |
||||
Reads a key from S3 |
||||
|
||||
:param key: S3 key that will point to the file |
||||
:type key: str |
||||
:param bucket_name: Name of the bucket in which the file is stored |
||||
:type bucket_name: str |
||||
:return: the content of the key |
||||
:rtype: str |
||||
""" |
||||
obj = self.get_key(key, bucket_name) |
||||
return obj.get()["Body"].read().decode("utf-8") |
||||
|
||||
@provide_bucket_name |
||||
@unify_bucket_name_and_key |
||||
def select_key( |
||||
self, |
||||
key: str, |
||||
bucket_name: Optional[str] = None, |
||||
expression: Optional[str] = None, |
||||
expression_type: Optional[str] = None, |
||||
input_serialization: Optional[Dict[str, Any]] = None, |
||||
output_serialization: Optional[Dict[str, Any]] = None, |
||||
) -> str: |
||||
""" |
||||
Reads a key with S3 Select. |
||||
|
||||
:param key: S3 key that will point to the file |
||||
:type key: str |
||||
:param bucket_name: Name of the bucket in which the file is stored |
||||
:type bucket_name: str |
||||
:param expression: S3 Select expression |
||||
:type expression: str |
||||
:param expression_type: S3 Select expression type |
||||
:type expression_type: str |
||||
:param input_serialization: S3 Select input data serialization format |
||||
:type input_serialization: dict |
||||
:param output_serialization: S3 Select output data serialization format |
||||
:type output_serialization: dict |
||||
:return: retrieved subset of original data by S3 Select |
||||
:rtype: str |
||||
|
||||
.. seealso:: |
||||
For more details about S3 Select parameters: |
||||
http://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.select_object_content |
||||
""" |
||||
expression = expression or "SELECT * FROM S3Object" |
||||
expression_type = expression_type or "SQL" |
||||
|
||||
if input_serialization is None: |
||||
input_serialization = {"CSV": {}} |
||||
if output_serialization is None: |
||||
output_serialization = {"CSV": {}} |
||||
|
||||
response = self.get_conn().select_object_content( |
||||
Bucket=bucket_name, |
||||
Key=key, |
||||
Expression=expression, |
||||
ExpressionType=expression_type, |
||||
InputSerialization=input_serialization, |
||||
OutputSerialization=output_serialization, |
||||
) |
||||
|
||||
return "".join( |
||||
event["Records"]["Payload"].decode("utf-8") |
||||
for event in response["Payload"] |
||||
if "Records" in event |
||||
) |
||||
|
||||
@provide_bucket_name |
||||
@unify_bucket_name_and_key |
||||
def check_for_wildcard_key( |
||||
self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = "" |
||||
) -> bool: |
||||
""" |
||||
Checks that a key matching a wildcard expression exists in a bucket |
||||
|
||||
:param wildcard_key: the path to the key |
||||
:type wildcard_key: str |
||||
:param bucket_name: the name of the bucket |
||||
:type bucket_name: str |
||||
:param delimiter: the delimiter marks key hierarchy |
||||
:type delimiter: str |
||||
:return: True if a key exists and False if not. |
||||
:rtype: bool |
||||
""" |
||||
return ( |
||||
self.get_wildcard_key( |
||||
wildcard_key=wildcard_key, bucket_name=bucket_name, delimiter=delimiter |
||||
) |
||||
is not None |
||||
) |
||||
|
||||
@provide_bucket_name |
||||
@unify_bucket_name_and_key |
||||
def get_wildcard_key( |
||||
self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = "" |
||||
) -> S3Transfer: |
||||
""" |
||||
Returns a boto3.s3.Object object matching the wildcard expression |
||||
|
||||
:param wildcard_key: the path to the key |
||||
:type wildcard_key: str |
||||
:param bucket_name: the name of the bucket |
||||
:type bucket_name: str |
||||
:param delimiter: the delimiter marks key hierarchy |
||||
:type delimiter: str |
||||
:return: the key object from the bucket or None if none has been found. |
||||
:rtype: boto3.s3.Object |
||||
""" |
||||
prefix = re.split(r"[*]", wildcard_key, 1)[0] |
||||
key_list = self.list_keys(bucket_name, prefix=prefix, delimiter=delimiter) |
||||
key_matches = [k for k in key_list if fnmatch.fnmatch(k, wildcard_key)] |
||||
if key_matches: |
||||
return self.get_key(key_matches[0], bucket_name) |
||||
return None |
||||
|
||||
@provide_bucket_name |
||||
@unify_bucket_name_and_key |
||||
def load_file( |
||||
self, |
||||
filename: str, |
||||
key: str, |
||||
bucket_name: Optional[str] = None, |
||||
replace: bool = False, |
||||
encrypt: bool = False, |
||||
gzip: bool = False, |
||||
acl_policy: Optional[str] = None, |
||||
) -> None: |
||||
""" |
||||
Loads a local file to S3 |
||||
|
||||
:param filename: name of the file to load. |
||||
:type filename: str |
||||
:param key: S3 key that will point to the file |
||||
:type key: str |
||||
:param bucket_name: Name of the bucket in which to store the file |
||||
:type bucket_name: str |
||||
:param replace: A flag to decide whether or not to overwrite the key |
||||
if it already exists. If replace is False and the key exists, an |
||||
error will be raised. |
||||
:type replace: bool |
||||
:param encrypt: If True, the file will be encrypted on the server-side |
||||
by S3 and will be stored in an encrypted form while at rest in S3. |
||||
:type encrypt: bool |
||||
:param gzip: If True, the file will be compressed locally |
||||
:type gzip: bool |
||||
:param acl_policy: String specifying the canned ACL policy for the file being |
||||
uploaded to the S3 bucket. |
||||
:type acl_policy: str |
||||
""" |
||||
if not replace and self.check_for_key(key, bucket_name): |
||||
raise ValueError(f"The key {key} already exists.") |
||||
|
||||
extra_args = self.extra_args |
||||
if encrypt: |
||||
extra_args["ServerSideEncryption"] = "AES256" |
||||
if gzip: |
||||
with open(filename, "rb") as f_in: |
||||
filename_gz = f_in.name + ".gz" |
||||
with gz.open(filename_gz, "wb") as f_out: |
||||
shutil.copyfileobj(f_in, f_out) |
||||
filename = filename_gz |
||||
if acl_policy: |
||||
extra_args["ACL"] = acl_policy |
||||
|
||||
client = self.get_conn() |
||||
client.upload_file( |
||||
filename, |
||||
bucket_name, |
||||
key, |
||||
ExtraArgs=extra_args, |
||||
Config=self.transfer_config, |
||||
) |
||||
|
||||
@provide_bucket_name |
||||
@unify_bucket_name_and_key |
||||
def load_string( |
||||
self, |
||||
string_data: str, |
||||
key: str, |
||||
bucket_name: Optional[str] = None, |
||||
replace: bool = False, |
||||
encrypt: bool = False, |
||||
encoding: Optional[str] = None, |
||||
acl_policy: Optional[str] = None, |
||||
compression: Optional[str] = None, |
||||
) -> None: |
||||
""" |
||||
Loads a string to S3 |
||||
|
||||
This is provided as a convenience to drop a string in S3. It uses the |
||||
boto infrastructure to ship a file to s3. |
||||
|
||||
:param string_data: str to set as content for the key. |
||||
:type string_data: str |
||||
:param key: S3 key that will point to the file |
||||
:type key: str |
||||
:param bucket_name: Name of the bucket in which to store the file |
||||
:type bucket_name: str |
||||
:param replace: A flag to decide whether or not to overwrite the key |
||||
if it already exists |
||||
:type replace: bool |
||||
:param encrypt: If True, the file will be encrypted on the server-side |
||||
by S3 and will be stored in an encrypted form while at rest in S3. |
||||
:type encrypt: bool |
||||
:param encoding: The string to byte encoding |
||||
:type encoding: str |
||||
:param acl_policy: The string to specify the canned ACL policy for the |
||||
object to be uploaded |
||||
:type acl_policy: str |
||||
:param compression: Type of compression to use, currently only gzip is supported. |
||||
:type compression: str |
||||
""" |
||||
encoding = encoding or "utf-8" |
||||
|
||||
bytes_data = string_data.encode(encoding) |
||||
|
||||
# Compress string |
||||
available_compressions = ["gzip"] |
||||
if compression is not None and compression not in available_compressions: |
||||
raise NotImplementedError( |
||||
"Received {} compression type. String " |
||||
"can currently be compressed in {} " |
||||
"only.".format(compression, available_compressions) |
||||
) |
||||
if compression == "gzip": |
||||
bytes_data = gz.compress(bytes_data) |
||||
|
||||
file_obj = io.BytesIO(bytes_data) |
||||
|
||||
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy) |
||||
file_obj.close() |
||||
|
||||
@provide_bucket_name |
||||
@unify_bucket_name_and_key |
||||
def load_bytes( |
||||
self, |
||||
bytes_data: bytes, |
||||
key: str, |
||||
bucket_name: Optional[str] = None, |
||||
replace: bool = False, |
||||
encrypt: bool = False, |
||||
acl_policy: Optional[str] = None, |
||||
) -> None: |
||||
""" |
||||
Loads bytes to S3 |
||||
|
||||
This is provided as a convenience to drop a string in S3. It uses the |
||||
boto infrastructure to ship a file to s3. |
||||
|
||||
:param bytes_data: bytes to set as content for the key. |
||||
:type bytes_data: bytes |
||||
:param key: S3 key that will point to the file |
||||
:type key: str |
||||
:param bucket_name: Name of the bucket in which to store the file |
||||
:type bucket_name: str |
||||
:param replace: A flag to decide whether or not to overwrite the key |
||||
if it already exists |
||||
:type replace: bool |
||||
:param encrypt: If True, the file will be encrypted on the server-side |
||||
by S3 and will be stored in an encrypted form while at rest in S3. |
||||
:type encrypt: bool |
||||
:param acl_policy: The string to specify the canned ACL policy for the |
||||
object to be uploaded |
||||
:type acl_policy: str |
||||
""" |
||||
file_obj = io.BytesIO(bytes_data) |
||||
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy) |
||||
file_obj.close() |
||||
|
||||
@provide_bucket_name |
||||
@unify_bucket_name_and_key |
||||
def load_file_obj( |
||||
self, |
||||
file_obj: BytesIO, |
||||
key: str, |
||||
bucket_name: Optional[str] = None, |
||||
replace: bool = False, |
||||
encrypt: bool = False, |
||||
acl_policy: Optional[str] = None, |
||||
) -> None: |
||||
""" |
||||
Loads a file object to S3 |
||||
|
||||
:param file_obj: The file-like object to set as the content for the S3 key. |
||||
:type file_obj: file-like object |
||||
:param key: S3 key that will point to the file |
||||
:type key: str |
||||
:param bucket_name: Name of the bucket in which to store the file |
||||
:type bucket_name: str |
||||
:param replace: A flag that indicates whether to overwrite the key |
||||
if it already exists. |
||||
:type replace: bool |
||||
:param encrypt: If True, S3 encrypts the file on the server, |
||||
and the file is stored in encrypted form at rest in S3. |
||||
:type encrypt: bool |
||||
:param acl_policy: The string to specify the canned ACL policy for the |
||||
object to be uploaded |
||||
:type acl_policy: str |
||||
""" |
||||
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy) |
||||
|
||||
def _upload_file_obj( |
||||
self, |
||||
file_obj: BytesIO, |
||||
key: str, |
||||
bucket_name: Optional[str] = None, |
||||
replace: bool = False, |
||||
encrypt: bool = False, |
||||
acl_policy: Optional[str] = None, |
||||
) -> None: |
||||
if not replace and self.check_for_key(key, bucket_name): |
||||
raise ValueError(f"The key {key} already exists.") |
||||
|
||||
extra_args = self.extra_args |
||||
if encrypt: |
||||
extra_args["ServerSideEncryption"] = "AES256" |
||||
if acl_policy: |
||||
extra_args["ACL"] = acl_policy |
||||
|
||||
client = self.get_conn() |
||||
client.upload_fileobj( |
||||
file_obj, |
||||
bucket_name, |
||||
key, |
||||
ExtraArgs=extra_args, |
||||
Config=self.transfer_config, |
||||
) |
||||
|
||||
def copy_object( |
||||
self, |
||||
source_bucket_key: str, |
||||
dest_bucket_key: str, |
||||
source_bucket_name: Optional[str] = None, |
||||
dest_bucket_name: Optional[str] = None, |
||||
source_version_id: Optional[str] = None, |
||||
acl_policy: Optional[str] = None, |
||||
) -> None: |
||||
""" |
||||
Creates a copy of an object that is already stored in S3. |
||||
|
||||
Note: the S3 connection used here needs to have access to both |
||||
source and destination bucket/key. |
||||
|
||||
:param source_bucket_key: The key of the source object. |
||||
|
||||
It can be either full s3:// style url or relative path from root level. |
||||
|
||||
When it's specified as a full s3:// url, please omit source_bucket_name. |
||||
:type source_bucket_key: str |
||||
:param dest_bucket_key: The key of the object to copy to. |
||||
|
||||
The convention to specify `dest_bucket_key` is the same |
||||
as `source_bucket_key`. |
||||
:type dest_bucket_key: str |
||||
:param source_bucket_name: Name of the S3 bucket where the source object is in. |
||||
|
||||
It should be omitted when `source_bucket_key` is provided as a full s3:// url. |
||||
:type source_bucket_name: str |
||||
:param dest_bucket_name: Name of the S3 bucket to where the object is copied. |
||||
|
||||
It should be omitted when `dest_bucket_key` is provided as a full s3:// url. |
||||
:type dest_bucket_name: str |
||||
:param source_version_id: Version ID of the source object (OPTIONAL) |
||||
:type source_version_id: str |
||||
:param acl_policy: The string to specify the canned ACL policy for the |
||||
object to be copied which is private by default. |
||||
:type acl_policy: str |
||||
""" |
||||
acl_policy = acl_policy or "private" |
||||
|
||||
if dest_bucket_name is None: |
||||
dest_bucket_name, dest_bucket_key = self.parse_s3_url(dest_bucket_key) |
||||
else: |
||||
parsed_url = urlparse(dest_bucket_key) |
||||
if parsed_url.scheme != "" or parsed_url.netloc != "": |
||||
raise AirflowException( |
||||
"If dest_bucket_name is provided, " |
||||
+ "dest_bucket_key should be relative path " |
||||
+ "from root level, rather than a full s3:// url" |
||||
) |
||||
|
||||
if source_bucket_name is None: |
||||
source_bucket_name, source_bucket_key = self.parse_s3_url(source_bucket_key) |
||||
else: |
||||
parsed_url = urlparse(source_bucket_key) |
||||
if parsed_url.scheme != "" or parsed_url.netloc != "": |
||||
raise AirflowException( |
||||
"If source_bucket_name is provided, " |
||||
+ "source_bucket_key should be relative path " |
||||
+ "from root level, rather than a full s3:// url" |
||||
) |
||||
|
||||
copy_source = { |
||||
"Bucket": source_bucket_name, |
||||
"Key": source_bucket_key, |
||||
"VersionId": source_version_id, |
||||
} |
||||
response = self.get_conn().copy_object( |
||||
Bucket=dest_bucket_name, |
||||
Key=dest_bucket_key, |
||||
CopySource=copy_source, |
||||
ACL=acl_policy, |
||||
) |
||||
return response |
||||
|
||||
@provide_bucket_name |
||||
def delete_bucket(self, bucket_name: str, force_delete: bool = False) -> None: |
||||
""" |
||||
To delete s3 bucket, delete all s3 bucket objects and then delete the bucket. |
||||
|
||||
:param bucket_name: Bucket name |
||||
:type bucket_name: str |
||||
:param force_delete: Enable this to delete bucket even if not empty |
||||
:type force_delete: bool |
||||
:return: None |
||||
:rtype: None |
||||
""" |
||||
if force_delete: |
||||
bucket_keys = self.list_keys(bucket_name=bucket_name) |
||||
if bucket_keys: |
||||
self.delete_objects(bucket=bucket_name, keys=bucket_keys) |
||||
self.conn.delete_bucket(Bucket=bucket_name) |
||||
|
||||
def delete_objects(self, bucket: str, keys: Union[str, list]) -> None: |
||||
""" |
||||
Delete keys from the bucket. |
||||
|
||||
:param bucket: Name of the bucket in which you are going to delete object(s) |
||||
:type bucket: str |
||||
:param keys: The key(s) to delete from S3 bucket. |
||||
|
||||
When ``keys`` is a string, it's supposed to be the key name of |
||||
the single object to delete. |
||||
|
||||
When ``keys`` is a list, it's supposed to be the list of the |
||||
keys to delete. |
||||
:type keys: str or list |
||||
""" |
||||
if isinstance(keys, str): |
||||
keys = [keys] |
||||
|
||||
s3 = self.get_conn() |
||||
|
||||
# We can only send a maximum of 1000 keys per request. |
||||
# For details see: |
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.delete_objects |
||||
for chunk in chunks(keys, chunk_size=1000): |
||||
response = s3.delete_objects( |
||||
Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]} |
||||
) |
||||
deleted_keys = [x["Key"] for x in response.get("Deleted", [])] |
||||
self.log.info("Deleted: %s", deleted_keys) |
||||
if "Errors" in response: |
||||
errors_keys = [x["Key"] for x in response.get("Errors", [])] |
||||
raise AirflowException(f"Errors when deleting: {errors_keys}") |
||||
|
||||
@provide_bucket_name |
||||
@unify_bucket_name_and_key |
||||
def download_file( |
||||
self, |
||||
key: str, |
||||
bucket_name: Optional[str] = None, |
||||
local_path: Optional[str] = None, |
||||
) -> str: |
||||
""" |
||||
Downloads a file from the S3 location to the local file system. |
||||
|
||||
:param key: The key path in S3. |
||||
:type key: str |
||||
:param bucket_name: The specific bucket to use. |
||||
:type bucket_name: Optional[str] |
||||
:param local_path: The local path to the downloaded file. If no path is provided it will use the |
||||
system's temporary directory. |
||||
:type local_path: Optional[str] |
||||
:return: the file name. |
||||
:rtype: str |
||||
""" |
||||
self.log.info( |
||||
"Downloading source S3 file from Bucket %s with path %s", bucket_name, key |
||||
) |
||||
|
||||
if not self.check_for_key(key, bucket_name): |
||||
raise AirflowException( |
||||
f"The source file in Bucket {bucket_name} with path {key} does not exist" |
||||
) |
||||
|
||||
s3_obj = self.get_key(key, bucket_name) |
||||
|
||||
with NamedTemporaryFile( |
||||
dir=local_path, prefix="airflow_tmp_", delete=False |
||||
) as local_tmp_file: |
||||
s3_obj.download_fileobj(local_tmp_file) |
||||
|
||||
return local_tmp_file.name |
||||
|
||||
def generate_presigned_url( |
||||
self, |
||||
client_method: str, |
||||
params: Optional[dict] = None, |
||||
expires_in: int = 3600, |
||||
http_method: Optional[str] = None, |
||||
) -> Optional[str]: |
||||
""" |
||||
Generate a presigned url given a client, its method, and arguments |
||||
|
||||
:param client_method: The client method to presign for. |
||||
:type client_method: str |
||||
:param params: The parameters normally passed to ClientMethod. |
||||
:type params: dict |
||||
:param expires_in: The number of seconds the presigned url is valid for. |
||||
By default it expires in an hour (3600 seconds). |
||||
:type expires_in: int |
||||
:param http_method: The http method to use on the generated url. |
||||
By default, the http method is whatever is used in the method's model. |
||||
:type http_method: str |
||||
:return: The presigned url. |
||||
:rtype: str |
||||
""" |
||||
s3_client = self.get_conn() |
||||
try: |
||||
return s3_client.generate_presigned_url( |
||||
ClientMethod=client_method, |
||||
Params=params, |
||||
ExpiresIn=expires_in, |
||||
HttpMethod=http_method, |
||||
) |
||||
|
||||
except ClientError as e: |
||||
self.log.error(e.response["Error"]["Message"]) |
||||
return None |
||||
|
||||
@provide_bucket_name |
||||
def get_bucket_tagging( |
||||
self, bucket_name: Optional[str] = None |
||||
) -> Optional[List[Dict[str, str]]]: |
||||
""" |
||||
Gets a List of tags from a bucket. |
||||
|
||||
:param bucket_name: The name of the bucket. |
||||
:type bucket_name: str |
||||
:return: A List containing the key/value pairs for the tags |
||||
:rtype: Optional[List[Dict[str, str]]] |
||||
""" |
||||
try: |
||||
s3_client = self.get_conn() |
||||
result = s3_client.get_bucket_tagging(Bucket=bucket_name)["TagSet"] |
||||
self.log.info("S3 Bucket Tag Info: %s", result) |
||||
return result |
||||
except ClientError as e: |
||||
self.log.error(e) |
||||
raise e |
||||
|
||||
@provide_bucket_name |
||||
def put_bucket_tagging( |
||||
self, |
||||
tag_set: Optional[List[Dict[str, str]]] = None, |
||||
key: Optional[str] = None, |
||||
value: Optional[str] = None, |
||||
bucket_name: Optional[str] = None, |
||||
) -> None: |
||||
""" |
||||
Overwrites the existing TagSet with provided tags. Must provide either a TagSet or a key/value pair. |
||||
|
||||
:param tag_set: A List containing the key/value pairs for the tags. |
||||
:type tag_set: List[Dict[str, str]] |
||||
:param key: The Key for the new TagSet entry. |
||||
:type key: str |
||||
:param value: The Value for the new TagSet entry. |
||||
:type value: str |
||||
:param bucket_name: The name of the bucket. |
||||
:type bucket_name: str |
||||
:return: None |
||||
:rtype: None |
||||
""" |
||||
self.log.info( |
||||
"S3 Bucket Tag Info:\tKey: %s\tValue: %s\tSet: %s", key, value, tag_set |
||||
) |
||||
if not tag_set: |
||||
tag_set = [] |
||||
if key and value: |
||||
tag_set.append({"Key": key, "Value": value}) |
||||
elif not tag_set or (key or value): |
||||
message = "put_bucket_tagging() requires either a predefined TagSet or a key/value pair." |
||||
self.log.error(message) |
||||
raise ValueError(message) |
||||
|
||||
try: |
||||
s3_client = self.get_conn() |
||||
s3_client.put_bucket_tagging( |
||||
Bucket=bucket_name, Tagging={"TagSet": tag_set} |
||||
) |
||||
except ClientError as e: |
||||
self.log.error(e) |
||||
raise e |
||||
|
||||
@provide_bucket_name |
||||
def delete_bucket_tagging(self, bucket_name: Optional[str] = None) -> None: |
||||
""" |
||||
Deletes all tags from a bucket. |
||||
|
||||
:param bucket_name: The name of the bucket. |
||||
:type bucket_name: str |
||||
:return: None |
||||
:rtype: None |
||||
""" |
||||
s3_client = self.get_conn() |
||||
s3_client.delete_bucket_tagging(Bucket=bucket_name) |
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,71 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
# |
||||
|
||||
import base64 |
||||
import json |
||||
from typing import Union |
||||
|
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class SecretsManagerHook(AwsBaseHook): |
||||
""" |
||||
Interact with Amazon SecretsManager Service. |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. see also:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
""" |
||||
|
||||
def __init__(self, *args, **kwargs): |
||||
super().__init__(client_type="secretsmanager", *args, **kwargs) |
||||
|
||||
def get_secret(self, secret_name: str) -> Union[str, bytes]: |
||||
""" |
||||
Retrieve secret value from AWS Secrets Manager as a str or bytes |
||||
reflecting format it stored in the AWS Secrets Manager |
||||
|
||||
:param secret_name: name of the secrets. |
||||
:type secret_name: str |
||||
:return: Union[str, bytes] with the information about the secrets |
||||
:rtype: Union[str, bytes] |
||||
""" |
||||
# Depending on whether the secret is a string or binary, one of |
||||
# these fields will be populated. |
||||
get_secret_value_response = self.get_conn().get_secret_value( |
||||
SecretId=secret_name |
||||
) |
||||
if "SecretString" in get_secret_value_response: |
||||
secret = get_secret_value_response["SecretString"] |
||||
else: |
||||
secret = base64.b64decode(get_secret_value_response["SecretBinary"]) |
||||
return secret |
||||
|
||||
def get_secret_as_dict(self, secret_name: str) -> dict: |
||||
""" |
||||
Retrieve secret value from AWS Secrets Manager in a dict representation |
||||
|
||||
:param secret_name: name of the secrets. |
||||
:type secret_name: str |
||||
:return: dict with the information about the secrets |
||||
:rtype: dict |
||||
""" |
||||
return json.loads(self.get_secret(secret_name)) |
||||
@ -0,0 +1,99 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
"""This module contains AWS SES Hook""" |
||||
from typing import Any, Dict, Iterable, List, Optional, Union |
||||
|
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
from airflow.utils.email import build_mime_message |
||||
|
||||
|
||||
class SESHook(AwsBaseHook): |
||||
""" |
||||
Interact with Amazon Simple Email Service. |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
""" |
||||
|
||||
def __init__(self, *args, **kwargs) -> None: |
||||
kwargs["client_type"] = "ses" |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
def send_email( # pylint: disable=too-many-arguments |
||||
self, |
||||
mail_from: str, |
||||
to: Union[str, Iterable[str]], |
||||
subject: str, |
||||
html_content: str, |
||||
files: Optional[List[str]] = None, |
||||
cc: Optional[Union[str, Iterable[str]]] = None, |
||||
bcc: Optional[Union[str, Iterable[str]]] = None, |
||||
mime_subtype: str = "mixed", |
||||
mime_charset: str = "utf-8", |
||||
reply_to: Optional[str] = None, |
||||
return_path: Optional[str] = None, |
||||
custom_headers: Optional[Dict[str, Any]] = None, |
||||
) -> dict: |
||||
""" |
||||
Send email using Amazon Simple Email Service |
||||
|
||||
:param mail_from: Email address to set as email's from |
||||
:param to: List of email addresses to set as email's to |
||||
:param subject: Email's subject |
||||
:param html_content: Content of email in HTML format |
||||
:param files: List of paths of files to be attached |
||||
:param cc: List of email addresses to set as email's CC |
||||
:param bcc: List of email addresses to set as email's BCC |
||||
:param mime_subtype: Can be used to specify the sub-type of the message. Default = mixed |
||||
:param mime_charset: Email's charset. Default = UTF-8. |
||||
:param return_path: The email address to which replies will be sent. By default, replies |
||||
are sent to the original sender's email address. |
||||
:param reply_to: The email address to which message bounces and complaints should be sent. |
||||
"Return-Path" is sometimes called "envelope from," "envelope sender," or "MAIL FROM." |
||||
:param custom_headers: Additional headers to add to the MIME message. |
||||
No validations are run on these values and they should be able to be encoded. |
||||
:return: Response from Amazon SES service with unique message identifier. |
||||
""" |
||||
ses_client = self.get_conn() |
||||
|
||||
custom_headers = custom_headers or {} |
||||
if reply_to: |
||||
custom_headers["Reply-To"] = reply_to |
||||
if return_path: |
||||
custom_headers["Return-Path"] = return_path |
||||
|
||||
message, recipients = build_mime_message( |
||||
mail_from=mail_from, |
||||
to=to, |
||||
subject=subject, |
||||
html_content=html_content, |
||||
files=files, |
||||
cc=cc, |
||||
bcc=bcc, |
||||
mime_subtype=mime_subtype, |
||||
mime_charset=mime_charset, |
||||
custom_headers=custom_headers, |
||||
) |
||||
|
||||
return ses_client.send_raw_email( |
||||
Source=mail_from, |
||||
Destinations=recipients, |
||||
RawMessage={"Data": message.as_string()}, |
||||
) |
||||
@ -0,0 +1,96 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
"""This module contains AWS SNS hook""" |
||||
import json |
||||
from typing import Dict, Optional, Union |
||||
|
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
def _get_message_attribute(o): |
||||
if isinstance(o, bytes): |
||||
return {"DataType": "Binary", "BinaryValue": o} |
||||
if isinstance(o, str): |
||||
return {"DataType": "String", "StringValue": o} |
||||
if isinstance(o, (int, float)): |
||||
return {"DataType": "Number", "StringValue": str(o)} |
||||
if hasattr(o, "__iter__"): |
||||
return {"DataType": "String.Array", "StringValue": json.dumps(o)} |
||||
raise TypeError( |
||||
"Values in MessageAttributes must be one of bytes, str, int, float, or iterable; " |
||||
f"got {type(o)}" |
||||
) |
||||
|
||||
|
||||
class AwsSnsHook(AwsBaseHook): |
||||
""" |
||||
Interact with Amazon Simple Notification Service. |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
""" |
||||
|
||||
def __init__(self, *args, **kwargs): |
||||
super().__init__(client_type="sns", *args, **kwargs) |
||||
|
||||
def publish_to_target( |
||||
self, |
||||
target_arn: str, |
||||
message: str, |
||||
subject: Optional[str] = None, |
||||
message_attributes: Optional[dict] = None, |
||||
): |
||||
""" |
||||
Publish a message to a topic or an endpoint. |
||||
|
||||
:param target_arn: either a TopicArn or an EndpointArn |
||||
:type target_arn: str |
||||
:param message: the default message you want to send |
||||
:param message: str |
||||
:param subject: subject of message |
||||
:type subject: str |
||||
:param message_attributes: additional attributes to publish for message filtering. This should be |
||||
a flat dict; the DataType to be sent depends on the type of the value: |
||||
|
||||
- bytes = Binary |
||||
- str = String |
||||
- int, float = Number |
||||
- iterable = String.Array |
||||
|
||||
:type message_attributes: dict |
||||
""" |
||||
publish_kwargs: Dict[str, Union[str, dict]] = { |
||||
"TargetArn": target_arn, |
||||
"MessageStructure": "json", |
||||
"Message": json.dumps({"default": message}), |
||||
} |
||||
|
||||
# Construct args this way because boto3 distinguishes from missing args and those set to None |
||||
if subject: |
||||
publish_kwargs["Subject"] = subject |
||||
if message_attributes: |
||||
publish_kwargs["MessageAttributes"] = { |
||||
key: _get_message_attribute(val) |
||||
for key, val in message_attributes.items() |
||||
} |
||||
|
||||
return self.get_conn().publish(**publish_kwargs) |
||||
@ -0,0 +1,87 @@ |
||||
# |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
"""This module contains AWS SQS hook""" |
||||
from typing import Dict, Optional |
||||
|
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class SQSHook(AwsBaseHook): |
||||
""" |
||||
Interact with Amazon Simple Queue Service. |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
""" |
||||
|
||||
def __init__(self, *args, **kwargs) -> None: |
||||
kwargs["client_type"] = "sqs" |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
def create_queue(self, queue_name: str, attributes: Optional[Dict] = None) -> Dict: |
||||
""" |
||||
Create queue using connection object |
||||
|
||||
:param queue_name: name of the queue. |
||||
:type queue_name: str |
||||
:param attributes: additional attributes for the queue (default: None) |
||||
For details of the attributes parameter see :py:meth:`SQS.create_queue` |
||||
:type attributes: dict |
||||
|
||||
:return: dict with the information about the queue |
||||
For details of the returned value see :py:meth:`SQS.create_queue` |
||||
:rtype: dict |
||||
""" |
||||
return self.get_conn().create_queue( |
||||
QueueName=queue_name, Attributes=attributes or {} |
||||
) |
||||
|
||||
def send_message( |
||||
self, |
||||
queue_url: str, |
||||
message_body: str, |
||||
delay_seconds: int = 0, |
||||
message_attributes: Optional[Dict] = None, |
||||
) -> Dict: |
||||
""" |
||||
Send message to the queue |
||||
|
||||
:param queue_url: queue url |
||||
:type queue_url: str |
||||
:param message_body: the contents of the message |
||||
:type message_body: str |
||||
:param delay_seconds: seconds to delay the message |
||||
:type delay_seconds: int |
||||
:param message_attributes: additional attributes for the message (default: None) |
||||
For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message` |
||||
:type message_attributes: dict |
||||
|
||||
:return: dict with the information about the message sent |
||||
For details of the returned value see :py:meth:`botocore.client.SQS.send_message` |
||||
:rtype: dict |
||||
""" |
||||
return self.get_conn().send_message( |
||||
QueueUrl=queue_url, |
||||
MessageBody=message_body, |
||||
DelaySeconds=delay_seconds, |
||||
MessageAttributes=message_attributes or {}, |
||||
) |
||||
@ -0,0 +1,82 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
import json |
||||
from typing import Optional, Union |
||||
|
||||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
||||
|
||||
|
||||
class StepFunctionHook(AwsBaseHook): |
||||
""" |
||||
Interact with an AWS Step Functions State Machine. |
||||
|
||||
Additional arguments (such as ``aws_conn_id``) may be specified and |
||||
are passed down to the underlying AwsBaseHook. |
||||
|
||||
.. seealso:: |
||||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
||||
""" |
||||
|
||||
def __init__(self, region_name: Optional[str] = None, *args, **kwargs) -> None: |
||||
kwargs["client_type"] = "stepfunctions" |
||||
super().__init__(*args, **kwargs) |
||||
|
||||
def start_execution( |
||||
self, |
||||
state_machine_arn: str, |
||||
name: Optional[str] = None, |
||||
state_machine_input: Union[dict, str, None] = None, |
||||
) -> str: |
||||
""" |
||||
Start Execution of the State Machine. |
||||
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.start_execution |
||||
|
||||
:param state_machine_arn: AWS Step Function State Machine ARN |
||||
:type state_machine_arn: str |
||||
:param name: The name of the execution. |
||||
:type name: Optional[str] |
||||
:param state_machine_input: JSON data input to pass to the State Machine |
||||
:type state_machine_input: Union[Dict[str, any], str, None] |
||||
:return: Execution ARN |
||||
:rtype: str |
||||
""" |
||||
execution_args = {"stateMachineArn": state_machine_arn} |
||||
if name is not None: |
||||
execution_args["name"] = name |
||||
if state_machine_input is not None: |
||||
if isinstance(state_machine_input, str): |
||||
execution_args["input"] = state_machine_input |
||||
elif isinstance(state_machine_input, dict): |
||||
execution_args["input"] = json.dumps(state_machine_input) |
||||
|
||||
self.log.info("Executing Step Function State Machine: %s", state_machine_arn) |
||||
|
||||
response = self.conn.start_execution(**execution_args) |
||||
return response.get("executionArn") |
||||
|
||||
def describe_execution(self, execution_arn: str) -> dict: |
||||
""" |
||||
Describes a State Machine Execution |
||||
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.describe_execution |
||||
|
||||
:param execution_arn: ARN of the State Machine Execution |
||||
:type execution_arn: str |
||||
:return: Dict with Execution details |
||||
:rtype: dict |
||||
""" |
||||
return self.get_conn().describe_execution(executionArn=execution_arn) |
||||
@ -0,0 +1,16 @@ |
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue