[Refactor] Harden and speed up the JIT cache-key computation#597
[Refactor] Harden and speed up the JIT cache-key computation#597sjfeng1999 wants to merge 5 commits into
Conversation
Per-call cache key (jit_function.py): - Require __cache_signature__ on every JitArgument; drop the str(ir.Type) fallback so unknown types raise instead of silently colliding under one key. - Fold module globals into the key: a cross-process-stable snapshot plus Triton-style in-process drift detection that raises on change (no auto-recompile). Composite (name, module) identity so same-named globals across modules don't collide. - target = (GPUTarget, device_id), read live per call so device switches and arch/env changes participate in the key (device_id via the active DeviceRuntime). - Whitelisted code-gen env vars re-read live into the key (os._Environ._data fast path with a public-API fallback). - Performance: memoize the recursive global-ref discovery and the globals key segment; per call only re-snapshots values and runs a lean drift loop. jit_argument.py: - Construct JitArgument-annotated params (e.g. Stream) via the annotation; remove the int+Stream special-case and the type fallback. Tests: - test_jit_cache_key_completeness.py: env drift, globals snapshot in key, globals drift raises, required __cache_signature__ protocol. - test_compile_hints.py: _target_ entry is now (GPUTarget, device_id). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR hardens JIT cache-key construction so generated artifacts better reflect codegen-affecting inputs such as environment, target/device, and referenced globals.
Changes:
- Adds env-var, GPU target/device, and globals snapshot segments to JIT cache keys.
- Refactors
JitArgumenthandling to require explicit cache signatures and annotation-driven construction. - Adds/updates unit tests for cache-key completeness and target key shape.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
python/flydsl/compiler/jit_function.py |
Expands cache-key construction and dependency/global tracking. |
python/flydsl/compiler/jit_argument.py |
Adjusts JitArgument conversion dispatch to prefer annotated argument types. |
tests/unit/test_jit_cache_key_completeness.py |
Adds regression coverage for env drift, globals drift, device id, and cache signature requirements. |
tests/unit/test_compile_hints.py |
Updates target cache-key assertions for (GPUTarget, device_id). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
coderfeli
left a comment
There was a problem hiding this comment.
I found two issues that seem worth addressing before merging:
- This may regress the launch hot path.
_build_full_cache_key()runs before checking_call_state_cache, and_resolve_and_make_cache_key()now constructs the registered JitArgument for rawtorch.Tensorarguments in order to computecache_signature(). That means even a fully warmedCallStatecache hit still creates aTensorAdaptorand goes through DLPack/adaptor setup just to build the key.
The CallState fast path still avoids DLPack for argument packing/execution, but it no longer avoids this cache-key construction cost. Since this PR is also intended to speed up key computation, can we restore a lightweight tensor metadata signature path, e.g. TensorAdaptor.cache_signature_from_tensor() / the old raw-signature behavior, and only build the full adaptor on miss/compile?
- There seems to be a drift-detection hole for class-bound/inherited JIT methods.
_global_refs_cacheand_globals_prefix_cacheare keyed byowner_cls, but_used_global_valsis shared across all owners. If the same inheritedJitFunctionis first called onBaseand later onSub, andSuboverrides a helper that references an additional global, that new global is not in the original baseline._check_globals_drift()then skips it via_NOT_IN_BASELINE, while_globals_prefix_cache[Sub]can still be memoized from the firstSubcall.
After that, mutating the Sub-specific global may neither raise nor update the cached _globals_ key segment, allowing stale reuse. I think _used_global_vals should be keyed by owner_cls as well, matching the refs/prefix caches, or missing/new refs should invalidate the owner-specific globals prefix / raise.
Co-authored-by: Cursor <cursoragent@cursor.com>
Per-call cache key (jit_function.py):
jit_argument.py:
Tests:
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist