Skip to content

Allow additional parameters in compute_mat #697

Open
@celsofranssa

Description

@celsofranssa

Hello,

The following code snippet shows a custom distance function that scales the simple dot distance with the rewards associated with each embedding.

import torch
from pytorch_metric_learning.distances import BaseDistance


class CustomDistance(BaseDistance):

    def __init__(self, params):
        super().__init__(params, is_inverted=True)
        assert self.is_inverted
        # dict of ref_emb rewards
        self.rewards = params.rewards

    def compute_mat(self, embeddings_ids, embeddings, ref_emb_ids, ref_emb):
        mat =  20 * torch.einsum("ab,cb->ac", embeddings, ref_emb)
        for col, ref_emb_idx in ref_emb_ids.items():
            mat[:, col] *= self.rewards[ref_emb_idx]

However, the loss function call of mat = self.distance(embeddings, ref_emb) not allows calling the overridden method compute_mat(self, embeddings_ids, embeddings, ref_emb_ids, ref_emb) (containing the embeddings and ref_emb ids) required to get the embedding reward and scale the corresponding distance value.

Is there a workaround?

Thank you.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions