Skip to content

[WIP][DO NOT MERGE] XNNPACK BYOC backend for Relax CPU inference#19580

Draft
mshr-h wants to merge 18 commits into
apache:mainfrom
mshr-h:xnnpack-byoc
Draft

[WIP][DO NOT MERGE] XNNPACK BYOC backend for Relax CPU inference#19580
mshr-h wants to merge 18 commits into
apache:mainfrom
mshr-h:xnnpack-byoc

Conversation

@mshr-h
Copy link
Copy Markdown
Contributor

@mshr-h mshr-h commented May 17, 2026

Just experimenting.

Benchmark on NVIDIA DGX Spark:

Model Partitions Baseline mean ms XNNPACK mean ms Speedup
xnnpack_tiny_cnn 4 0.002500 0.003049 0.820x
xnnpack_static_qs8_tiny_cnn 2 0.002587 0.001711 1.512x
xnnpack_large_cnn_fp32 5 0.412256 0.044388 9.288x
xnnpack_large_mlp_fp32 4 0.148401 0.018079 8.209x
xnnpack_large_qs8_cnn 2 0.003716 0.002309 1.609x
torchvision:mobilenet_v2 17 90.415 91.823 0.985x
torchvision:mobilenet_v3_small 62 16.420 17.165 0.957x
torchvision:resnet18 9 845.077 845.838 0.999x

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates XNNPACK as a Relax BYOC backend, enabling support for static-shape float32 and quantized CNN subgraphs. The changes encompass CMake build support, Relax pattern registration with a cost-based partitioning policy, TFLite frontend updates for QDQ models, and a JSON-based external codegen. Reviewers identified multiple typos in the Python implementation where tvm.tirx was incorrectly used instead of tvm.tir, and suggested replacing a hardcoded float literal in the C++ codegen with std::numeric_limits<float>::max() for improved robustness.


shape = []
for dim in sinfo.shape.values:
if not isinstance(dim, (tvm.tirx.expr.IntImm, int)):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The submodule tvm.tirx does not exist in the standard TVM library. This appears to be a typo for tvm.tir.

Suggested change
if not isinstance(dim, (tvm.tirx.expr.IntImm, int)):
if not isinstance(dim, (tvm.tir.IntImm, int)):



def _symbol_name(dim) -> str | None:
if isinstance(dim, (tvm.tirx.expr.IntImm, int)):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The submodule tvm.tirx does not exist in the standard TVM library. This appears to be a typo for tvm.tir.

Suggested change
if isinstance(dim, (tvm.tirx.expr.IntImm, int)):
if isinstance(dim, (tvm.tir.IntImm, int)):

result: list[int | None] = []
symbol: str | None = None
for index, dim in enumerate(dims):
if isinstance(dim, (tvm.tirx.expr.IntImm, int)):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The submodule tvm.tirx does not exist in the standard TVM library. This appears to be a typo for tvm.tir.

Suggested change
if isinstance(dim, (tvm.tirx.expr.IntImm, int)):
if isinstance(dim, (tvm.tir.IntImm, int)):

if not isinstance(expr, relax.PrimValue):
return None
value = expr.value
if isinstance(value, tvm.tirx.expr.FloatImm):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The submodule tvm.tirx does not exist in the standard TVM library. This appears to be a typo for tvm.tir.

Suggested change
if isinstance(value, tvm.tirx.expr.FloatImm):
if isinstance(value, tvm.tir.FloatImm):

value = expr.value
if isinstance(value, tvm.tirx.expr.FloatImm):
return float(value.value)
if isinstance(value, tvm.tirx.expr.IntImm):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The submodule tvm.tirx does not exist in the standard TVM library. This appears to be a typo for tvm.tir.

Suggested change
if isinstance(value, tvm.tirx.expr.IntImm):
if isinstance(value, tvm.tir.IntImm):

}

private:
static constexpr double kXNNPACKInfinity = 3.4028234663852886e38;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a hardcoded literal for the maximum float value is less robust than using standard library constants. Consider using std::numeric_limits<float>::max().

Suggested change
static constexpr double kXNNPACKInfinity = 3.4028234663852886e38;
static constexpr double kXNNPACKInfinity = std::numeric_limits<float>::max();

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant