MedCPTRetriever.predict:v0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import weave
from medrag_multi_modal.retrieval.common import SimilarityMetric
@weave.op()
def predict(
self,
query: str,
top_k: int = 2,
metric: SimilarityMetric = SimilarityMetric.COSINE,
):
"""
Predicts the most relevant chunks for a given query.
This function uses the `retrieve` method to find the top-k relevant chunks
from the dataset based on the input query. It allows specifying the number
of top relevant chunks to retrieve and the similarity metric to use for scoring.
!!! example "Example Usage"
```python
import weave
from dotenv import load_dotenv
import wandb
from medrag_multi_modal.retrieval import MedCPTRetriever