diff --git a/src/3_evals/1_llm_judge/run_eval.py b/src/3_evals/1_llm_judge/run_eval.py index 7687103..a6beb0b 100644 --- a/src/3_evals/1_llm_judge/run_eval.py +++ b/src/3_evals/1_llm_judge/run_eval.py @@ -18,10 +18,6 @@ from src.utils.langfuse.shared_client import flush_langfuse, langfuse_client -load_dotenv(verbose=True) -set_up_logging() - - SYSTEM_MESSAGE = """\ Answer the question using the search tool. \ EACH TIME before invoking the function, you must explain your reasons for doing so. \ @@ -148,6 +144,16 @@ async def run_and_evaluate( async def _main() -> None: + main_agent = agents.Agent( + name="Wikipedia Agent", + instructions=SYSTEM_MESSAGE, + tools=[agents.function_tool(client_manager.knowledgebase.search_knowledgebase)], + model=agents.OpenAIChatCompletionsModel( + model=client_manager.configs.default_planner_model, + openai_client=client_manager.openai_client, + ), + ) + coros = [ run_and_evaluate( run_name=args.run_name, main_agent=main_agent, lf_dataset_item=_item @@ -181,22 +187,15 @@ async def _main() -> None: parser.add_argument("--limit", type=int) args = parser.parse_args() - lf_dataset_items = langfuse_client.get_dataset(args.langfuse_dataset_name).items - if args.limit is not None: - lf_dataset_items = lf_dataset_items[: args.limit] - - client_manager = AsyncClientManager() + load_dotenv(verbose=True) + set_up_logging() setup_langfuse_tracer() - main_agent = agents.Agent( - name="Wikipedia Agent", - instructions=SYSTEM_MESSAGE, - tools=[agents.function_tool(client_manager.knowledgebase.search_knowledgebase)], - model=agents.OpenAIChatCompletionsModel( - model=client_manager.configs.default_planner_model, - openai_client=client_manager.openai_client, - ), - ) + client_manager = AsyncClientManager() + + lf_dataset_items = langfuse_client.get_dataset(args.langfuse_dataset_name).items + if args.limit is not None: + lf_dataset_items = lf_dataset_items[: args.limit] asyncio.run(_main()) diff --git a/src/3_evals/1_llm_judge/upload_data.py b/src/3_evals/1_llm_judge/upload_data.py index f794c82..b6ecbf7 100644 --- a/src/3_evals/1_llm_judge/upload_data.py +++ b/src/3_evals/1_llm_judge/upload_data.py @@ -15,17 +15,15 @@ from src.utils.langfuse.shared_client import langfuse_client -load_dotenv(verbose=True) - - -parser = argparse.ArgumentParser() -parser.add_argument("--source_dataset", required=True) -parser.add_argument("--langfuse_dataset_name", required=True) -parser.add_argument("--limit", type=int) - - if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--source_dataset", required=True) + parser.add_argument("--langfuse_dataset_name", required=True) + parser.add_argument("--limit", type=int) args = parser.parse_args() + + load_dotenv(verbose=True) + configs = Configs() set_up_langfuse_otlp_env_vars() diff --git a/src/3_evals/2_synthetic_data/annotate_diversity.py b/src/3_evals/2_synthetic_data/annotate_diversity.py index 80445e0..2e9ed3d 100644 --- a/src/3_evals/2_synthetic_data/annotate_diversity.py +++ b/src/3_evals/2_synthetic_data/annotate_diversity.py @@ -28,13 +28,6 @@ from langfuse._client.datasets import DatasetItemClient -parser = argparse.ArgumentParser() -parser.add_argument("--langfuse_dataset_name", required=True) -parser.add_argument("--run_name", default="cosine_similarity") -parser.add_argument("--limit", type=int) -parser.add_argument("--embed_batch_size", type=int, default=18) - - class EmbeddingResult(pydantic.BaseModel): """Tracks trace_id and embedding vector for an instance.""" @@ -91,6 +84,11 @@ def _avg_cosine_similarity(matrix: np.ndarray) -> np.ndarray: if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--langfuse_dataset_name", required=True) + parser.add_argument("--run_name", default="cosine_similarity") + parser.add_argument("--limit", type=int) + parser.add_argument("--embed_batch_size", type=int, default=18) args = parser.parse_args() assert args.embed_batch_size > 0, "args.embed_batch_size must be at least 1." diff --git a/src/3_evals/2_synthetic_data/gradio_visualize_diversity.py b/src/3_evals/2_synthetic_data/gradio_visualize_diversity.py index f4128d5..5f94a81 100644 --- a/src/3_evals/2_synthetic_data/gradio_visualize_diversity.py +++ b/src/3_evals/2_synthetic_data/gradio_visualize_diversity.py @@ -123,17 +123,17 @@ async def get_projection_plot( ) -viewer = gr.Interface( - fn=get_projection_plot, - inputs=[ - gr.Textbox(label="Dataset name"), - gr.Radio(["tsne", "pca"], label="Dimensionality Reduction Method"), - gr.Number(value=18, label="Number of rows to plot", minimum=1), - ], - outputs=gr.Plot(label="2D Embedding Plot"), - title="3.2 Text Embedding Visualizer", - description="Select a method to visualize 256-D embeddings of text snippets.", -) - if __name__ == "__main__": + viewer = gr.Interface( + fn=get_projection_plot, + inputs=[ + gr.Textbox(label="Dataset name"), + gr.Radio(["tsne", "pca"], label="Dimensionality Reduction Method"), + gr.Number(value=18, label="Number of rows to plot", minimum=1), + ], + outputs=gr.Plot(label="2D Embedding Plot"), + title="3.2 Text Embedding Visualizer", + description="Select a method to visualize 256-D embeddings of text snippets.", + ) + viewer.launch(share=True) diff --git a/src/3_evals/2_synthetic_data/synthesize_data.py b/src/3_evals/2_synthetic_data/synthesize_data.py index fe02a37..b5ad73f 100644 --- a/src/3_evals/2_synthetic_data/synthesize_data.py +++ b/src/3_evals/2_synthetic_data/synthesize_data.py @@ -42,12 +42,6 @@ {json_schema} """ -parser = argparse.ArgumentParser() -parser.add_argument("--source_dataset", required=True) -parser.add_argument("--langfuse_dataset_name", required=True) -parser.add_argument("--limit", type=int, default=18) -parser.add_argument("--max_concurrency", type=int, default=3) - class _Citation(pydantic.BaseModel): """Represents one cited source/article.""" @@ -112,6 +106,11 @@ async def generate_synthetic_test_cases( if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--source_dataset", required=True) + parser.add_argument("--langfuse_dataset_name", required=True) + parser.add_argument("--limit", type=int, default=18) + parser.add_argument("--max_concurrency", type=int, default=3) args = parser.parse_args() load_dotenv(verbose=True)