Title: Generating Sentences with TRL while Maintaining Sentiment - Issue with "AutoModelForCausalLMWithValueHead"

97 views Asked by At

I am currently working on generating sentences with TRL (Transformers Reinforcement Learning) while preserving the same sentiment as the sample sentences. However, I've come across an issue with the TRL code that uses AutoModelForCausalLMWithValueHead, which is primarily intended for generating responses, not sample text.

I would greatly appreciate any guidance or suggestions on how to address this issue and modify the TRL code appropriately for generating sample text while preserving sentiment.

Thank you in advance for your valuable insights!

Here is the code:

# 0. imports
import torch
from transformers import GPT2Tokenizer

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer


# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# 2. initialize trainer
ppo_config = {"batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)

# 3. encode a query
query_txt = "I want to rewrite this sentence with the same sentiment; ex. I really like this movie "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device)

# 4. generate model response
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 20,
}
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[0])

# 5. define a reward for a response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0, device=model.pretrained_model.device)]

# 6. train model with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
0

There are 0 answers