# flake8: noqa: F841 # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ A template called by DataFlowPythonOperator to summarize BatchPrediction. It accepts a user function to calculate the metric(s) per instance in the prediction results, then aggregates to output as a summary. It accepts the following arguments: - ``--prediction_path``: The GCS folder that contains BatchPrediction results, containing ``prediction.results-NNNNN-of-NNNNN`` files in the json format. Output will be also stored in this folder, as 'prediction.summary.json'. - ``--metric_fn_encoded``: An encoded function that calculates and returns a tuple of metric(s) for a given instance (as a dictionary). It should be encoded via ``base64.b64encode(dill.dumps(fn, recurse=True))``. - ``--metric_keys``: A comma-separated key(s) of the aggregated metric(s) in the summary output. The order and the size of the keys must match to the output of metric_fn. The summary will have an additional key, 'count', to represent the total number of instances, so the keys shouldn't include 'count'. Usage example: .. code-block: python from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator def get_metric_fn(): import math # all imports must be outside of the function to be passed. def metric_fn(inst): label = float(inst["input_label"]) classes = float(inst["classes"]) prediction = float(inst["scores"][1]) log_loss = math.log(1 + math.exp( -(label * 2 - 1) * math.log(prediction / (1 - prediction)))) squared_err = (classes-label)**2 return (log_loss, squared_err) return metric_fn metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True)) DataflowCreatePythonJobOperator( task_id="summary-prediction", py_options=["-m"], py_file="airflow.providers.google.cloud.utils.mlengine_prediction_summary", options={ "prediction_path": prediction_path, "metric_fn_encoded": metric_fn_encoded, "metric_keys": "log_loss,mse" }, dataflow_default_options={ "project": "xxx", "region": "us-east1", "staging_location": "gs://yy", "temp_location": "gs://zz", } ) >> dag When the input file is like the following:: {"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]} {"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]} {"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]} {"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]} The output file will be:: {"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25} To test outside of the dag: .. code-block:: python subprocess.check_call(["python", "-m", "airflow.providers.google.cloud.utils.mlengine_prediction_summary", "--prediction_path=gs://...", "--metric_fn_encoded=" + metric_fn_encoded, "--metric_keys=log_loss,mse", "--runner=DataflowRunner", "--staging_location=gs://...", "--temp_location=gs://...", ]) """ import argparse import base64 import json import logging import os import apache_beam as beam import dill # pylint: disable=wrong-import-order class JsonCoder: """JSON encoder/decoder.""" @staticmethod def encode(x): """JSON encoder.""" return json.dumps(x).encode() @staticmethod def decode(x): """JSON decoder.""" return json.loads(x) @beam.ptransform_fn def MakeSummary(pcoll, metric_fn, metric_keys): # pylint: disable=invalid-name """Summary PTransform used in Dataflow.""" return ( pcoll | "ApplyMetricFnPerInstance" >> beam.Map(metric_fn) | "PairWith1" >> beam.Map(lambda tup: tup + (1,)) | "SumTuple" >> beam.CombineGlobally( beam.combiners.TupleCombineFn(*([sum] * (len(metric_keys) + 1))) ) | "AverageAndMakeDict" >> beam.Map( lambda tup: dict( [(name, tup[i] / tup[-1]) for i, name in enumerate(metric_keys)] + [("count", tup[-1])] ) ) ) def run(argv=None): """Helper for obtaining prediction summary.""" parser = argparse.ArgumentParser() parser.add_argument( "--prediction_path", required=True, help=( "The GCS folder that contains BatchPrediction results, containing " "prediction.results-NNNNN-of-NNNNN files in the json format. " "Output will be also stored in this folder, as a file" "'prediction.summary.json'." ), ) parser.add_argument( "--metric_fn_encoded", required=True, help=( "An encoded function that calculates and returns a tuple of " "metric(s) for a given instance (as a dictionary). It should be " "encoded via base64.b64encode(dill.dumps(fn, recurse=True))." ), ) parser.add_argument( "--metric_keys", required=True, help=( "A comma-separated keys of the aggregated metric(s) in the summary " "output. The order and the size of the keys must match to the " "output of metric_fn. The summary will have an additional key, " "'count', to represent the total number of instances, so this flag " "shouldn't include 'count'." ), ) known_args, pipeline_args = parser.parse_known_args(argv) metric_fn = dill.loads(base64.b64decode(known_args.metric_fn_encoded)) if not callable(metric_fn): raise ValueError("--metric_fn_encoded must be an encoded callable.") metric_keys = known_args.metric_keys.split(",") with beam.Pipeline(options=beam.pipeline.PipelineOptions(pipeline_args)) as pipe: # pylint: disable=no-value-for-parameter prediction_result_pattern = os.path.join( known_args.prediction_path, "prediction.results-*-of-*" ) prediction_summary_path = os.path.join( known_args.prediction_path, "prediction.summary.json" ) # This is apache-beam ptransform's convention _ = ( pipe | "ReadPredictionResult" >> beam.io.ReadFromText(prediction_result_pattern, coder=JsonCoder()) | "Summary" >> MakeSummary(metric_fn, metric_keys) | "Write" >> beam.io.WriteToText( prediction_summary_path, shard_name_template="", # without trailing -NNNNN-of-NNNNN. coder=JsonCoder(), ) ) if __name__ == "__main__": # Dataflow does not print anything on the screen by default. Good practice says to configure the logger # to be able to track the progress. This code is run in a separate process, so it's safe. logging.getLogger().setLevel(logging.INFO) run()