import torch
import torch.nn as nn
import torch.nn.functional as F
from smplx import SMPL
from einops import rearrange
from models.loss import Loss
from transformers import CLIPProcessor, CLIPModel
from utils.utils import get_keypoints
from models.module import MusicEncoderLayer, MotionDecoderLayer
import math
class GPT(nn.Module):
def __init__(self, p=2, \
input_size=438, embed_size=512, num_layers=6, heads=8, forward_expansion=4, dropout=0.1, output_size=75):
super(GPT, self).__init__()
max_len, max_per = 450, 6
self.motion_pos_emb_t = nn.Parameter(torch.zeros(max_len, embed_size))
# self.motion_pos_emb_p = nn.Parameter(torch.zeros(max_per, embed_size))
# self.music_pose_emb_t = nn.Parameter(torch.zeros(max_len, embed_size))
self.music_emb = nn.Linear(input_size, embed_size)
self.motion_emb = nn.Linear(output_size, embed_size)
self.text_encoder = TextEncoder()
self.music_encoder = MusicEncoder(embed_size, num_layers, heads, forward_expansion, dropout)
self.motion_decoder = MotionDecoder(embed_size, num_layers, heads, forward_expansion, dropout, output_size)
# self.mask = generate_square_subsequent_mask(max_len, 'cuda')
# self.mask = self.mask.masked_fill(self.mask==0, float('-inf')).masked_fill(self.mask==1, float(0.0))
self.loss = nn.MSELoss()
# self.loss = Loss()
def forward(self, text, music, motion):
motion_src, motion_trg = motion[:, :, :-1, :], motion[:, :, 1:, :]
b, p, t, _ = motion_src.shape
text_encode = self.text_encoder(text)
music_encode = self.music_encoder(self.music_emb(music[:, :-1, :]))\
.reshape(b, 1, t, -1).repeat(1, p, 1, 1).reshape(b*p, t, -1)
mask = torch.nn.Transformer().generate_square_subsequent_mask(t).transpose(0, 1).cuda()
motion_emb = self.motion_emb(motion_src) + self.motion_pos_emb_t[:t, :].reshape(1, 1, t, -1).repeat(b, p, 1, 1)
motion_pred = self.motion_decoder(motion_emb, music_encode, mask=mask).reshape(b, p, t, -1)
loss = self.loss(motion_pred, motion_trg)
return motion_pred, loss
def inference(self, text, music, motion):
self.eval()
with torch.no_grad():
music, motion = music[:, :-1, :], motion[:, :, :-1, :]
b, p, t, c = motion.shape
music_encode = self.music_encoder(self.music_emb(music))\
.reshape(b, 1, t, -1).repeat(1, p, 1, 1).reshape(b*p, t, -1)
preds = torch.zeros(b, p, t, c).cuda()
preds[:, :, 0, :] = motion[:, :, 0, :]
mask = torch.nn.Transformer().generate_square_subsequent_mask(t).transpose(0, 1).cuda()
for i in range(1, t):
motion_emb = self.motion_emb(preds) + self.motion_pos_emb_t[:t, :].reshape(1, 1, t, -1).repeat(b, p, 1, 1)
current_pred = self.motion_decoder(motion_emb, music_encode, mask=mask).reshape(b, p, t, -1)
preds[:, :, i, :] += current_pred[:, :, i-1, :]
motion_pred = preds.reshape(b, p, t, -1)
print(motion_pred[0, 0, :10, :6])
import sys
sys.exit()
pred_keypoints = get_keypoints(motion_pred)
return {'keypoints': pred_keypoints, 'smpl': motion_pred}
class MusicEncoder(nn.Module):
def __init__(self, embed_size, num_layers, heads, forward_expansion, dropout):
super(MusicEncoder, self).__init__()
self.layers = nn.ModuleList(
[nn.TransformerEncoderLayer(d_model=embed_size, nhead=heads, dim_feedforward=embed_size*forward_expansion, \
dropout=dropout, batch_first=True) for _ in range(num_layers)]
)
def forward(self, x):
b, t, _ = x.shape
out = x
for layer in self.layers:
out = layer(out)
return out
class MotionDecoder(nn.Module):
def __init__(self, embed_size, num_layers, heads, forward_expansion, dropout, output_size):
super(MotionDecoder, self).__init__()
self.num_layers = num_layers
self.fc_out = nn.Linear(embed_size, output_size)
self.layers = nn.ModuleList(
[nn.TransformerDecoderLayer(d_model=embed_size, nhead=heads, dim_feedforward=embed_size*forward_expansion, \
dropout=dropout, batch_first=True) for _ in range(num_layers)]
)
def forward(self, motion_src, music_text_encode, mask=None):
b, p, t, _ = motion_src.shape
out = motion_src.reshape(b*p, t, -1)
for layer in self.layers:
out = layer(out, music_text_encode, tgt_mask=mask)
return self.fc_out(out)
class TextEncoder(nn.Module):
def __init__(self):
super(TextEncoder, self).__init__()
self.text_clip = CLIPModel.from_pretrained("./Pretrained/CLIP/Model")
self.text_processor = CLIPProcessor.from_pretrained("./Pretrained/CLIP/Processor")
def forward(self, texts):
texts_process = self.text_processor(text=texts, return_tensors="pt", padding=True, truncation=True)
text_process = {name: tensor.to(self.text_clip.device) for name, tensor in texts_process.items()}
text_output = self.text_clip.get_text_features(**text_process)
return text_output
I tend to finish a Music2Dance Task, Dance is the SMPL-Data, Music is a 439-dimension feature, and I have aligned their FPS. The training loss is decrease, but the inferecne result is absolutely wrong, the frames after the second's is the same. above is the error log and my code, please help me to find out the mistakes!
Thanks!
