I have a teacher, student and an adapter NER model. The adaptor is a standalone BiLSTM model that takes discrete outputs from teacher/student model and generates some logits. Now I want to do knowledge distillation but instead of calculating the loss from comparing teacher and student output logits, i want the loss to be calculated from the teacher and student outputs generated after passing through the NER model. The problem is in pyTorch, once the discrete outputs (that are input to the NER model) is generated, somewhere the computation graph breaks and i cannot backpropagate the loss to the student model. When I do
ce_loss = nn.CrossEntropyLoss(teacher_ner_logits, student_ner_logits)
ce_loss.backward()
the loss isn't propagated backwards to the student model. I want to know if this is logically possible to implement and if so how?