Apache Airflow dags w/ backend configuration bundle.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

421 lines
16 KiB

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
from collections import namedtuple
from time import sleep
from typing import Any, Dict, List, Optional, Sequence, Union
from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.azure_container_instance import (
AzureContainerInstanceHook,
)
from airflow.providers.microsoft.azure.hooks.azure_container_registry import (
AzureContainerRegistryHook,
)
from airflow.providers.microsoft.azure.hooks.azure_container_volume import (
AzureContainerVolumeHook,
)
from airflow.utils.decorators import apply_defaults
from azure.mgmt.containerinstance.models import (
Container,
ContainerGroup,
ContainerPort,
EnvironmentVariable,
IpAddress,
ResourceRequests,
ResourceRequirements,
VolumeMount,
)
from msrestazure.azure_exceptions import CloudError
Volume = namedtuple(
"Volume",
["conn_id", "account_name", "share_name", "mount_path", "read_only"],
)
DEFAULT_ENVIRONMENT_VARIABLES: Dict[str, str] = {}
DEFAULT_SECURED_VARIABLES: Sequence[str] = []
DEFAULT_VOLUMES: Sequence[Volume] = []
DEFAULT_MEMORY_IN_GB = 2.0
DEFAULT_CPU = 1.0
# pylint: disable=too-many-instance-attributes
class AzureContainerInstancesOperator(BaseOperator):
"""
Start a container on Azure Container Instances
:param ci_conn_id: connection id of a service principal which will be used
to start the container instance
:type ci_conn_id: str
:param registry_conn_id: connection id of a user which can login to a
private docker registry. If None, we assume a public registry
:type registry_conn_id: Optional[str]
:param resource_group: name of the resource group wherein this container
instance should be started
:type resource_group: str
:param name: name of this container instance. Please note this name has
to be unique in order to run containers in parallel.
:type name: str
:param image: the docker image to be used
:type image: str
:param region: the region wherein this container instance should be started
:type region: str
:param environment_variables: key,value pairs containing environment
variables which will be passed to the running container
:type environment_variables: Optional[dict]
:param secured_variables: names of environmental variables that should not
be exposed outside the container (typically passwords).
:type secured_variables: Optional[str]
:param volumes: list of ``Volume`` tuples to be mounted to the container.
Currently only Azure Fileshares are supported.
:type volumes: list[<conn_id, account_name, share_name, mount_path, read_only>]
:param memory_in_gb: the amount of memory to allocate to this container
:type memory_in_gb: double
:param cpu: the number of cpus to allocate to this container
:type cpu: double
:param gpu: GPU Resource for the container.
:type gpu: azure.mgmt.containerinstance.models.GpuResource
:param command: the command to run inside the container
:type command: Optional[List[str]]
:param container_timeout: max time allowed for the execution of
the container instance.
:type container_timeout: datetime.timedelta
:param tags: azure tags as dict of str:str
:type tags: Optional[dict[str, str]]
:param os_type: The operating system type required by the containers
in the container group. Possible values include: 'Windows', 'Linux'
:type os_type: str
:param restart_policy: Restart policy for all containers within the container group.
Possible values include: 'Always', 'OnFailure', 'Never'
:type restart_policy: str
:param ip_address: The IP address type of the container group.
:type ip_address: IpAddress
**Example**::
AzureContainerInstancesOperator(
ci_conn_id = "azure_service_principal",
registry_conn_id = "azure_registry_user",
resource_group = "my-resource-group",
name = "my-container-name-{{ ds }}",
image = "myprivateregistry.azurecr.io/my_container:latest",
region = "westeurope",
environment_variables = {"MODEL_PATH": "my_value",
"POSTGRES_LOGIN": "{{ macros.connection('postgres_default').login }}",
"POSTGRES_PASSWORD": "{{ macros.connection('postgres_default').password }}",
"JOB_GUID": "{{ ti.xcom_pull(task_ids='task1', key='guid') }}" },
secured_variables = ['POSTGRES_PASSWORD'],
volumes = [("azure_wasb_conn_id",
"my_storage_container",
"my_fileshare",
"/input-data",
True),],
memory_in_gb=14.0,
cpu=4.0,
gpu=GpuResource(count=1, sku='K80'),
command=["/bin/echo", "world"],
task_id="start_container"
)
"""
template_fields = ("name", "image", "command", "environment_variables")
# pylint: disable=too-many-arguments
@apply_defaults
def __init__(
self,
*,
ci_conn_id: str,
registry_conn_id: Optional[str],
resource_group: str,
name: str,
image: str,
region: str,
environment_variables: Optional[dict] = None,
secured_variables: Optional[str] = None,
volumes: Optional[list] = None,
memory_in_gb: Optional[Any] = None,
cpu: Optional[Any] = None,
gpu: Optional[Any] = None,
command: Optional[List[str]] = None,
remove_on_error: bool = True,
fail_if_exists: bool = True,
tags: Optional[Dict[str, str]] = None,
os_type: str = "Linux",
restart_policy: str = "Never",
ip_address: Optional[IpAddress] = None,
ports: Optional[List[ContainerPort]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.ci_conn_id = ci_conn_id
self.resource_group = resource_group
self.name = self._check_name(name)
self.image = image
self.region = region
self.registry_conn_id = registry_conn_id
self.environment_variables = (
environment_variables or DEFAULT_ENVIRONMENT_VARIABLES
)
self.secured_variables = secured_variables or DEFAULT_SECURED_VARIABLES
self.volumes = volumes or DEFAULT_VOLUMES
self.memory_in_gb = memory_in_gb or DEFAULT_MEMORY_IN_GB
self.cpu = cpu or DEFAULT_CPU
self.gpu = gpu
self.command = command
self.remove_on_error = remove_on_error
self.fail_if_exists = fail_if_exists
self._ci_hook: Any = None
self.tags = tags
self.os_type = os_type
if self.os_type not in ["Linux", "Windows"]:
raise AirflowException(
"Invalid value for the os_type argument. "
"Please set 'Linux' or 'Windows' as the os_type. "
f"Found `{self.os_type}`."
)
self.restart_policy = restart_policy
if self.restart_policy not in ["Always", "OnFailure", "Never"]:
raise AirflowException(
"Invalid value for the restart_policy argument. "
"Please set one of 'Always', 'OnFailure','Never' as the restart_policy. "
f"Found `{self.restart_policy}`"
)
self.ip_address = ip_address
self.ports = ports
def execute(self, context: dict) -> int:
# Check name again in case it was templated.
self._check_name(self.name)
self._ci_hook = AzureContainerInstanceHook(self.ci_conn_id)
if self.fail_if_exists:
self.log.info("Testing if container group already exists")
if self._ci_hook.exists(self.resource_group, self.name):
raise AirflowException("Container group exists")
if self.registry_conn_id:
registry_hook = AzureContainerRegistryHook(self.registry_conn_id)
image_registry_credentials: Optional[list] = [
registry_hook.connection,
]
else:
image_registry_credentials = None
environment_variables = []
for key, value in self.environment_variables.items():
if key in self.secured_variables:
e = EnvironmentVariable(name=key, secure_value=value)
else:
e = EnvironmentVariable(name=key, value=value)
environment_variables.append(e)
volumes: List[Union[Volume, Volume]] = []
volume_mounts: List[Union[VolumeMount, VolumeMount]] = []
for conn_id, account_name, share_name, mount_path, read_only in self.volumes:
hook = AzureContainerVolumeHook(conn_id)
mount_name = "mount-%d" % len(volumes)
volumes.append(
hook.get_file_volume(mount_name, share_name, account_name, read_only)
)
volume_mounts.append(
VolumeMount(name=mount_name, mount_path=mount_path, read_only=read_only)
)
exit_code = 1
try:
self.log.info(
"Starting container group with %.1f cpu %.1f mem",
self.cpu,
self.memory_in_gb,
)
if self.gpu:
self.log.info(
"GPU count: %.1f, GPU SKU: %s", self.gpu.count, self.gpu.sku
)
resources = ResourceRequirements(
requests=ResourceRequests(
memory_in_gb=self.memory_in_gb, cpu=self.cpu, gpu=self.gpu
)
)
if self.ip_address and not self.ports:
self.ports = [ContainerPort(port=80)]
self.log.info("Default port set. Container will listen on port 80")
container = Container(
name=self.name,
image=self.image,
resources=resources,
command=self.command,
environment_variables=environment_variables,
volume_mounts=volume_mounts,
ports=self.ports,
)
container_group = ContainerGroup(
location=self.region,
containers=[
container,
],
image_registry_credentials=image_registry_credentials,
volumes=volumes,
restart_policy=self.restart_policy,
os_type=self.os_type,
tags=self.tags,
ip_address=self.ip_address,
)
self._ci_hook.create_or_update(
self.resource_group, self.name, container_group
)
self.log.info(
"Container group started %s/%s", self.resource_group, self.name
)
exit_code = self._monitor_logging(self.resource_group, self.name)
self.log.info("Container had exit code: %s", exit_code)
if exit_code != 0:
raise AirflowException(
f"Container had a non-zero exit code, {exit_code}"
)
return exit_code
except CloudError:
self.log.exception("Could not start container group")
raise AirflowException("Could not start container group")
finally:
if exit_code == 0 or self.remove_on_error:
self.on_kill()
def on_kill(self) -> None:
if self.remove_on_error:
self.log.info("Deleting container group")
try:
self._ci_hook.delete(self.resource_group, self.name)
except Exception: # pylint: disable=broad-except
self.log.exception("Could not delete container group")
def _monitor_logging(self, resource_group: str, name: str) -> int:
last_state = None
last_message_logged = None
last_line_logged = None
# pylint: disable=too-many-nested-blocks
while True:
try:
cg_state = self._ci_hook.get_state(resource_group, name)
instance_view = cg_state.containers[0].instance_view
# If there is no instance view, we show the provisioning state
if instance_view is not None:
c_state = instance_view.current_state
state, exit_code, detail_status = (
c_state.state,
c_state.exit_code,
c_state.detail_status,
)
messages = [event.message for event in instance_view.events]
last_message_logged = self._log_last(messages, last_message_logged)
else:
state = cg_state.provisioning_state
exit_code = 0
detail_status = "Provisioning"
if state != last_state:
self.log.info("Container group state changed to %s", state)
last_state = state
if state in ["Running", "Terminated"]:
try:
logs = self._ci_hook.get_logs(resource_group, name)
last_line_logged = self._log_last(logs, last_line_logged)
except CloudError:
self.log.exception(
"Exception while getting logs from container instance, retrying..."
)
if state == "Terminated":
self.log.error(
"Container exited with detail_status %s", detail_status
)
return exit_code
if state == "Failed":
self.log.error("Azure provision failure")
return 1
except AirflowTaskTimeout:
raise
except CloudError as err:
if "ResourceNotFound" in str(err):
self.log.warning(
"ResourceNotFound, container is probably removed "
"by another process "
"(make sure that the name is unique)."
)
return 1
else:
self.log.exception("Exception while getting container groups")
except Exception: # pylint: disable=broad-except
self.log.exception("Exception while getting container groups")
sleep(1)
def _log_last(self, logs: Optional[list], last_line_logged: Any) -> Optional[Any]:
if logs:
# determine the last line which was logged before
last_line_index = 0
for i in range(len(logs) - 1, -1, -1):
if logs[i] == last_line_logged:
# this line is the same, hence print from i+1
last_line_index = i + 1
break
# log all new ones
for line in logs[last_line_index:]:
self.log.info(line.rstrip())
return logs[-1]
return None
@staticmethod
def _check_name(name: str) -> str:
if "{{" in name:
# Let macros pass as they cannot be checked at construction time
return name
regex_check = re.match("[a-z0-9]([-a-z0-9]*[a-z0-9])?", name)
if regex_check is None or regex_check.group() != name:
raise AirflowException(
'ACI name must match regex [a-z0-9]([-a-z0-9]*[a-z0-9])? (like "my-name")'
)
if len(name) > 63:
raise AirflowException("ACI name cannot be longer than 63 characters")
return name