0

I am currently implementing a function to compute Custom Cross Entropy Loss. The definition of the function is a following image.

cite from "Deep Ordinal Regression Network for Monocular Depth Estimation", Huan Fu et al. CVPR 2018

my codes are as following,

output = output.permute(0, 2, 3, 1)
target = target.permute(0, 2, 3, 1)

batch, height, width, channel = output.size()

total_loss = 0.
for b in range(batch): # for each batch
    o = output[b]
    t = target[b]
    loss = 0.
    for w in range(width):
        for h in range(height): # for every pixel([h,w]) in the image
            sid_t = t[h][w][0]
            sid_o_candi = o[h][w]
            part1 = 0. # to store the first sigma 
            part2 = 0. # to store the second sigma

            for k in range(0, sid_t):
                p = torch.sum(sid_o_candi[k:]) # to get Pk(w,h)
                part1 += torch.log(p + 1e-12).item()

            for k in range(sid_t, intervals):
                p = torch.sum(sid_o_candi[k:]) # to get Pk(w,h)
                part2 += torch.log(1-p + 1e-12).item()

            loss += part1 + part2

    loss /= width * height * (-1)
    total_loss += loss
total_loss /= batch
return torch.tensor(total_loss, dtype=torch.float32)

I am wondering is there any optimization could be done with these code.

1

1 Answer 1

2

I'm not sure sid_t = t[h][w][0] is the same for every pixel or not. If so, you can get rid of all for loop which boost the speed of computing loss.

Don't use .item() because it will return a Python value which loses the grad_fn track. Then you can't use loss.backward() to compute the gradients.

If sid_t = t[h][w][0] is not the same, here is some modification to help you get rid of at least 1 for-loop:



batch, height, width, channel = output.size()

total_loss = 0.
for b in range(batch): # for each batch
    o = output[b]
    t = target[b]
    loss = 0.
    for w in range(width):
        for h in range(height): # for every pixel([h,w]) in the image
            sid_t = t[h][w][0]
            sid_o_candi = o[h][w]
            part1 = 0. # to store the first sigma 
            part2 = 0. # to store the second sigma

            sid1_cumsum = sid_o_candi[:sid_t].flip(dim=(0,)).cumsum(dim=0).flip(dims=(0,)) 
            part1 = torch.sum(torch.log(sid1_cumsum + 1e-12))

            sid2_cumsum = sid_o_candi[sid_t:intervals].flip(dim=(0,)).cumsum(dim=0).flip(dims=(0,)) 
            part2 = torch.sum(torch.log(1 - sid2_cumsum + 1e-12))

            loss += part1 + part2

    loss /= width * height * (-1)
    total_loss += loss
total_loss /= batch
return torch.tensor(total_loss, dtype=torch.float32)

How it works:

x = torch.arange(10); 
print(x)

x_flip = x.flip(dims=(0,)); 
print(x_flip)

x_inverse_cumsum = x_flip.cumsum(dim=0).flip(dims=(0,))
print(x_inverse_cumsum)

# output
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
tensor([45, 45, 44, 42, 39, 35, 30, 24, 17,  9])

Hope it helps.

Sign up to request clarification or add additional context in comments.

1 Comment

Is there a typo in this statement "part2 = torch.sum(torch.(torch.log(1 - sid2_cumsum + 1e-12)))", since my interpreter reports errors?

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.