NVEmbed2Retriever.retrieve:v1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import weave
from medrag_multi_modal.retrieval.common import SimilarityMetric
import torch
from medrag_multi_modal.utils import get_torch_backend
import torch.nn.functional as F
from medrag_multi_modal.retrieval.common import argsort_scores
@weave.op()
def retrieve(
self,
query: list[str],
top_k: int = 2,
metric: SimilarityMetric = SimilarityMetric.COSINE,
):
device = torch.device(get_torch_backend())
with torch.no_grad():
query_embedding = self._model.encode(
self.add_eos(query), normalize_embeddings=True
)
if metric == SimilarityMetric.EUCLIDEAN:
scores = torch.squeeze(query_embedding @ self._vector_index.T)
else:
scores = F.cosine_similarity(query_embedding, self._vector_index)
scores = scores.cpu().numpy().tolist()
scores = argsort_scores(scores, descending=True)[:top_k]