How to get the list of features along side their schema and stats using TFX

98 views Asked by At

Let's say that I have the following, very straightforward pipeline:

import os
from tfx import v1 as tfx


_dataset_folder = './tfrecords/train/*'
_pipeline_data_folder = './pipeline_data'
_serving_model_dir = os.path.join(_pipeline_data_folder, 'serving_model')

example_gen = tfx.components.ImportExampleGen(input_base=_dataset_folder)
statistics_gen = tfx.components.StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = tfx.components.SchemaGen(
    statistics=statistics_gen.outputs['statistics'],
    infer_feature_shape=True)
example_validator = tfx.components.ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=schema_gen.outputs['schema'])

_transform_module_file = 'preprocessing_fn.py'
transform = tfx.components.Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file=os.path.abspath(_transform_module_file),
    custom_config={'statistics_gen': statistics_gen.outputs['statistics'],
                   'schema_gen': schema_gen.outputs['schema']})

_trainer_module_file = 'run_fn.py'
trainer = tfx.components.Trainer(
    module_file=os.path.abspath(_trainer_module_file),
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    train_args=tfx.proto.TrainArgs(num_steps=10),
    eval_args=tfx.proto.EvalArgs(num_steps=6))


pusher = tfx.components.Pusher(
  model=trainer.outputs['model'],
  push_destination=tfx.proto.PushDestination(
    filesystem=tfx.proto.PushDestination.Filesystem(
        base_directory=_serving_model_dir)))

components = [
    example_gen,
    statistics_gen,
    schema_gen,
    example_validator,
    transform,
    trainer,
    pusher,
]

pipeline = tfx.dsl.Pipeline(
    pipeline_name='straightforward_pipeline',
    pipeline_root=_pipeline_data_folder,
    metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(
        f'{_pipeline_data_folder}/metadata.db'),
    components=components)

tfx.orchestration.LocalDagRunner().run(pipeline)

The only part a bit out of ordinary in the snippet code above is the fact that I'm passing statistics_gen and schema_gen to the transform step of the pipeline in the custom_config argument. What I'm hoping to achieve here is iterating over the list of features in the Dataset, in order to transform them.

This is what I need for that:

  1. The list of features in the dataset (I don't want to hardcode/assume this list, I want my code to come up with it automatically)
  2. Each feature's type (again, I don't want to hardcode/assume them)
  3. Each feature's statistical attributes (like min, max) and again, I don't want to hardcode/assume them!!!

My question is, how can I do this in my preprocessing_fn.py function?

BTW, I know how to do this if I have access to the CSV version of the dataset:

import tensorflow_data_validation as tfdv

dataset_stats = tfdv.generate_statistics_from_csv(examples_file)
feature_1_stats = tfdv.get_feature_stats(dataset_stats.datasets[0],
                                         tfdv.FeaturePath(['feature_1']))

But there is a problem. It is extracting all the info from the dataset while in my code, I believe, they are already extracted by the pipeline steps statistics_gen and schema_gen. And I don't want to redo the whole process. I just need to learn how to use the mentioned two steps to get the info I need.

0

There are 0 answers