Open
Description
Hello,
The following code snippet shows a custom distance function that scales the simple dot distance with the rewards associated with each embedding.
import torch
from pytorch_metric_learning.distances import BaseDistance
class CustomDistance(BaseDistance):
def __init__(self, params):
super().__init__(params, is_inverted=True)
assert self.is_inverted
# dict of ref_emb rewards
self.rewards = params.rewards
def compute_mat(self, embeddings_ids, embeddings, ref_emb_ids, ref_emb):
mat = 20 * torch.einsum("ab,cb->ac", embeddings, ref_emb)
for col, ref_emb_idx in ref_emb_ids.items():
mat[:, col] *= self.rewards[ref_emb_idx]
However, the loss function call of mat = self.distance(embeddings, ref_emb)
not allows calling the overridden method compute_mat(self, embeddings_ids, embeddings, ref_emb_ids, ref_emb)
(containing the embeddings and ref_emb ids) required to get the embedding reward and scale the corresponding distance value.
Is there a workaround?
Thank you.