I am running this code (https://github.com/ayu-22/BPPNet-Back-Projected-Pyramid-Network/blob/master/Single_Image_Dehazing.ipynb) on a custom dataset but I am running into this error.
RuntimeError: one of the variables needed for gradient computation has been modified by an in place operation: [torch. cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Please refer to the code link above for clarification of where the error is occurring.
I am running this model on a custom dataset, the data loader part is pasted below.
import torchvision.transforms as transforms
train_transform = transforms.Compose([
transforms.Resize((256,256)),
#transforms.RandomResizedCrop(256),
#transforms.RandomHorizontalFlip(),
#transforms.ColorJitter(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
])
class Flare(Dataset):
def __init__(self, flare_dir, wf_dir,transform = None):
self.flare_dir = flare_dir
self.wf_dir = wf_dir
self.transform = transform
self.flare_img = os.listdir(flare_dir)
self.wf_img = os.listdir(wf_dir)
def __len__(self):
return len(self.flare_img)
def __getitem__(self, idx):
f_img = Image.open(os.path.join(self.flare_dir, self.flare_img[idx])).convert("RGB")
for i in self.wf_img:
if (self.flare_img[idx].split('.')[0][4:] == i.split('.')[0]):
wf_img = Image.open(os.path.join(self.wf_dir, i)).convert("RGB")
break
f_img = self.transform(f_img)
wf_img = self.transform(wf_img)
return f_img, wf_img
flare_dir = '../input/flaredataset/Flare/Flare_img'
wf_dir = '../input/flaredataset/Flare/Without_Flare_'
flare_img = os.listdir(flare_dir)
wf_img = os.listdir(wf_dir)
wf_img.sort()
flare_img.sort()
print(wf_img[0])
train_ds = Flare(flare_dir, wf_dir,train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_ds,
batch_size=BATCH_SIZE,
shuffle=True)
To get a better idea of the dataset class , you can compare my dataset class with the link pasted above
