You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
112 lines
4.2 KiB
112 lines
4.2 KiB
# |
|
# Licensed to the Apache Software Foundation (ASF) under one |
|
# or more contributor license agreements. See the NOTICE file |
|
# distributed with this work for additional information |
|
# regarding copyright ownership. The ASF licenses this file |
|
# to you under the Apache License, Version 2.0 (the |
|
# "License"); you may not use this file except in compliance |
|
# with the License. You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, |
|
# software distributed under the License is distributed on an |
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
|
# KIND, either express or implied. See the License for the |
|
# specific language governing permissions and limitations |
|
# under the License. |
|
import time |
|
from typing import Optional |
|
|
|
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook |
|
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor |
|
from airflow.utils.decorators import apply_defaults |
|
|
|
|
|
class SageMakerTrainingSensor(SageMakerBaseSensor): |
|
""" |
|
Asks for the state of the training state until it reaches a terminal state. |
|
If it fails the sensor errors, failing the task. |
|
|
|
:param job_name: name of the SageMaker training job to check the state of |
|
:type job_name: str |
|
:param print_log: if the operator should print the cloudwatch log |
|
:type print_log: bool |
|
""" |
|
|
|
template_fields = ["job_name"] |
|
template_ext = () |
|
|
|
@apply_defaults |
|
def __init__(self, *, job_name, print_log=True, **kwargs): |
|
super().__init__(**kwargs) |
|
self.job_name = job_name |
|
self.print_log = print_log |
|
self.positions = {} |
|
self.stream_names = [] |
|
self.instance_count: Optional[int] = None |
|
self.state: Optional[int] = None |
|
self.last_description = None |
|
self.last_describe_job_call = None |
|
self.log_resource_inited = False |
|
|
|
def init_log_resource(self, hook: SageMakerHook) -> None: |
|
"""Set tailing LogState for associated training job.""" |
|
description = hook.describe_training_job(self.job_name) |
|
self.instance_count = description["ResourceConfig"]["InstanceCount"] |
|
|
|
status = description["TrainingJobStatus"] |
|
job_already_completed = status not in self.non_terminal_states() |
|
self.state = ( |
|
LogState.TAILING if not job_already_completed else LogState.COMPLETE |
|
) |
|
self.last_description = description |
|
self.last_describe_job_call = time.monotonic() |
|
self.log_resource_inited = True |
|
|
|
def non_terminal_states(self): |
|
return SageMakerHook.non_terminal_states |
|
|
|
def failed_states(self): |
|
return SageMakerHook.failed_states |
|
|
|
def get_sagemaker_response(self): |
|
if self.print_log: |
|
if not self.log_resource_inited: |
|
self.init_log_resource(self.get_hook()) |
|
( |
|
self.state, |
|
self.last_description, |
|
self.last_describe_job_call, |
|
) = self.get_hook().describe_training_job_with_log( |
|
self.job_name, |
|
self.positions, |
|
self.stream_names, |
|
self.instance_count, |
|
self.state, |
|
self.last_description, |
|
self.last_describe_job_call, |
|
) |
|
else: |
|
self.last_description = self.get_hook().describe_training_job(self.job_name) |
|
|
|
status = self.state_from_response(self.last_description) |
|
if ( |
|
status not in self.non_terminal_states() |
|
and status not in self.failed_states() |
|
): |
|
billable_time = ( |
|
self.last_description["TrainingEndTime"] |
|
- self.last_description["TrainingStartTime"] |
|
) * self.last_description["ResourceConfig"]["InstanceCount"] |
|
self.log.info( |
|
"Billable seconds: %s", int(billable_time.total_seconds()) + 1 |
|
) |
|
|
|
return self.last_description |
|
|
|
def get_failed_reason_from_response(self, response): |
|
return response["FailureReason"] |
|
|
|
def state_from_response(self, response): |
|
return response["TrainingJobStatus"]
|
|
|