You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
216 lines
7.7 KiB
216 lines
7.7 KiB
# flake8: noqa: F841 |
|
# |
|
# Licensed to the Apache Software Foundation (ASF) under one |
|
# or more contributor license agreements. See the NOTICE file |
|
# distributed with this work for additional information |
|
# regarding copyright ownership. The ASF licenses this file |
|
# to you under the Apache License, Version 2.0 (the |
|
# "License"); you may not use this file except in compliance |
|
# with the License. You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, |
|
# software distributed under the License is distributed on an |
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
|
# KIND, either express or implied. See the License for the |
|
# specific language governing permissions and limitations |
|
# under the License. |
|
""" |
|
A template called by DataFlowPythonOperator to summarize BatchPrediction. |
|
|
|
It accepts a user function to calculate the metric(s) per instance in |
|
the prediction results, then aggregates to output as a summary. |
|
|
|
It accepts the following arguments: |
|
|
|
- ``--prediction_path``: |
|
The GCS folder that contains BatchPrediction results, containing |
|
``prediction.results-NNNNN-of-NNNNN`` files in the json format. |
|
Output will be also stored in this folder, as 'prediction.summary.json'. |
|
- ``--metric_fn_encoded``: |
|
An encoded function that calculates and returns a tuple of metric(s) |
|
for a given instance (as a dictionary). It should be encoded |
|
via ``base64.b64encode(dill.dumps(fn, recurse=True))``. |
|
- ``--metric_keys``: |
|
A comma-separated key(s) of the aggregated metric(s) in the summary |
|
output. The order and the size of the keys must match to the output |
|
of metric_fn. |
|
The summary will have an additional key, 'count', to represent the |
|
total number of instances, so the keys shouldn't include 'count'. |
|
|
|
|
|
Usage example: |
|
|
|
.. code-block: python |
|
|
|
from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator |
|
|
|
|
|
def get_metric_fn(): |
|
import math # all imports must be outside of the function to be passed. |
|
def metric_fn(inst): |
|
label = float(inst["input_label"]) |
|
classes = float(inst["classes"]) |
|
prediction = float(inst["scores"][1]) |
|
log_loss = math.log(1 + math.exp( |
|
-(label * 2 - 1) * math.log(prediction / (1 - prediction)))) |
|
squared_err = (classes-label)**2 |
|
return (log_loss, squared_err) |
|
return metric_fn |
|
metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True)) |
|
DataflowCreatePythonJobOperator( |
|
task_id="summary-prediction", |
|
py_options=["-m"], |
|
py_file="airflow.providers.google.cloud.utils.mlengine_prediction_summary", |
|
options={ |
|
"prediction_path": prediction_path, |
|
"metric_fn_encoded": metric_fn_encoded, |
|
"metric_keys": "log_loss,mse" |
|
}, |
|
dataflow_default_options={ |
|
"project": "xxx", "region": "us-east1", |
|
"staging_location": "gs://yy", "temp_location": "gs://zz", |
|
} |
|
) >> dag |
|
|
|
When the input file is like the following:: |
|
|
|
{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]} |
|
{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]} |
|
{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]} |
|
{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]} |
|
|
|
The output file will be:: |
|
|
|
{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25} |
|
|
|
To test outside of the dag: |
|
|
|
.. code-block:: python |
|
|
|
subprocess.check_call(["python", |
|
"-m", |
|
"airflow.providers.google.cloud.utils.mlengine_prediction_summary", |
|
"--prediction_path=gs://...", |
|
"--metric_fn_encoded=" + metric_fn_encoded, |
|
"--metric_keys=log_loss,mse", |
|
"--runner=DataflowRunner", |
|
"--staging_location=gs://...", |
|
"--temp_location=gs://...", |
|
]) |
|
""" |
|
|
|
import argparse |
|
import base64 |
|
import json |
|
import logging |
|
import os |
|
|
|
import apache_beam as beam |
|
import dill # pylint: disable=wrong-import-order |
|
|
|
|
|
class JsonCoder: |
|
"""JSON encoder/decoder.""" |
|
|
|
@staticmethod |
|
def encode(x): |
|
"""JSON encoder.""" |
|
return json.dumps(x).encode() |
|
|
|
@staticmethod |
|
def decode(x): |
|
"""JSON decoder.""" |
|
return json.loads(x) |
|
|
|
|
|
@beam.ptransform_fn |
|
def MakeSummary(pcoll, metric_fn, metric_keys): # pylint: disable=invalid-name |
|
"""Summary PTransform used in Dataflow.""" |
|
return ( |
|
pcoll |
|
| "ApplyMetricFnPerInstance" >> beam.Map(metric_fn) |
|
| "PairWith1" >> beam.Map(lambda tup: tup + (1,)) |
|
| "SumTuple" |
|
>> beam.CombineGlobally( |
|
beam.combiners.TupleCombineFn(*([sum] * (len(metric_keys) + 1))) |
|
) |
|
| "AverageAndMakeDict" |
|
>> beam.Map( |
|
lambda tup: dict( |
|
[(name, tup[i] / tup[-1]) for i, name in enumerate(metric_keys)] |
|
+ [("count", tup[-1])] |
|
) |
|
) |
|
) |
|
|
|
|
|
def run(argv=None): |
|
"""Helper for obtaining prediction summary.""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--prediction_path", |
|
required=True, |
|
help=( |
|
"The GCS folder that contains BatchPrediction results, containing " |
|
"prediction.results-NNNNN-of-NNNNN files in the json format. " |
|
"Output will be also stored in this folder, as a file" |
|
"'prediction.summary.json'." |
|
), |
|
) |
|
parser.add_argument( |
|
"--metric_fn_encoded", |
|
required=True, |
|
help=( |
|
"An encoded function that calculates and returns a tuple of " |
|
"metric(s) for a given instance (as a dictionary). It should be " |
|
"encoded via base64.b64encode(dill.dumps(fn, recurse=True))." |
|
), |
|
) |
|
parser.add_argument( |
|
"--metric_keys", |
|
required=True, |
|
help=( |
|
"A comma-separated keys of the aggregated metric(s) in the summary " |
|
"output. The order and the size of the keys must match to the " |
|
"output of metric_fn. The summary will have an additional key, " |
|
"'count', to represent the total number of instances, so this flag " |
|
"shouldn't include 'count'." |
|
), |
|
) |
|
known_args, pipeline_args = parser.parse_known_args(argv) |
|
|
|
metric_fn = dill.loads(base64.b64decode(known_args.metric_fn_encoded)) |
|
if not callable(metric_fn): |
|
raise ValueError("--metric_fn_encoded must be an encoded callable.") |
|
metric_keys = known_args.metric_keys.split(",") |
|
|
|
with beam.Pipeline(options=beam.pipeline.PipelineOptions(pipeline_args)) as pipe: |
|
# pylint: disable=no-value-for-parameter |
|
prediction_result_pattern = os.path.join( |
|
known_args.prediction_path, "prediction.results-*-of-*" |
|
) |
|
prediction_summary_path = os.path.join( |
|
known_args.prediction_path, "prediction.summary.json" |
|
) |
|
# This is apache-beam ptransform's convention |
|
_ = ( |
|
pipe |
|
| "ReadPredictionResult" |
|
>> beam.io.ReadFromText(prediction_result_pattern, coder=JsonCoder()) |
|
| "Summary" >> MakeSummary(metric_fn, metric_keys) |
|
| "Write" |
|
>> beam.io.WriteToText( |
|
prediction_summary_path, |
|
shard_name_template="", # without trailing -NNNNN-of-NNNNN. |
|
coder=JsonCoder(), |
|
) |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
# Dataflow does not print anything on the screen by default. Good practice says to configure the logger |
|
# to be able to track the progress. This code is run in a separate process, so it's safe. |
|
logging.getLogger().setLevel(logging.INFO) |
|
run()
|
|
|