Skip to main content
added 22 characters in body
Source Link
josca
  • 71
  • 1
  • 2

Another easy workaround is:

def create_model(batch_size):
    model = Sequential()
    model.add(LSTM(1, batch_input_shape=(batch_size, 1, sl), stateful=True))
    model.add(Dense(1))
    return model

model_train = create_model(batch_size=50)

model_train.compile(loss='mean_squared_error', optimizer='adam')
model_train.fit(trainX, trainY, epochs=epochs, batch_size=batch_size)

model_predict = create_model(batch_size=1)

weights = model_train.get_weights()
model_predict.set_weights(weights)

Another easy workaround is:

def create_model(batch_size):
    model = Sequential()
    model.add(LSTM(1, batch_input_shape=(batch_size, 1, sl), stateful=True))
    model.add(Dense(1))

model_train = create_model(batch_size=50)

model_train.compile(loss='mean_squared_error', optimizer='adam')
model_train.fit(trainX, trainY, epochs=epochs, batch_size=batch_size)

model_predict = create_model(batch_size=1)

weights = model_train.get_weights()
model_predict.set_weights(weights)

Another easy workaround is:

def create_model(batch_size):
    model = Sequential()
    model.add(LSTM(1, batch_input_shape=(batch_size, 1, sl), stateful=True))
    model.add(Dense(1))
    return model

model_train = create_model(batch_size=50)

model_train.compile(loss='mean_squared_error', optimizer='adam')
model_train.fit(trainX, trainY, epochs=epochs, batch_size=batch_size)

model_predict = create_model(batch_size=1)

weights = model_train.get_weights()
model_predict.set_weights(weights)
Post Undeleted by josca
Post Deleted by josca
Source Link
josca
  • 71
  • 1
  • 2

Another easy workaround is:

def create_model(batch_size):
    model = Sequential()
    model.add(LSTM(1, batch_input_shape=(batch_size, 1, sl), stateful=True))
    model.add(Dense(1))

model_train = create_model(batch_size=50)

model_train.compile(loss='mean_squared_error', optimizer='adam')
model_train.fit(trainX, trainY, epochs=epochs, batch_size=batch_size)

model_predict = create_model(batch_size=1)

weights = model_train.get_weights()
model_predict.set_weights(weights)