I'm trying to use the SKLearnClassifier Keras wrapper to do some grid searching and cross validation using the sklearn library but I'm unable to get the model to work properly.
def build_model(X, y, n_neurons: List[str], learning_rate: float):
model = keras.models.Sequential()
model.add(keras.Input(shape=(28*28,)))
model.add(keras.layers.Dense(n_neurons[0], activation="relu"))
model.add(keras.layers.Dense(n_neurons[1], activation="relu"))
model.add(keras.layers.Dense(10, activation="softmax"))
optimizer = keras.optimizers.SGD(learning_rate=learning_rate)
model.compile(loss="sparse_categorical_crossentropy",
optimizer=optimizer,
metrics=["accuracy"])
return model
sk_train = X_train.reshape((X_train.shape[0],X_train.shape[1]*X_train.shape[2]))
sk_val = X_val.reshape((X_val.shape[0],X_val.shape[1]*X_val.shape[2]))
model = keras.wrappers.SKLearnClassifier(model=build_model, model_kwargs={
"n_neurons": [300, 100],
"learning_rate": 3e-4
})
model.fit(sk_train, y_train, epochs=30, validation_data=(sk_val, y_val))
The error I get is
ValueError: Argument `output` must have rank (ndim) `target.ndim - 1`. Received: target.shape=(None, 10), output.shape=(None, 10)
The error message seems to be saying that it expects an output of shape (None, 10) and that it received (None, 10), which doesn't make sense to me. The model works fine if I just call the function and fit the model directly, without the wrapper:
dummy_model = build_model(X_train, y_train, [300, 100], learning_rate=3e-4)
dummy_model.summary()
dummy_model.fit(sk_train, y_train, epochs=30, validation_data=(sk_val, y_val))
I've also tried not reshaping the data to keep the original 28*28 shape but it just makes it worse.
print()to check what you have in variables in this line. And you could also check.ndimfor data in this line - it seems it has wrong number of dimentions.(None, 10). It shows that it gets target with(None, 10)and output with(None, 10)but it is wrong value - it seemsoutputhas to be smaller - maybe(10,)instead of(None, 10)sparse_categorical_crossentropyexpects integer labels. Try converting your labels from one-hot to integer. Also try removingXandyfrombuild_modelfunction to see if that helps. Thanks!