I have finetuned a GPT-2 model within databricks and saved the model and tokenizer into DBFS. I am trying to log, register, and deploy that model to an endpoint, which I am not sure is currently possible at the moment. I have followed along with the examples provided here, here, here, here, and here, and while the examples work for me, I haven't been able to swap out the foundation models used with a finetuned model of my own. Here is some code that very closely approximates my code for logging and registering the model:
# Define PythonModel to log with mlflow.pyfunc.log_model
class GPT(mlflow.pyfunc.PythonModel):
def load_context(self, context):
"""
This method initializes the tokenizer and language model.
"""
self.tokenizer = GPT2Tokenizer.from_pretrained(snapshot_location)
config = GPT2Config.from_pretrained(snapshot_location)
self.model = GPT2LMHeadModel.from_pretrained(snapshot_location, config=config)
self.model.to(device='cuda')
self.model.eval()
def predict(self, model_input):
"""
This method generates prediction for the given input.
"""
generated_text = []
for index, row in model_input.iterrows():
prompt = row["prompt"]
temperature=model_input.get("temperature", [0.7])[0]
max_new_tokens=model_input.get("max_new_tokens", [100])[0]
full_prompt = prompt
encoded_input = self.tokenizer.encode(full_prompt, return_tensors="pt").to('cuda')
output = self.model.generate(encoded_input,
temperature=temperature,
max_new_tokens=max_new_tokens)
prompt_length = len(encoded_input[0])
generated_text.append(self.tokenizer.batch_decode(output[:,prompt_length:], skip_special_tokens=True))
return pd.Series(generated_text)
# Define input and output schema
input_schema = Schema([
ColSpec(DataType.string, "prompt"),
ColSpec(DataType.double, "temperature", optional=False),
ColSpec(DataType.integer, "max_new_tokens", optional=False)
])
output_schema = Schema([ColSpec(DataType.string, 'output')])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
# Define input example
input_example=pd.DataFrame({
"prompt":["$PROMPT"],
"temperature": [0.7],
"max_new_tokens": [100]
})
# Log the model with its details such as artifacts, pip requirements and input example
torch_version = torch.__version__.split("+")[0]
with mlflow.start_run() as run:
mlflow.pyfunc.log_model(
"model",
python_model=GPT(),
artifacts={'repository' : snapshot_location},
pip_requirements=[f"torch=={torch_version}",
f"transformers=={transformers.__version__}",
f"accelerate=={accelerate.__version__}", "einops", "sentencepiece"],
input_example=input_example,
signature=signature
)
mlflow.set_registry_uri("databricks-uc")
registered_name = "models.default.gpt" # Note that the UC model name follows the pattern <catalog_name>.<schema_name>.<model_name>, corresponding to the catalog, schema, and registered model name
# Register model in MLflow Model Registry
result = mlflow.register_model(
"runs:/"+run.info.run_id+"/model",
name=registered_name,
await_registration_for=1000,
)
client = MlflowClient(registry_uri="databricks-uc")
# Choose the right model version registered in the above cell.
client.set_registered_model_alias(name=registered_name, alias="GPT_model", version=result.version)
I can sucessfully log and register the model just fine, the issues arise when I try to prompt the model within a notebook or set up a model serving endpoint. As you can tell, I have not deviated much from the provided tutorials, but I believe everything hinges on where the snapshot_location
variable points to. I have tried pointing it at a dbfs file path, only to get the error while trying to serve the model An error occurred while loading the model. Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/dbfs/$FILEPATH'. Use 'repo_type' argument if needed..
in the service logs. I then tried the code:
snapshot_location = os.path.expanduser("~/gpt/output")
os.makedirs(snapshot_location, exist_ok=True)
which gave me the error HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/root/gpt/output'. Use 'repo_type' argument if needed.
while trying to prompt the model within a notebook.
My question is, is there a way to somehow get the repo_type
flag somewhere in there? If I can't save the model in DBFS, should I be saving the model in an S3 bucket? Googling the error led me to a lot of github comments about including the full file path, but I am not sure how applicable that is to my problem. I am going to try to move the files from DBFS onto the local disk, and try and register the model from there. Am I going about this the right way though? Is the issue not with the location but with my logging and registration code? Any insight on this would be greatly appreciated.