Minimum working example of training hyperprior; Weights not updating #138
-
|
I am trying to create a dummy example to train the hyperprior of an entropy model. I used bls2017.py as my reference. The issue seems to be that the dummy model doesn't see the trainable variables in the prior. Any thoughts on what I am missing? My environment:
My dummy example: import tensorflow as tf
import tensorflow_compression as tfc
class DummyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.prior = tfc.NoisyDeepFactorized()
self.build((None, 10))
def call(self, inputs):
entropy_model = tfc.ContinuousBatchedEntropyModel(self.prior, coding_rank=1, compression=False)
_, bits = entropy_model(inputs, training=True)
return tf.reduce_mean(bits)
model = DummyModel()
model.compile(tf.keras.optimizers.Adam(0.1), tf.keras.losses.MeanAbsoluteError())
x_train = tf.random.normal([10**6, 10], mean=5.0, stddev=0.5)
y_train = tf.zeros(shape=(x_train.shape[0],1))
init_vars = [v.numpy().mean() for v in model.prior.trainable_variables]
print('Prior weights:', len(model.prior.trainable_variables))
print('Model weights:', len(model.trainable_weights))
history = model.fit(x_train, y_train, batch_size=1024, epochs=2)
print(history.history)
unchanged = init_vars == [v.numpy().mean() for v in model.prior.trainable_variables]
print('Prior weights unchanged?', unchanged)Output: I am aiming to see the training step take the gradient of the average bits estimate (i.e. the MAE loss relative to a 0 target) w.r.t. the prior weights and then apply some update to those weights. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
Hi, you are experiencing a bug in older TF releases. |
Beta Was this translation helpful? Give feedback.
Hi, you are experiencing a bug in older TF releases.
tf.keras.Modelclasses didn't collect trainable variables from all nested objects that inherit fromtf.Module, only from ones that inherit fromtf.keras.layers.Layer.Distributionobjects would fall in this category. This was fixed in a later TF version. I think it was fixed in 2.5. I'd recommend using the latest version (2.8; 2.9 should probably be released end of this week). If that's not possible, there is a workaround, check out this commit.