Apache Airflow dags w/ backend configuration bundle.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

385 lines
13 KiB

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Databricks hook.
This hook enable the submitting and running of jobs to the Databricks platform. Internally the
operators talk to the ``api/2.0/jobs/runs/submit``
`endpoint <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_.
"""
from time import sleep
from urllib.parse import urlparse
import requests
from airflow import __version__
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from requests import PreparedRequest
from requests import exceptions as requests_exceptions
from requests.auth import AuthBase
RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart")
START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete")
RUN_NOW_ENDPOINT = ("POST", "api/2.0/jobs/run-now")
SUBMIT_RUN_ENDPOINT = ("POST", "api/2.0/jobs/runs/submit")
GET_RUN_ENDPOINT = ("GET", "api/2.0/jobs/runs/get")
CANCEL_RUN_ENDPOINT = ("POST", "api/2.0/jobs/runs/cancel")
USER_AGENT_HEADER = {"user-agent": f"airflow-{__version__}"}
INSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/install")
UNINSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/uninstall")
class RunState:
"""Utility class for the run state concept of Databricks runs."""
def __init__(
self, life_cycle_state: str, result_state: str, state_message: str
) -> None:
self.life_cycle_state = life_cycle_state
self.result_state = result_state
self.state_message = state_message
@property
def is_terminal(self) -> bool:
"""True if the current state is a terminal state."""
if self.life_cycle_state not in RUN_LIFE_CYCLE_STATES:
raise AirflowException(
(
"Unexpected life cycle state: {}: If the state has "
"been introduced recently, please check the Databricks user "
"guide for troubleshooting information"
).format(self.life_cycle_state)
)
return self.life_cycle_state in ("TERMINATED", "SKIPPED", "INTERNAL_ERROR")
@property
def is_successful(self) -> bool:
"""True if the result state is SUCCESS"""
return self.result_state == "SUCCESS"
def __eq__(self, other: object) -> bool:
if not isinstance(other, RunState):
return NotImplemented
return (
self.life_cycle_state == other.life_cycle_state
and self.result_state == other.result_state
and self.state_message == other.state_message
)
def __repr__(self) -> str:
return str(self.__dict__)
class DatabricksHook(BaseHook): # noqa
"""
Interact with Databricks.
:param databricks_conn_id: The name of the databricks connection to use.
:type databricks_conn_id: str
:param timeout_seconds: The amount of time in seconds the requests library
will wait before timing-out.
:type timeout_seconds: int
:param retry_limit: The number of times to retry the connection in case of
service outages.
:type retry_limit: int
:param retry_delay: The number of seconds to wait between retries (it
might be a floating point number).
:type retry_delay: float
"""
conn_name_attr = "databricks_conn_id"
default_conn_name = "databricks_default"
conn_type = "databricks"
hook_name = "Databricks"
def __init__(
self,
databricks_conn_id: str = default_conn_name,
timeout_seconds: int = 180,
retry_limit: int = 3,
retry_delay: float = 1.0,
) -> None:
super().__init__()
self.databricks_conn_id = databricks_conn_id
self.databricks_conn = self.get_connection(databricks_conn_id)
self.timeout_seconds = timeout_seconds
if retry_limit < 1:
raise ValueError("Retry limit must be greater than equal to 1")
self.retry_limit = retry_limit
self.retry_delay = retry_delay
@staticmethod
def _parse_host(host: str) -> str:
"""
The purpose of this function is to be robust to improper connections
settings provided by users, specifically in the host field.
For example -- when users supply ``https://xx.cloud.databricks.com`` as the
host, we must strip out the protocol to get the host.::
h = DatabricksHook()
assert h._parse_host('https://xx.cloud.databricks.com') == \
'xx.cloud.databricks.com'
In the case where users supply the correct ``xx.cloud.databricks.com`` as the
host, this function is a no-op.::
assert h._parse_host('xx.cloud.databricks.com') == 'xx.cloud.databricks.com'
"""
urlparse_host = urlparse(host).hostname
if urlparse_host:
# In this case, host = https://xx.cloud.databricks.com
return urlparse_host
else:
# In this case, host = xx.cloud.databricks.com
return host
def _do_api_call(self, endpoint_info, json):
"""
Utility function to perform an API call with retries
:param endpoint_info: Tuple of method and endpoint
:type endpoint_info: tuple[string, string]
:param json: Parameters for this API call.
:type json: dict
:return: If the api call returns a OK status code,
this function returns the response in JSON. Otherwise,
we throw an AirflowException.
:rtype: dict
"""
method, endpoint = endpoint_info
if "token" in self.databricks_conn.extra_dejson:
self.log.info("Using token auth. ")
auth = _TokenAuth(self.databricks_conn.extra_dejson["token"])
if "host" in self.databricks_conn.extra_dejson:
host = self._parse_host(self.databricks_conn.extra_dejson["host"])
else:
host = self.databricks_conn.host
else:
self.log.info("Using basic auth. ")
auth = (self.databricks_conn.login, self.databricks_conn.password)
host = self.databricks_conn.host
url = f"https://{self._parse_host(host)}/{endpoint}"
if method == "GET":
request_func = requests.get
elif method == "POST":
request_func = requests.post
elif method == "PATCH":
request_func = requests.patch
else:
raise AirflowException("Unexpected HTTP Method: " + method)
attempt_num = 1
while True:
try:
response = request_func(
url,
json=json if method in ("POST", "PATCH") else None,
params=json if method == "GET" else None,
auth=auth,
headers=USER_AGENT_HEADER,
timeout=self.timeout_seconds,
)
response.raise_for_status()
return response.json()
except requests_exceptions.RequestException as e:
if not _retryable_error(e):
# In this case, the user probably made a mistake.
# Don't retry.
raise AirflowException(
f"Response: {e.response.content}, Status Code: {e.response.status_code}"
)
self._log_request_error(attempt_num, e)
if attempt_num == self.retry_limit:
raise AirflowException(
(
"API requests to Databricks failed {} times. " + "Giving up."
).format(self.retry_limit)
)
attempt_num += 1
sleep(self.retry_delay)
def _log_request_error(self, attempt_num: int, error: str) -> None:
self.log.error(
"Attempt %s API Request to Databricks failed with reason: %s",
attempt_num,
error,
)
def run_now(self, json: dict) -> str:
"""
Utility function to call the ``api/2.0/jobs/run-now`` endpoint.
:param json: The data used in the body of the request to the ``run-now`` endpoint.
:type json: dict
:return: the run_id as a string
:rtype: str
"""
response = self._do_api_call(RUN_NOW_ENDPOINT, json)
return response["run_id"]
def submit_run(self, json: dict) -> str:
"""
Utility function to call the ``api/2.0/jobs/runs/submit`` endpoint.
:param json: The data used in the body of the request to the ``submit`` endpoint.
:type json: dict
:return: the run_id as a string
:rtype: str
"""
response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json)
return response["run_id"]
def get_run_page_url(self, run_id: str) -> str:
"""
Retrieves run_page_url.
:param run_id: id of the run
:return: URL of the run page
"""
json = {"run_id": run_id}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
return response["run_page_url"]
def get_job_id(self, run_id: str) -> str:
"""
Retrieves job_id from run_id.
:param run_id: id of the run
:type run_id: str
:return: Job id for given Databricks run
"""
json = {"run_id": run_id}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
return response["job_id"]
def get_run_state(self, run_id: str) -> RunState:
"""
Retrieves run state of the run.
:param run_id: id of the run
:return: state of the run
"""
json = {"run_id": run_id}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
state = response["state"]
life_cycle_state = state["life_cycle_state"]
# result_state may not be in the state if not terminal
result_state = state.get("result_state", None)
state_message = state["state_message"]
return RunState(life_cycle_state, result_state, state_message)
def cancel_run(self, run_id: str) -> None:
"""
Cancels the run.
:param run_id: id of the run
"""
json = {"run_id": run_id}
self._do_api_call(CANCEL_RUN_ENDPOINT, json)
def restart_cluster(self, json: dict) -> None:
"""
Restarts the cluster.
:param json: json dictionary containing cluster specification.
"""
self._do_api_call(RESTART_CLUSTER_ENDPOINT, json)
def start_cluster(self, json: dict) -> None:
"""
Starts the cluster.
:param json: json dictionary containing cluster specification.
"""
self._do_api_call(START_CLUSTER_ENDPOINT, json)
def terminate_cluster(self, json: dict) -> None:
"""
Terminates the cluster.
:param json: json dictionary containing cluster specification.
"""
self._do_api_call(TERMINATE_CLUSTER_ENDPOINT, json)
def install(self, json: dict) -> None:
"""
Install libraries on the cluster.
Utility function to call the ``2.0/libraries/install`` endpoint.
:param json: json dictionary containing cluster_id and an array of library
:type json: dict
"""
self._do_api_call(INSTALL_LIBS_ENDPOINT, json)
def uninstall(self, json: dict) -> None:
"""
Uninstall libraries on the cluster.
Utility function to call the ``2.0/libraries/uninstall`` endpoint.
:param json: json dictionary containing cluster_id and an array of library
:type json: dict
"""
self._do_api_call(UNINSTALL_LIBS_ENDPOINT, json)
def _retryable_error(exception) -> bool:
return (
isinstance(
exception,
(requests_exceptions.ConnectionError, requests_exceptions.Timeout),
)
or exception.response is not None
and exception.response.status_code >= 500
)
RUN_LIFE_CYCLE_STATES = [
"PENDING",
"RUNNING",
"TERMINATING",
"TERMINATED",
"SKIPPED",
"INTERNAL_ERROR",
]
class _TokenAuth(AuthBase):
"""
Helper class for requests Auth field. AuthBase requires you to implement the __call__
magic function.
"""
def __init__(self, token: str) -> None:
self.token = token
def __call__(self, r: PreparedRequest) -> PreparedRequest:
r.headers["Authorization"] = "Bearer " + self.token
return r