You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
263 lines
10 KiB
263 lines
10 KiB
# |
|
# Licensed to the Apache Software Foundation (ASF) under one |
|
# or more contributor license agreements. See the NOTICE file |
|
# distributed with this work for additional information |
|
# regarding copyright ownership. The ASF licenses this file |
|
# to you under the Apache License, Version 2.0 (the |
|
# "License"); you may not use this file except in compliance |
|
# with the License. You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, |
|
# software distributed under the License is distributed on an |
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
|
# KIND, either express or implied. See the License for the |
|
# specific language governing permissions and limitations |
|
# under the License. |
|
|
|
"""This module contains AWS Athena hook""" |
|
from time import sleep |
|
from typing import Any, Dict, Optional |
|
|
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook |
|
from botocore.paginate import PageIterator |
|
|
|
|
|
class AWSAthenaHook(AwsBaseHook): |
|
""" |
|
Interact with AWS Athena to run, poll queries and return query results |
|
|
|
Additional arguments (such as ``aws_conn_id``) may be specified and |
|
are passed down to the underlying AwsBaseHook. |
|
|
|
.. seealso:: |
|
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` |
|
|
|
:param sleep_time: Time (in seconds) to wait between two consecutive calls to check query status on Athena |
|
:type sleep_time: int |
|
""" |
|
|
|
INTERMEDIATE_STATES = ( |
|
"QUEUED", |
|
"RUNNING", |
|
) |
|
FAILURE_STATES = ( |
|
"FAILED", |
|
"CANCELLED", |
|
) |
|
SUCCESS_STATES = ("SUCCEEDED",) |
|
|
|
def __init__(self, *args: Any, sleep_time: int = 30, **kwargs: Any) -> None: |
|
super().__init__(client_type="athena", *args, **kwargs) # type: ignore |
|
self.sleep_time = sleep_time |
|
|
|
def run_query( |
|
self, |
|
query: str, |
|
query_context: Dict[str, str], |
|
result_configuration: Dict[str, Any], |
|
client_request_token: Optional[str] = None, |
|
workgroup: str = "primary", |
|
) -> str: |
|
""" |
|
Run Presto query on athena with provided config and return submitted query_execution_id |
|
|
|
:param query: Presto query to run |
|
:type query: str |
|
:param query_context: Context in which query need to be run |
|
:type query_context: dict |
|
:param result_configuration: Dict with path to store results in and config related to encryption |
|
:type result_configuration: dict |
|
:param client_request_token: Unique token created by user to avoid multiple executions of same query |
|
:type client_request_token: str |
|
:param workgroup: Athena workgroup name, when not specified, will be 'primary' |
|
:type workgroup: str |
|
:return: str |
|
""" |
|
params = { |
|
"QueryString": query, |
|
"QueryExecutionContext": query_context, |
|
"ResultConfiguration": result_configuration, |
|
"WorkGroup": workgroup, |
|
} |
|
if client_request_token: |
|
params["ClientRequestToken"] = client_request_token |
|
response = self.get_conn().start_query_execution(**params) |
|
query_execution_id = response["QueryExecutionId"] |
|
return query_execution_id |
|
|
|
def check_query_status(self, query_execution_id: str) -> Optional[str]: |
|
""" |
|
Fetch the status of submitted athena query. Returns None or one of valid query states. |
|
|
|
:param query_execution_id: Id of submitted athena query |
|
:type query_execution_id: str |
|
:return: str |
|
""" |
|
response = self.get_conn().get_query_execution( |
|
QueryExecutionId=query_execution_id |
|
) |
|
state = None |
|
try: |
|
state = response["QueryExecution"]["Status"]["State"] |
|
except Exception as ex: # pylint: disable=broad-except |
|
self.log.error("Exception while getting query state %s", ex) |
|
finally: |
|
# The error is being absorbed here and is being handled by the caller. |
|
# The error is being absorbed to implement retries. |
|
return state # pylint: disable=lost-exception |
|
|
|
def get_state_change_reason(self, query_execution_id: str) -> Optional[str]: |
|
""" |
|
Fetch the reason for a state change (e.g. error message). Returns None or reason string. |
|
|
|
:param query_execution_id: Id of submitted athena query |
|
:type query_execution_id: str |
|
:return: str |
|
""" |
|
response = self.get_conn().get_query_execution( |
|
QueryExecutionId=query_execution_id |
|
) |
|
reason = None |
|
try: |
|
reason = response["QueryExecution"]["Status"]["StateChangeReason"] |
|
except Exception as ex: # pylint: disable=broad-except |
|
self.log.error("Exception while getting query state change reason: %s", ex) |
|
finally: |
|
# The error is being absorbed here and is being handled by the caller. |
|
# The error is being absorbed to implement retries. |
|
return reason # pylint: disable=lost-exception |
|
|
|
def get_query_results( |
|
self, |
|
query_execution_id: str, |
|
next_token_id: Optional[str] = None, |
|
max_results: int = 1000, |
|
) -> Optional[dict]: |
|
""" |
|
Fetch submitted athena query results. returns none if query is in intermediate state or |
|
failed/cancelled state else dict of query output |
|
|
|
:param query_execution_id: Id of submitted athena query |
|
:type query_execution_id: str |
|
:param next_token_id: The token that specifies where to start pagination. |
|
:type next_token_id: str |
|
:param max_results: The maximum number of results (rows) to return in this request. |
|
:type max_results: int |
|
:return: dict |
|
""" |
|
query_state = self.check_query_status(query_execution_id) |
|
if query_state is None: |
|
self.log.error("Invalid Query state") |
|
return None |
|
elif ( |
|
query_state in self.INTERMEDIATE_STATES |
|
or query_state in self.FAILURE_STATES |
|
): |
|
self.log.error('Query is in "%s" state. Cannot fetch results', query_state) |
|
return None |
|
result_params = { |
|
"QueryExecutionId": query_execution_id, |
|
"MaxResults": max_results, |
|
} |
|
if next_token_id: |
|
result_params["NextToken"] = next_token_id |
|
return self.get_conn().get_query_results(**result_params) |
|
|
|
def get_query_results_paginator( |
|
self, |
|
query_execution_id: str, |
|
max_items: Optional[int] = None, |
|
page_size: Optional[int] = None, |
|
starting_token: Optional[str] = None, |
|
) -> Optional[PageIterator]: |
|
""" |
|
Fetch submitted athena query results. returns none if query is in intermediate state or |
|
failed/cancelled state else a paginator to iterate through pages of results. If you |
|
wish to get all results at once, call build_full_result() on the returned PageIterator |
|
|
|
:param query_execution_id: Id of submitted athena query |
|
:type query_execution_id: str |
|
:param max_items: The total number of items to return. |
|
:type max_items: int |
|
:param page_size: The size of each page. |
|
:type page_size: int |
|
:param starting_token: A token to specify where to start paginating. |
|
:type starting_token: str |
|
:return: PageIterator |
|
""" |
|
query_state = self.check_query_status(query_execution_id) |
|
if query_state is None: |
|
self.log.error("Invalid Query state (null)") |
|
return None |
|
if ( |
|
query_state in self.INTERMEDIATE_STATES |
|
or query_state in self.FAILURE_STATES |
|
): |
|
self.log.error('Query is in "%s" state. Cannot fetch results', query_state) |
|
return None |
|
result_params = { |
|
"QueryExecutionId": query_execution_id, |
|
"PaginationConfig": { |
|
"MaxItems": max_items, |
|
"PageSize": page_size, |
|
"StartingToken": starting_token, |
|
}, |
|
} |
|
paginator = self.get_conn().get_paginator("get_query_results") |
|
return paginator.paginate(**result_params) |
|
|
|
def poll_query_status( |
|
self, query_execution_id: str, max_tries: Optional[int] = None |
|
) -> Optional[str]: |
|
""" |
|
Poll the status of submitted athena query until query state reaches final state. |
|
Returns one of the final states |
|
|
|
:param query_execution_id: Id of submitted athena query |
|
:type query_execution_id: str |
|
:param max_tries: Number of times to poll for query state before function exits |
|
:type max_tries: int |
|
:return: str |
|
""" |
|
try_number = 1 |
|
final_query_state = ( |
|
None # Query state when query reaches final state or max_tries reached |
|
) |
|
while True: |
|
query_state = self.check_query_status(query_execution_id) |
|
if query_state is None: |
|
self.log.info( |
|
"Trial %s: Invalid query state. Retrying again", try_number |
|
) |
|
elif query_state in self.INTERMEDIATE_STATES: |
|
self.log.info( |
|
"Trial %s: Query is still in an intermediate state - %s", |
|
try_number, |
|
query_state, |
|
) |
|
else: |
|
self.log.info( |
|
"Trial %s: Query execution completed. Final state is %s}", |
|
try_number, |
|
query_state, |
|
) |
|
final_query_state = query_state |
|
break |
|
if max_tries and try_number >= max_tries: # Break loop if max_tries reached |
|
final_query_state = query_state |
|
break |
|
try_number += 1 |
|
sleep(self.sleep_time) |
|
return final_query_state |
|
|
|
def stop_query(self, query_execution_id: str) -> Dict: |
|
""" |
|
Cancel the submitted athena query |
|
|
|
:param query_execution_id: Id of submitted athena query |
|
:type query_execution_id: str |
|
:return: dict |
|
""" |
|
return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id)
|
|
|