You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
115 lines
3.8 KiB
115 lines
3.8 KiB
# |
|
# Licensed to the Apache Software Foundation (ASF) under one |
|
# or more contributor license agreements. See the NOTICE file |
|
# distributed with this work for additional information |
|
# regarding copyright ownership. The ASF licenses this file |
|
# to you under the Apache License, Version 2.0 (the |
|
# "License"); you may not use this file except in compliance |
|
# with the License. You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, |
|
# software distributed under the License is distributed on an |
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
|
# KIND, either express or implied. See the License for the |
|
# specific language governing permissions and limitations |
|
# under the License. |
|
|
|
from typing import Any, Dict, Iterable, Optional |
|
|
|
from airflow.exceptions import AirflowException |
|
from airflow.providers.amazon.aws.hooks.emr import EmrHook |
|
from airflow.sensors.base import BaseSensorOperator |
|
from airflow.utils.decorators import apply_defaults |
|
|
|
|
|
class EmrBaseSensor(BaseSensorOperator): |
|
""" |
|
Contains general sensor behavior for EMR. |
|
|
|
Subclasses should implement following methods: |
|
- ``get_emr_response()`` |
|
- ``state_from_response()`` |
|
- ``failure_message_from_response()`` |
|
|
|
Subclasses should set ``target_states`` and ``failed_states`` fields. |
|
|
|
:param aws_conn_id: aws connection to uses |
|
:type aws_conn_id: str |
|
""" |
|
|
|
ui_color = "#66c3ff" |
|
|
|
@apply_defaults |
|
def __init__(self, *, aws_conn_id: str = "aws_default", **kwargs): |
|
super().__init__(**kwargs) |
|
self.aws_conn_id = aws_conn_id |
|
self.target_states: Optional[Iterable[str]] = None # will be set in subclasses |
|
self.failed_states: Optional[Iterable[str]] = None # will be set in subclasses |
|
self.hook: Optional[EmrHook] = None |
|
|
|
def get_hook(self) -> EmrHook: |
|
"""Get EmrHook""" |
|
if self.hook: |
|
return self.hook |
|
|
|
self.hook = EmrHook(aws_conn_id=self.aws_conn_id) |
|
return self.hook |
|
|
|
def poke(self, context): |
|
response = self.get_emr_response() |
|
|
|
if not response["ResponseMetadata"]["HTTPStatusCode"] == 200: |
|
self.log.info("Bad HTTP response: %s", response) |
|
return False |
|
|
|
state = self.state_from_response(response) |
|
self.log.info("Job flow currently %s", state) |
|
|
|
if state in self.target_states: |
|
return True |
|
|
|
if state in self.failed_states: |
|
final_message = "EMR job failed" |
|
failure_message = self.failure_message_from_response(response) |
|
if failure_message: |
|
final_message += " " + failure_message |
|
raise AirflowException(final_message) |
|
|
|
return False |
|
|
|
def get_emr_response(self) -> Dict[str, Any]: |
|
""" |
|
Make an API call with boto3 and get response. |
|
|
|
:return: response |
|
:rtype: dict[str, Any] |
|
""" |
|
raise NotImplementedError("Please implement get_emr_response() in subclass") |
|
|
|
@staticmethod |
|
def state_from_response(response: Dict[str, Any]) -> str: |
|
""" |
|
Get state from response dictionary. |
|
|
|
:param response: response from AWS API |
|
:type response: dict[str, Any] |
|
:return: state |
|
:rtype: str |
|
""" |
|
raise NotImplementedError("Please implement state_from_response() in subclass") |
|
|
|
@staticmethod |
|
def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: |
|
""" |
|
Get failure message from response dictionary. |
|
|
|
:param response: response from AWS API |
|
:type response: dict[str, Any] |
|
:return: failure message |
|
:rtype: Optional[str] |
|
""" |
|
raise NotImplementedError( |
|
"Please implement failure_message_from_response() in subclass" |
|
)
|
|
|