MedCPTRetriever.retrieve:v2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import weave
from medrag_multi_modal.retrieval.common import SimilarityMetric
import torch
from medrag_multi_modal.utils import get_torch_backend
import torch.nn.functional as F
from medrag_multi_modal.retrieval.common import argsort_scores
@weave.op()
def retrieve(
self,
query: str,
top_k: int = 2,
metric: SimilarityMetric = SimilarityMetric.COSINE,
):
query = [query]
device = torch.device(get_torch_backend())
with torch.no_grad():
encoded = self._query_tokenizer(
query,
truncation=True,
padding=True,
return_tensors="pt",
)
query_embedding = self._query_encoder_model(**encoded).last_hidden_state[
:, 0, :