Adding the follwing layer after inputs solves the issue by force reshaping inputs, as suggested in the comments:
class ReshapeLayer(keras.layers.Layer):
def __init__(self, **kwargs):
super(ReshapeLayer, self).__init__(**kwargs)
def build(self):
pass
def call(self, inputs):
return tf.expand_dims(inputs, axis=-1)