def create_hparams():
return trainer_lib.create_hparams(
FLAGS.hparams_set,
FLAGS.hparams,
data_dir=os.path.expanduser(FLAGS.data_dir),
problem_name=FLAGS.problem)
def create_decode_hparams():
decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
decode_hp.shards = FLAGS.decode_shards
decode_hp.shard_id = FLAGS.worker_id
decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory
decode_hp.decode_in_memory = decode_in_memory
decode_hp.decode_to_file = FLAGS.decode_to_file
decode_hp.decode_reference = FLAGS.decode_reference
return decode_hp
hp = create_hparams()
decode_hp = create_decode_hparams()
run_conf = t2t_trainer.create_run_config(hp)
estimator = trainer_lib.create_estimator(
FLAGS.model,
hp,
run_conf,
decode_hparams=decode_hp,
use_tpu=FLAGS.use_tpu)
print(run_conf.session_config)
def input_fn():
inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="inputs")
input_tensor = {'inputs': inputs }
return tf.estimator.export.ServingInputReceiver(input_tensor, input_tensor)
predictor=tf.contrib.predictor.from_estimator(estimator, input_fn)
I got output of
InvalidArgumentError: Cannot assign a device for operation transformer/body/parallel_0/body/encoder/layer_0/self_attention/multihead_attention/dot_product_attention/attention: Could not satisfy explicit device specification '/device:GPU:0' because no supported kernel for GPU devices is available. Colocation Debug Info: Colocation group had the following types and supported devices: Root Member(assigned_device_name_index_=-1 requested_device_name_='/device:GPU:0' assigned_device_name_='' resource_device_name_='' supported_device_types_=[CPU] possible_devices_=[] ImageSummary: CPU
Colocation members, user-requested devices, and framework assigned devices, if any:
transformer/body/parallel_0/body/encoder/layer_0/self_attention/multihead_attention/dot_product_attention/attention (ImageSummary) /device:GPU:0Op: ImageSummary Node attrs: max_images=1, T=DT_FLOAT, bad_color=Tensor Registered kernels: device='CPU'
when i print the run_conf.session_config, I got allow_soft_placement: true. Many people said it can solve the problem of InvalidArgumentError but seems not work on me.