I have created a custom Keras metric, similar to the demo implementation below:
import tensorflow as tf
class MyMetric(tf.keras.metrics.Mean):
    def __init__(self, name='my_metric', dtype=None):
        super(MyMetric, self).__init__(name=name, dtype=dtype)
    def update_state(self, y_true, y_pred, sample_weight=None):
        return super(MyMetric, self).update_state(
            y_pred, sample_weight=sample_weight)
I have turned the implementation into a Python module with the init/main files and added the path to the system's PYTHONPATH.
I can use the metric when I train the Keras model.
Unfortunately, I haven't found a way to make the custom metric available to TensorFlow Model Analysis (TFMA).
In my interactive context notebook, I can load the metric when I create the eval_config.
import tensorflow as tf
import tensorflow_model_analysis as tfma 
from mymetric.metric import MyMetric
metrics = [MyMetric()]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)
eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='label_xf')],
        metrics_specs=metrics_specs,
        slicing_specs=[tfma.SlicingSpec()]
)
evaluator = Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'], 
    baseline_model=model_resolver.outputs['model'],
    eval_config=eval_config)
When I try to execute the evaluator, the metric is listed as in the metric specifications
metrics_specs {
  metrics {
    class_name: "MyMetric"
    config: "{\"dtype\": \"float32\", \"name\": \"my_metric\"}"
    threshold {
    }
  }
}
but the execution fails with the error
ValueError: Unknown metric function: MyMetric
Since the metric calculation is executed via Apache Beam's executor.Do function, I assume that Beam can't find the module (even though it is on the PYTHONPATH). If that is the case, how can I make the module available to Apache Beam beyond the PYTHONPATH configuration?
Traceback:
/usr/local/lib/python3.6/dist-packages/tensorflow_model_analysis/metrics/metric_specs.py in _deserialize_tf_metric(metric_config, custom_objects)
    741   cls_name, cfg = _tf_class_and_config(metric_config)
    742   with tf.keras.utils.custom_object_scope(custom_objects):
--> 743     return tf.keras.metrics.deserialize({'class_name': cls_name, 'config': cfg})
    744 
    745 
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/metrics.py in deserialize(config, custom_objects)
   3441       module_objects=globals(),
   3442       custom_objects=custom_objects,
-> 3443       printable_module_name='metric function')
   3444 
   3445 
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    345     config = identifier
    346     (cls, cls_config) = class_and_config_for_serialized_keras_object(
--> 347         config, module_objects, custom_objects, printable_module_name)
    348 
    349     if hasattr(cls, 'from_config'):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
    294   cls = get_registered_object(class_name, custom_objects, module_objects)
    295   if cls is None:
--> 296     raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
    297 
    298   cls_config = config['config']
ValueError: Unknown metric function: MyMetric
 
                        
You need to specify the module so that TFX knows where to find your MyMetric class. One way of doing this is to specify it as part of the metric specs:
from tensorflow_model_analysis import configmetric_config = [config.MetricConfig(class_name='MyMetric', module='mymodule.mymetric')]metrics_specs = [config.MetricsSpec(metrics=metric_config)]You will also need to create a module called
mymoduleand put yourMyMetricclass in inmymetric.pyfor this to work. Also make sure that the module is accessible from where you are executing the code (which should be the case if you have added it to your PYTHONPATH).