Skip to content

What will happen with layers which are in tfa, but not in other keras frameworks and which do not work with Keras 3 (I'm intrested in WeightNormalization layer) #2869

@Kurdakov

Description

@Kurdakov

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 20.04): Linux Ubuntu 22.04
  • TensorFlow version and how it was installed (source or binary): Tensorflow 2.6.1 binary installation
  • TensorFlow-Addons version and how it was installed (source or binary): master source
  • Python version: 3.10
  • Is GPU used? (yes/no): no

Describe the bug
while master branch has fixed imports for Keras 3
class WeightNormalization(tf.keras.layers.Wrapper) won't work with Keras 3

Code to reproduce the issue

import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow.keras.layers import Conv1D, Embedding, MaxPooling1D, Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.optimizers import Adam

max_words = 800

(Xtrain, ytrain), (Xtest, ytest) = imdb.load_data(num_words=1000)

Xtrain = sequence.pad_sequences(Xtrain, maxlen=max_words)
Xtest = sequence.pad_sequences(Xtest, maxlen=max_words)        

model = Sequential()
model.add(Embedding(1000, 500, input_length=max_words))
model.add(tfa.layers.WeightNormalization(Conv1D(64, 3, activation='relu')))
model.add(MaxPooling1D(2,2))
model.add(tfa.layers.WeightNormalization(Conv1D(32, 3, activation='relu')))
model.add(MaxPooling1D(2,2))
model.add(Flatten())
model.add(Dense(10, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()

model.compile(optimizer=Adam(.0001), metrics=['accuracy'], loss='binary_crossentropy')
model.fit(Xtrain, ytrain, validation_split=.2, epochs=10)

problems:

def compute_output_shape(self, input_shape):

uses as_list(), Keras 3 does not support it, removal of as_list helps.

other problems which I failed to resolve are in creation of
self._naked_clone_layer

the problem is essentially is that class WeightNormalization is absent in other keras frameworks, but it does not work in tfa with Keras 3 either.

I understand that tfa is near end of support (and already almost an year in minimal support mode), but then the question is - what to use in place of WeightNormalization layer in Keras 3?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions