Let's say that I have the following, very straightforward pipeline:
import os
from tfx import v1 as tfx
_dataset_folder = './tfrecords/train/*'
_pipeline_data_folder = './pipeline_data'
_serving_model_dir = os.path.join(_pipeline_data_folder, 'serving_model')
example_gen = tfx.components.ImportExampleGen(input_base=_dataset_folder)
statistics_gen = tfx.components.StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=True)
example_validator = tfx.components.ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema'])
_transform_module_file = 'preprocessing_fn.py'
transform = tfx.components.Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=os.path.abspath(_transform_module_file),
custom_config={'statistics_gen': statistics_gen.outputs['statistics'],
'schema_gen': schema_gen.outputs['schema']})
_trainer_module_file = 'run_fn.py'
trainer = tfx.components.Trainer(
module_file=os.path.abspath(_trainer_module_file),
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema_gen.outputs['schema'],
train_args=tfx.proto.TrainArgs(num_steps=10),
eval_args=tfx.proto.EvalArgs(num_steps=6))
pusher = tfx.components.Pusher(
model=trainer.outputs['model'],
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory=_serving_model_dir)))
components = [
example_gen,
statistics_gen,
schema_gen,
example_validator,
transform,
trainer,
pusher,
]
pipeline = tfx.dsl.Pipeline(
pipeline_name='straightforward_pipeline',
pipeline_root=_pipeline_data_folder,
metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(
f'{_pipeline_data_folder}/metadata.db'),
components=components)
tfx.orchestration.LocalDagRunner().run(pipeline)
The only part a bit out of ordinary in the snippet code above is the fact that I'm passing statistics_gen
and schema_gen
to the transform step of the pipeline in the custom_config
argument. What I'm hoping to achieve here is iterating over the list of features in the Dataset, in order to transform them.
This is what I need for that:
- The list of features in the dataset (I don't want to hardcode/assume this list, I want my code to come up with it automatically)
- Each feature's type (again, I don't want to hardcode/assume them)
- Each feature's statistical attributes (like min, max) and again, I don't want to hardcode/assume them!!!
My question is, how can I do this in my preprocessing_fn.py
function?
BTW, I know how to do this if I have access to the CSV version of the dataset:
import tensorflow_data_validation as tfdv
dataset_stats = tfdv.generate_statistics_from_csv(examples_file)
feature_1_stats = tfdv.get_feature_stats(dataset_stats.datasets[0],
tfdv.FeaturePath(['feature_1']))
But there is a problem. It is extracting all the info from the dataset while in my code, I believe, they are already extracted by the pipeline steps statistics_gen
and schema_gen
. And I don't want to redo the whole process. I just need to learn how to use the mentioned two steps to get the info I need.