-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprototype.py
More file actions
95 lines (73 loc) · 2.89 KB
/
prototype.py
File metadata and controls
95 lines (73 loc) · 2.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Phase 1: Single-file prototype proving the full chain.
Hardcoded expressions → Claude API → Qwen 3.5 (untrained LoRA) → output.
Validates connectivity and data flow without training.
Usage:
python prototype.py
# On H100 node (for local model):
python prototype.py --with-local-model
"""
from __future__ import annotations
import argparse
def main():
parser = argparse.ArgumentParser(description="API Adapter prototype")
parser.add_argument(
"--with-local-model",
action="store_true",
help="Also run through local Qwen model (requires GPU)",
)
args = parser.parse_args()
from api_adapter.symbols import evaluate
from api_adapter.api_client import get_client, query_claude, parse_answer
# Test expressions: mix of standard and custom
test_cases = [
("3 + 5", "standard"),
("12 * 4 - 7", "standard"),
("10 / 2 + 3", "standard"),
("3 θ 5", "custom"), # 3 + 5 = 8
("12 γ 4 α 7", "custom"), # 12 * 4 - 7 = 41
("10 β 2 θ 3", "custom"), # 10 / 2 + 3 = 8
]
client = get_client()
print("=" * 60)
print("API Adapter Prototype - End-to-End Chain Test")
print("=" * 60)
results = []
for expr, expr_type in test_cases:
correct = evaluate(expr)
claude_response = query_claude(expr, client=client)
claude_answer = parse_answer(claude_response)
result = {
"expression": expr,
"type": expr_type,
"correct": correct,
"claude_response": claude_response,
"claude_answer": claude_answer,
"claude_correct": claude_answer == correct,
}
results.append(result)
status = "OK" if result["claude_correct"] else "WRONG"
print(f"\n[{status}] {expr} (type={expr_type})")
print(f" Correct answer: {correct}")
print(f" Claude says: {claude_response} (parsed: {claude_answer})")
# Summary
correct_count = sum(1 for r in results if r["claude_correct"])
print(f"\nClaude accuracy: {correct_count}/{len(results)}")
if args.with_local_model:
print("\n" + "=" * 60)
print("Running through local adapter model (untrained)...")
print("=" * 60)
from api_adapter.local_model import load_model, format_adapter_prompt, generate
model, tokenizer = load_model()
prompts = [
format_adapter_prompt(r["expression"], r["claude_response"], include_symbols=True)
for r in results
]
outputs = generate(model, tokenizer, prompts)
for r, output in zip(results, outputs):
print(f"\n Expression: {r['expression']}")
print(f" Claude: {r['claude_response']}")
print(f" Adapter: {output}")
print(f" Correct: {r['correct']}")
print("\nPrototype complete.")
if __name__ == "__main__":
main()