0

Starting with the end in mind:

  • Is there anyway I can write that arbitrary loss function by calling some pytorch functions, thus preserving the autograd graph?
  • How can I ensure my loss function is "wired" to autograd?
  • If that cannot be done how can I backpropagate the loss score?

Hi, I'm creating a neural network in Pytorch and I need to create an arbitrary loss function.

I want to train a model able to predict outputs without known labels. I do not really care about the actual output, but about it semantics. In other words, I'd like to verify the output "makes sense". Thus, I decided to design my own loss functions.

The input_data is a (m, 4) tensor and the output_data is a (m, m) tensor. I then construct a post_output_data list of uuid of size m from the output_data and the pre_input_data. The "pre-input" data is a short json object with a few fields such as a uuid. If the model was predicting a value that is "out-of-bound" the corresponding post-output data is equal to None.

The ValidityLoss function just computes the sum of out-of-bound results and returns it.

My problem is that the loss function breaks autograd and the weights of my model does not update at all.


# ValidityLoss is computed as the sum of `out-of-bound` results.
class ValidityLoss(nn.Module):
    def __init__(self):
        super(ValidityLoss, self).__init__()

    def forward(
        self, 
        outputs: (list[Output], list[Output]),
    ):
        bef, aft = outputs
        
        loss = self.__compute_loss(bef.output)
        loss += self.__compute_loss(aft.output)

        return torch.tensor(float(loss), requires_grad=True)

    def __compute_loss(self, x) -> int:
        loss = 0
        for item in x:
            if item is None:
               loss +=1
        return loss
4
  • The problem is that your loss is not differentiable
    – Dr. Snoopy
    Commented Mar 20 at 22:41
  • Your loss needs to take in the direct output of your model and compute the loss using pytorch operations
    – Karl
    Commented Mar 20 at 23:30
  • Thank you for both your answer. I'll try to find a way to compute that loss using pytorch operations. Commented Mar 21 at 10:02
  • Just out of curiosity, if my loss is not differentiable is there a way I could manually compute the grad or something? Commented Mar 21 at 10:03

0

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.