MedQAAssistant.retrieve_chunks_for_options:v0
1
2
3
4
5
6
7
8
9
10
11
12
13
import weave
from medrag_multi_modal.retrieval.text_retrieval.bm25s_retrieval import BM25sRetriever
@weave.op()
def retrieve_chunks_for_options(self, options: list[str]) -> list[dict]:
retriever_kwargs = {"top_k": self.top_k_chunks_for_options}
if not isinstance(self.retriever, BM25sRetriever):
retriever_kwargs["metric"] = self.retrieval_similarity_metric
retrieved_chunks = []
for option in options:
retrieved_chunks += self.retriever.predict(query=option, **retriever_kwargs)
return retrieved_chunks