How to convert pytorch model to half for inference using c++

398 views Asked by At

I am writing C++ code to load model and inference, the saved model is float32, I need to convert it to half precision just like:

model = torch.jit.load(model_path)
model.half()

seems that torch C++ doesn't have any corresponding api to do "module.half". And python module.half iterates all parameters and buffers, convert to half for floating point tensors. And torch::jit::Module has named_parameters() and named_buffer(), but it's not allowed to change the tensors, also torch::jit::Module provides apply and register_parameter apis, tried the following code:

auto model = torch::jit::load(model_path)
model.appy([&](torch::jit::Module& module) {
  for (auto x : module.named_parameters(false)) {
    if (x.value.is_floating_point()) {
      module.register_parameter(x.name, x.value.to(torch::kHalf), false);
    }
  }
  for (auto x : module.named_buffers(false)) {
    if (x.value.is_floating_point()) {
      module.register_parameter(x.name, x.value.to(torch::kHalf), true);
    }
  }
});

Above code passes compilation and running, I also checked the dtype everything is good. But after moving model to GPU by model.to(device), the forward fails with "RuntimeError: Input and parameter tensors are not at the same device, found input tensor at cuda:0 and parameter tensor at cpu". Everything is perfect if float32.

So, what's the correct way to convert module to half using C++?

1

There are 1 answers

0
fei.sun On

After reading the torch source code, I found the following solution, post here in case someone else has the same issue:

// check https://github.com/pytorch/pytorch/blob/3e354ef3e3f81a87158364dec0e527aee69236c1/torch/csrc/jit/api/module.cpp#L182

void ToHalf() {
  auto convert_to_half = [](const torch::autograd::Variable& variable) {
    auto new_data = variable.to(variable.device(), torch::kHalf, false);
    variable.set_data(new_data);
  };

  for (auto e : model.parameters()) {
    if (e.is_floating_point()) {
      convert_to_half(e);
    }
  }

  for (auto e : model.buffers()) {
    if (e.is_floating_point()) {
      convert_to_half(e);
    }
  }
}