MedCPTRetriever.retrieve:v4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import weave
from medrag_multi_modal.retrieval.common import SimilarityMetric
import torch
from medrag_multi_modal.utils import get_torch_backend
import torch.nn.functional as F
from medrag_multi_modal.retrieval.common import argsort_scores
@weave.op()
def retrieve(
self,
query: str,
top_k: int = 2,
metric: SimilarityMetric = SimilarityMetric.COSINE,
):
"""
Retrieves the top-k most relevant chunks for a given query using the specified similarity metric.
This method encodes the input query into an embedding and computes similarity scores between
the query embedding and the precomputed vector index. The similarity metric can be either
cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores
are returned as a list of dictionaries, each containing a chunk and its corresponding score.
Args:
query (str): The input query string to search for relevant chunks.
top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.