You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
3.5 KiB
93 lines
3.5 KiB
# |
|
# Licensed to the Apache Software Foundation (ASF) under one |
|
# or more contributor license agreements. See the NOTICE file |
|
# distributed with this work for additional information |
|
# regarding copyright ownership. The ASF licenses this file |
|
# to you under the Apache License, Version 2.0 (the |
|
# "License"); you may not use this file except in compliance |
|
# with the License. You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, |
|
# software distributed under the License is distributed on an |
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
|
# KIND, either express or implied. See the License for the |
|
# specific language governing permissions and limitations |
|
# under the License. |
|
from typing import Optional, Set |
|
|
|
from airflow.exceptions import AirflowException |
|
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook |
|
from airflow.sensors.base import BaseSensorOperator |
|
from airflow.utils.decorators import apply_defaults |
|
|
|
|
|
class SageMakerBaseSensor(BaseSensorOperator): |
|
""" |
|
Contains general sensor behavior for SageMaker. |
|
Subclasses should implement get_sagemaker_response() |
|
and state_from_response() methods. |
|
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods. |
|
""" |
|
|
|
ui_color = "#ededed" |
|
|
|
@apply_defaults |
|
def __init__(self, *, aws_conn_id: str = "aws_default", **kwargs): |
|
super().__init__(**kwargs) |
|
self.aws_conn_id = aws_conn_id |
|
self.hook: Optional[SageMakerHook] = None |
|
|
|
def get_hook(self) -> SageMakerHook: |
|
"""Get SageMakerHook""" |
|
if self.hook: |
|
return self.hook |
|
|
|
self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) |
|
return self.hook |
|
|
|
def poke(self, context): |
|
response = self.get_sagemaker_response() |
|
|
|
if not response["ResponseMetadata"]["HTTPStatusCode"] == 200: |
|
self.log.info("Bad HTTP response: %s", response) |
|
return False |
|
|
|
state = self.state_from_response(response) |
|
|
|
self.log.info("Job currently %s", state) |
|
|
|
if state in self.non_terminal_states(): |
|
return False |
|
|
|
if state in self.failed_states(): |
|
failed_reason = self.get_failed_reason_from_response(response) |
|
raise AirflowException( |
|
f"Sagemaker job failed for the following reason: {failed_reason}" |
|
) |
|
return True |
|
|
|
def non_terminal_states(self) -> Set[str]: |
|
"""Placeholder for returning states with should not terminate.""" |
|
raise NotImplementedError("Please implement non_terminal_states() in subclass") |
|
|
|
def failed_states(self) -> Set[str]: |
|
"""Placeholder for returning states with are considered failed.""" |
|
raise NotImplementedError("Please implement failed_states() in subclass") |
|
|
|
def get_sagemaker_response(self) -> Optional[dict]: |
|
"""Placeholder for checking status of a SageMaker task.""" |
|
raise NotImplementedError( |
|
"Please implement get_sagemaker_response() in subclass" |
|
) |
|
|
|
def get_failed_reason_from_response( |
|
self, response: dict |
|
) -> str: # pylint: disable=unused-argument |
|
"""Placeholder for extracting the reason for failure from an AWS response.""" |
|
return "Unknown" |
|
|
|
def state_from_response(self, response: dict) -> str: |
|
"""Placeholder for extracting the state from an AWS response.""" |
|
raise NotImplementedError("Please implement state_from_response() in subclass")
|
|
|