Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model is not working for different dataset #22

Open
sanahtech opened this issue Sep 11, 2019 · 2 comments
Open

Model is not working for different dataset #22

sanahtech opened this issue Sep 11, 2019 · 2 comments

Comments

@sanahtech
Copy link

SRGAN model is working fine for given DIV2K_train_hr dataset. But I tried to train the same model for a different dataset. It does not work.

following is the error.
File "init.py", line 83, in
train(FLAGS)
File "/home/student/usama_lahore/Memona/perceptual-reflection-removal-master/SRGANmodelnewdata/SRGAN-Keras-Implementation-master/train.py", line 21, in train
srgan_model.train(epochs, save_interval = save_interval, batch_size = batch_size)
File "/home/student/usama_lahore/Memona/perceptual-reflection-removal-master/SRGANmodelnewdata/SRGAN-Keras-Implementation-master/models/SRGAN.py", line 225, in train
real_dis_loss = self.discriminator.train_on_batch(hr_imgs, real)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1211, in train_on_batch
class_weight=class_weight)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 751, in _standardize_user_data
exception_prefix='input')
File "/usr/local/lib/python3.5/dist-packages/keras/engine/training_utils.py", line 128, in standardize_input_data
'with shape ' + str(data_shape))
ValueError: Error when checking input: expected input_3 to have 4 dimensions, but got array with shape (16, 1)

@mrciolino
Copy link

I have trained new GANs using this repo with a folder of images. How is the shape of your data (16,1)? Could you give more info on the experiment you are doing?

@sanahtech
Copy link
Author

norm_hr = self.high_reso_imgs[indx_high] / 127.5 - 1
print(norm_hr.shape)

in gen_pipeline(self, batch_size=16) function the above line prints (16, ).
while it should print (16,256,256,3)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants