Skip to content

[Refactor] Harden and speed up the JIT cache-key computation#597

Open
sjfeng1999 wants to merge 5 commits into
mainfrom
pr/enh-cache-key
Open

[Refactor] Harden and speed up the JIT cache-key computation#597
sjfeng1999 wants to merge 5 commits into
mainfrom
pr/enh-cache-key

Conversation

@sjfeng1999
Copy link
Copy Markdown
Collaborator

Per-call cache key (jit_function.py):

  • Require cache_signature on every JitArgument; drop the str(ir.Type) fallback so unknown types raise instead of silently colliding under one key.
  • Fold module globals into the key: a cross-process-stable snapshot plus Triton-style in-process drift detection that raises on change (no auto-recompile). Composite (name, module) identity so same-named globals across modules don't collide.
  • target = (GPUTarget, device_id), read live per call so device switches and arch/env changes participate in the key (device_id via the active DeviceRuntime).
  • Whitelisted code-gen env vars re-read live into the key (os._Environ._data fast path with a public-API fallback).
  • Performance: memoize the recursive global-ref discovery and the globals key segment; per call only re-snapshots values and runs a lean drift loop.

jit_argument.py:

  • Construct JitArgument-annotated params (e.g. Stream) via the annotation; remove the int+Stream special-case and the type fallback.

Tests:

  • test_jit_cache_key_completeness.py: env drift, globals snapshot in key, globals drift raises, required cache_signature protocol.
  • test_compile_hints.py: target entry is now (GPUTarget, device_id).

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Per-call cache key (jit_function.py):
- Require __cache_signature__ on every JitArgument; drop the str(ir.Type)
  fallback so unknown types raise instead of silently colliding under one key.
- Fold module globals into the key: a cross-process-stable snapshot plus
  Triton-style in-process drift detection that raises on change (no
  auto-recompile). Composite (name, module) identity so same-named globals
  across modules don't collide.
- target = (GPUTarget, device_id), read live per call so device switches and
  arch/env changes participate in the key (device_id via the active
  DeviceRuntime).
- Whitelisted code-gen env vars re-read live into the key (os._Environ._data
  fast path with a public-API fallback).
- Performance: memoize the recursive global-ref discovery and the globals key
  segment; per call only re-snapshots values and runs a lean drift loop.

jit_argument.py:
- Construct JitArgument-annotated params (e.g. Stream) via the annotation;
  remove the int+Stream special-case and the type fallback.

Tests:
- test_jit_cache_key_completeness.py: env drift, globals snapshot in key,
  globals drift raises, required __cache_signature__ protocol.
- test_compile_hints.py: _target_ entry is now (GPUTarget, device_id).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR hardens JIT cache-key construction so generated artifacts better reflect codegen-affecting inputs such as environment, target/device, and referenced globals.

Changes:

  • Adds env-var, GPU target/device, and globals snapshot segments to JIT cache keys.
  • Refactors JitArgument handling to require explicit cache signatures and annotation-driven construction.
  • Adds/updates unit tests for cache-key completeness and target key shape.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
python/flydsl/compiler/jit_function.py Expands cache-key construction and dependency/global tracking.
python/flydsl/compiler/jit_argument.py Adjusts JitArgument conversion dispatch to prefer annotated argument types.
tests/unit/test_jit_cache_key_completeness.py Adds regression coverage for env drift, globals drift, device id, and cache signature requirements.
tests/unit/test_compile_hints.py Updates target cache-key assertions for (GPUTarget, device_id).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread python/flydsl/compiler/jit_function.py Outdated
Comment thread tests/unit/test_jit_cache_key_completeness.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

Comment thread python/flydsl/compiler/jit_function.py Outdated
Copy link
Copy Markdown
Collaborator

@coderfeli coderfeli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found two issues that seem worth addressing before merging:

  1. This may regress the launch hot path. _build_full_cache_key() runs before checking _call_state_cache, and _resolve_and_make_cache_key() now constructs the registered JitArgument for raw torch.Tensor arguments in order to compute cache_signature(). That means even a fully warmed CallState cache hit still creates a TensorAdaptor and goes through DLPack/adaptor setup just to build the key.

The CallState fast path still avoids DLPack for argument packing/execution, but it no longer avoids this cache-key construction cost. Since this PR is also intended to speed up key computation, can we restore a lightweight tensor metadata signature path, e.g. TensorAdaptor.cache_signature_from_tensor() / the old raw-signature behavior, and only build the full adaptor on miss/compile?

  1. There seems to be a drift-detection hole for class-bound/inherited JIT methods. _global_refs_cache and _globals_prefix_cache are keyed by owner_cls, but _used_global_vals is shared across all owners. If the same inherited JitFunction is first called on Base and later on Sub, and Sub overrides a helper that references an additional global, that new global is not in the original baseline. _check_globals_drift() then skips it via _NOT_IN_BASELINE, while _globals_prefix_cache[Sub] can still be memoized from the first Sub call.

After that, mutating the Sub-specific global may neither raise nor update the cached _globals_ key segment, allowing stale reuse. I think _used_global_vals should be keyed by owner_cls as well, matching the refs/prefix caches, or missing/new refs should invalidate the owner-specific globals prefix / raise.

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants