How to set layer weights during training tensorflow

584 views Asked by At

In every forward pass of the model, I want to implement l2 normalization on the softmax layer's columns, then set the weights back as per the imprinted weights paper and this pytorch implementation. I am using layer.set_weights() to set the normalized weights during the call() function of the model, but this implementation only works with eager execution, as something goes wrong with layer.set_weights() when building the graph.

here is the implementation of the model in tf 1.15:

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import Model
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense

class Extractor(Model):
    def __init__(self, input_shape):
        super(Extractor, self).__init__()
        self.basenet = ResNet50(include_top=False, weights="imagenet", 
                                 pooling="avg", input_shape=input_shape)

    def call(self, x):
        x = self.basenet(x)
        return(x)


class Embedding(Model):
    def __init__(self, num_nodes, norm=True):
        super(Embedding, self).__init__()
        self.fc = Dense(num_nodes, activation="relu")
        self.norm = norm
    
    def call(self, x):
        x = self.fc(x)
        if self.norm:
            x = tf.nn.l2_normalize(x)
        return x

class Classifier(Model):
    def __init__(self, n_classes, norm=True, bias=False):
       super(Classifier, self).__init__()
       self.n_classes = n_classes
       self.norm = norm
       self.bias = bias
    
    def build(self, inputs_shape):
       self.prediction = Dense(self.n_classes, 
                               activation="softmax",use_bias=False)
    
    def call(self, x):
        if self.norm:
            w = self.prediction.trainable_weights
            if w:
                w = tf.nn.l2_normalize(w, axis=2)
                self.prediction.set_weights(w)    
       
        x = self.prediction(x)
        return x 

class Net(Model):
    def __init__(self, input_shape, n_classes, num_nodes, norm=True, 
                 bias=False):
        super(Net, self).__init__()
        self.n_classes = n_classes
        self.num_nodes = num_nodes
        self.norm = norm
        self.bias = bias
        self.extractor = Extractor(input_shape)
        self.embedding = Embedding(self.num_nodes, norm=self.norm)
        self.classifier = Classifier(self.n_classes, norm=self.norm, 
                                     bias=self.bias)
    
    
    def call(self, x):
        x = self.extractor(x)
        x = self.embedding(x)
        x = self.classifier(x)
        return x

The weight normalization can be found in the call step of the Classifier class, where I call .set_weights() after normalizing it.

Creating the model with model = Net(input_shape,n_classes, num_nodes) and using model(x) works, but model.predict() and model.fit() give me errors about .get_weights(). I can train the model in eager mode with gradient tape, but it is extremely slow.

Is there another way I can set the weights of a Dense layer during training in each forward call but lets me use the model outside of eager mode? When I say eager mode I mean with tf.enable_eager_execution() at the start of the program.

Here is the error I get when I use model.predict(x) instead:

TypeError: len is not well defined for symbolic Tensors. (imprint_net_1/classifier/l2_normalize:0) Please call `x.shape` rather than `len(x)` for shape information.
0

There are 0 answers