You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
277 lines
9.3 KiB
277 lines
9.3 KiB
# |
|
# Licensed to the Apache Software Foundation (ASF) under one |
|
# or more contributor license agreements. See the NOTICE file |
|
# distributed with this work for additional information |
|
# regarding copyright ownership. The ASF licenses this file |
|
# to you under the Apache License, Version 2.0 (the |
|
# "License"); you may not use this file except in compliance |
|
# with the License. You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, |
|
# software distributed under the License is distributed on an |
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
|
# KIND, either express or implied. See the License for the |
|
# specific language governing permissions and limitations |
|
# under the License. |
|
|
|
""" |
|
Example Airflow DAG for Google ML Engine service. |
|
""" |
|
import os |
|
from typing import Dict |
|
|
|
from airflow import models |
|
from airflow.operators.bash import BashOperator |
|
from airflow.providers.google.cloud.operators.mlengine import ( |
|
MLEngineCreateModelOperator, |
|
MLEngineCreateVersionOperator, |
|
MLEngineDeleteModelOperator, |
|
MLEngineDeleteVersionOperator, |
|
MLEngineGetModelOperator, |
|
MLEngineListVersionsOperator, |
|
MLEngineSetDefaultVersionOperator, |
|
MLEngineStartBatchPredictionJobOperator, |
|
MLEngineStartTrainingJobOperator, |
|
) |
|
from airflow.providers.google.cloud.utils import mlengine_operator_utils |
|
from airflow.utils.dates import days_ago |
|
|
|
PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") |
|
|
|
MODEL_NAME = os.environ.get("GCP_MLENGINE_MODEL_NAME", "model_name") |
|
|
|
SAVED_MODEL_PATH = os.environ.get( |
|
"GCP_MLENGINE_SAVED_MODEL_PATH", "gs://test-airflow-mlengine/saved-model/" |
|
) |
|
JOB_DIR = os.environ.get( |
|
"GCP_MLENGINE_JOB_DIR", "gs://test-airflow-mlengine/keras-job-dir" |
|
) |
|
PREDICTION_INPUT = os.environ.get( |
|
"GCP_MLENGINE_PREDICTION_INPUT", "gs://test-airflow-mlengine/prediction_input.json" |
|
) |
|
PREDICTION_OUTPUT = os.environ.get( |
|
"GCP_MLENGINE_PREDICTION_OUTPUT", "gs://test-airflow-mlengine/prediction_output" |
|
) |
|
TRAINER_URI = os.environ.get( |
|
"GCP_MLENGINE_TRAINER_URI", "gs://test-airflow-mlengine/trainer.tar.gz" |
|
) |
|
TRAINER_PY_MODULE = os.environ.get( |
|
"GCP_MLENGINE_TRAINER_TRAINER_PY_MODULE", "trainer.task" |
|
) |
|
|
|
SUMMARY_TMP = os.environ.get( |
|
"GCP_MLENGINE_DATAFLOW_TMP", "gs://test-airflow-mlengine/tmp/" |
|
) |
|
SUMMARY_STAGING = os.environ.get( |
|
"GCP_MLENGINE_DATAFLOW_STAGING", "gs://test-airflow-mlengine/staging/" |
|
) |
|
|
|
default_args = {"params": {"model_name": MODEL_NAME}} |
|
|
|
with models.DAG( |
|
"example_gcp_mlengine", |
|
schedule_interval=None, # Override to match your needs |
|
start_date=days_ago(1), |
|
tags=["example"], |
|
) as dag: |
|
# [START howto_operator_gcp_mlengine_training] |
|
training = MLEngineStartTrainingJobOperator( |
|
task_id="training", |
|
project_id=PROJECT_ID, |
|
region="us-central1", |
|
job_id="training-job-{{ ts_nodash }}-{{ params.model_name }}", |
|
runtime_version="1.15", |
|
python_version="3.7", |
|
job_dir=JOB_DIR, |
|
package_uris=[TRAINER_URI], |
|
training_python_module=TRAINER_PY_MODULE, |
|
training_args=[], |
|
labels={"job_type": "training"}, |
|
) |
|
# [END howto_operator_gcp_mlengine_training] |
|
|
|
# [START howto_operator_gcp_mlengine_create_model] |
|
create_model = MLEngineCreateModelOperator( |
|
task_id="create-model", |
|
project_id=PROJECT_ID, |
|
model={ |
|
"name": MODEL_NAME, |
|
}, |
|
) |
|
# [END howto_operator_gcp_mlengine_create_model] |
|
|
|
# [START howto_operator_gcp_mlengine_get_model] |
|
get_model = MLEngineGetModelOperator( |
|
task_id="get-model", |
|
project_id=PROJECT_ID, |
|
model_name=MODEL_NAME, |
|
) |
|
# [END howto_operator_gcp_mlengine_get_model] |
|
|
|
# [START howto_operator_gcp_mlengine_print_model] |
|
get_model_result = BashOperator( |
|
bash_command="echo \"{{ task_instance.xcom_pull('get-model') }}\"", |
|
task_id="get-model-result", |
|
) |
|
# [END howto_operator_gcp_mlengine_print_model] |
|
|
|
# [START howto_operator_gcp_mlengine_create_version1] |
|
create_version = MLEngineCreateVersionOperator( |
|
task_id="create-version", |
|
project_id=PROJECT_ID, |
|
model_name=MODEL_NAME, |
|
version={ |
|
"name": "v1", |
|
"description": "First-version", |
|
"deployment_uri": f"{JOB_DIR}/keras_export/", |
|
"runtime_version": "1.15", |
|
"machineType": "mls1-c1-m2", |
|
"framework": "TENSORFLOW", |
|
"pythonVersion": "3.7", |
|
}, |
|
) |
|
# [END howto_operator_gcp_mlengine_create_version1] |
|
|
|
# [START howto_operator_gcp_mlengine_create_version2] |
|
create_version_2 = MLEngineCreateVersionOperator( |
|
task_id="create-version-2", |
|
project_id=PROJECT_ID, |
|
model_name=MODEL_NAME, |
|
version={ |
|
"name": "v2", |
|
"description": "Second version", |
|
"deployment_uri": SAVED_MODEL_PATH, |
|
"runtime_version": "1.15", |
|
"machineType": "mls1-c1-m2", |
|
"framework": "TENSORFLOW", |
|
"pythonVersion": "3.7", |
|
}, |
|
) |
|
# [END howto_operator_gcp_mlengine_create_version2] |
|
|
|
# [START howto_operator_gcp_mlengine_default_version] |
|
set_defaults_version = MLEngineSetDefaultVersionOperator( |
|
task_id="set-default-version", |
|
project_id=PROJECT_ID, |
|
model_name=MODEL_NAME, |
|
version_name="v2", |
|
) |
|
# [END howto_operator_gcp_mlengine_default_version] |
|
|
|
# [START howto_operator_gcp_mlengine_list_versions] |
|
list_version = MLEngineListVersionsOperator( |
|
task_id="list-version", |
|
project_id=PROJECT_ID, |
|
model_name=MODEL_NAME, |
|
) |
|
# [END howto_operator_gcp_mlengine_list_versions] |
|
|
|
# [START howto_operator_gcp_mlengine_print_versions] |
|
list_version_result = BashOperator( |
|
bash_command="echo \"{{ task_instance.xcom_pull('list-version') }}\"", |
|
task_id="list-version-result", |
|
) |
|
# [END howto_operator_gcp_mlengine_print_versions] |
|
|
|
# [START howto_operator_gcp_mlengine_get_prediction] |
|
prediction = MLEngineStartBatchPredictionJobOperator( |
|
task_id="prediction", |
|
project_id=PROJECT_ID, |
|
job_id="prediction-{{ ts_nodash }}-{{ params.model_name }}", |
|
region="us-central1", |
|
model_name=MODEL_NAME, |
|
data_format="TEXT", |
|
input_paths=[PREDICTION_INPUT], |
|
output_path=PREDICTION_OUTPUT, |
|
labels={"job_type": "prediction"}, |
|
) |
|
# [END howto_operator_gcp_mlengine_get_prediction] |
|
|
|
# [START howto_operator_gcp_mlengine_delete_version] |
|
delete_version = MLEngineDeleteVersionOperator( |
|
task_id="delete-version", |
|
project_id=PROJECT_ID, |
|
model_name=MODEL_NAME, |
|
version_name="v1", |
|
) |
|
# [END howto_operator_gcp_mlengine_delete_version] |
|
|
|
# [START howto_operator_gcp_mlengine_delete_model] |
|
delete_model = MLEngineDeleteModelOperator( |
|
task_id="delete-model", |
|
project_id=PROJECT_ID, |
|
model_name=MODEL_NAME, |
|
delete_contents=True, |
|
) |
|
# [END howto_operator_gcp_mlengine_delete_model] |
|
|
|
training >> create_version |
|
training >> create_version_2 |
|
create_model >> get_model >> [get_model_result, delete_model] |
|
create_model >> create_version >> create_version_2 >> set_defaults_version >> list_version |
|
create_version >> prediction |
|
create_version_2 >> prediction |
|
prediction >> delete_version |
|
list_version >> list_version_result |
|
list_version >> delete_version |
|
delete_version >> delete_model |
|
|
|
# [START howto_operator_gcp_mlengine_get_metric] |
|
def get_metric_fn_and_keys(): |
|
""" |
|
Gets metric function and keys used to generate summary |
|
""" |
|
|
|
def normalize_value(inst: Dict): |
|
val = float(inst["dense_4"][0]) |
|
return tuple([val]) # returns a tuple. |
|
|
|
return normalize_value, ["val"] # key order must match. |
|
|
|
# [END howto_operator_gcp_mlengine_get_metric] |
|
|
|
# [START howto_operator_gcp_mlengine_validate_error] |
|
def validate_err_and_count(summary: Dict) -> Dict: |
|
""" |
|
Validate summary result |
|
""" |
|
if summary["val"] > 1: |
|
raise ValueError(f"Too high val>1; summary={summary}") |
|
if summary["val"] < 0: |
|
raise ValueError(f"Too low val<0; summary={summary}") |
|
if summary["count"] != 20: |
|
raise ValueError(f"Invalid value val != 20; summary={summary}") |
|
return summary |
|
|
|
# [END howto_operator_gcp_mlengine_validate_error] |
|
|
|
# [START howto_operator_gcp_mlengine_evaluate] |
|
( |
|
evaluate_prediction, |
|
evaluate_summary, |
|
evaluate_validation, |
|
) = mlengine_operator_utils.create_evaluate_ops( |
|
task_prefix="evaluate-ops", |
|
data_format="TEXT", |
|
input_paths=[PREDICTION_INPUT], |
|
prediction_path=PREDICTION_OUTPUT, |
|
metric_fn_and_keys=get_metric_fn_and_keys(), |
|
validate_fn=validate_err_and_count, |
|
batch_prediction_job_id="evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}", |
|
project_id=PROJECT_ID, |
|
region="us-central1", |
|
dataflow_options={ |
|
"project": PROJECT_ID, |
|
"tempLocation": SUMMARY_TMP, |
|
"stagingLocation": SUMMARY_STAGING, |
|
}, |
|
model_name=MODEL_NAME, |
|
version_name="v1", |
|
py_interpreter="python3", |
|
) |
|
# [END howto_operator_gcp_mlengine_evaluate] |
|
|
|
create_model >> create_version >> evaluate_prediction |
|
evaluate_validation >> delete_version
|
|
|