Get the predictions using DataCollatorForCompletionOnlyLM after fine-tuning Llama2 using SFT trainer

324 views Asked by At

I am in the process of fine-tuning Llama2 using SFT trainer and quantization using Lora. my dataset is composed of questions structured like:

<s>[INST] 
<<SYS>> Please select the correct answer from the given multiple Options based on the given Context: <</SYS>>  
 Context: Abrasion is another type of mechanical weathering. With abrasion, one rock bumps against another rock. Gravity causes abrasion as a rock tumbles down a slope. Moving water causes abrasion it moves rocks so that they bump against one another (Figure 9.3). Strong winds cause abrasion by blasting sand against rock surfaces.  
 Question: Gravity causes erosion by all of the following except  \
 Options:(A) glaciers (B) moving air (C) flowing water (D) mass movement  
 Answer: [/INST] D </s>

And I am currently using DataCollatorForCompletionOnlyLM to compute the loss based on the predicted answers. In my case this is the instruction structure:

Do I provide the context, question and options as an instruction_template?

instruction_template = "</SYS>>\n\n Context:" response_template = "Answer: [/INST]" collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)

OR

response_template = "Answer: [/INST]" collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer, mlm=False)

I have tried multiple response templates but always get the error: RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])

Can you please guide me to the correct one?

1

There are 1 answers

0
pyAddict On

Unlike most other tokenizers, llama-2 tokenizer depends on the context. For example-

sent-1: """### User: Hello\n\n### Assistant: Hi, how can I help you?"""
sent-2: "### Assistant:"

Here token-id for ### Assistant: will be different as sent-1 has \n as context while sent-2 doesn't.

Solution: Tokenizing with same context is the end solution. Here Huggigface is explaining the similar issue in detail.