Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 224 additions & 0 deletions cmd/cli/search/backend_resolution.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
package search

import (
"context"
"errors"
"strings"

"github.com/docker/model-runner/pkg/distribution/files"
distributionhf "github.com/docker/model-runner/pkg/distribution/huggingface"
"github.com/docker/model-runner/pkg/distribution/oci"
"github.com/docker/model-runner/pkg/distribution/registry"
disttypes "github.com/docker/model-runner/pkg/distribution/types"
"golang.org/x/sync/errgroup"
)

const (
backendUnknown = "unknown"

backendLlamaCpp = "llama.cpp"
backendVLLM = "vllm"
backendDiffusers = "diffusers"

defaultBackendResolveConcurrency = 4
)

type backendResolver interface {
Resolve(ctx context.Context, target string) (string, error)
}

type registryBackendResolver struct {
lookup func(ctx context.Context, reference string) (disttypes.ModelArtifact, error)
}

func newRegistryBackendResolver() *registryBackendResolver {
client := registry.NewClient()
return &registryBackendResolver{
lookup: client.Model,
}
}

func (r *registryBackendResolver) Resolve(ctx context.Context, target string) (string, error) {
model, err := r.lookup(ctx, withDefaultTag(target))
if err != nil {
return backendUnknown, err
}

config, configErr := model.Config()
if configErr == nil {
if backend := backendFromFormat(config.GetFormat()); backend != backendUnknown {
return backend, nil
}
}

manifest, manifestErr := model.Manifest()
if manifestErr != nil {
if configErr != nil {
return backendUnknown, errors.Join(configErr, manifestErr)
}
return backendUnknown, manifestErr
}

if backend := backendFromManifestLayers(manifest); backend != backendUnknown {
return backend, nil
}

if configErr != nil {
return backendUnknown, configErr
}

return backendUnknown, nil
}

type huggingFaceRepoBackendResolver struct {
listFiles func(ctx context.Context, repo, revision string) ([]distributionhf.RepoFile, error)
}

func newHuggingFaceRepoBackendResolver() *huggingFaceRepoBackendResolver {
client := distributionhf.NewClient(distributionhf.WithUserAgent(registry.DefaultUserAgent))
return &huggingFaceRepoBackendResolver{
listFiles: client.ListFiles,
}
}

func (r *huggingFaceRepoBackendResolver) Resolve(ctx context.Context, target string) (string, error) {
repoFiles, err := r.listFiles(ctx, target, "main")
if err != nil {
return backendUnknown, err
}
return backendFromRepoFiles(repoFiles), nil
}

func backendFromFormat(format disttypes.Format) string {
switch format {
case disttypes.FormatGGUF:
return backendLlamaCpp
case disttypes.FormatSafetensors:
return backendVLLM
case disttypes.FormatDiffusers:
return backendDiffusers
default:
return backendUnknown
}
}

func backendFromManifestLayers(manifest *oci.Manifest) string {
if manifest == nil {
return backendUnknown
}

var backends []string
for _, layer := range manifest.Layers {
//nolint:exhaustive // only backend-relevant media types affect search classification
switch layer.MediaType {
case disttypes.MediaTypeGGUF:
backends = append(backends, backendLlamaCpp)
case disttypes.MediaTypeSafetensors:
backends = append(backends, backendVLLM)
case disttypes.MediaTypeDDUF:
backends = append(backends, backendDiffusers)
default:
continue
}
}

return joinBackends(backends...)
}

func backendFromRepoFiles(repoFiles []distributionhf.RepoFile) string {
var backends []string
for _, repoFile := range repoFiles {
if repoFile.Type != "file" {
continue
}

//nolint:exhaustive // only model weight file types affect search classification
switch files.Classify(repoFile.Filename()) {
case files.FileTypeGGUF:
backends = append(backends, backendLlamaCpp)
case files.FileTypeSafetensors:
backends = append(backends, backendVLLM)
case files.FileTypeDDUF:
backends = append(backends, backendDiffusers)
default:
continue
}
}

return joinBackends(backends...)
}

func resolveSearchResultBackends(
ctx context.Context,
results []SearchResult,
resolveConcurrency int,
resolve func(context.Context, SearchResult) (string, error),
) []SearchResult {
if len(results) == 0 {
return results
}

if resolveConcurrency <= 0 {
resolveConcurrency = defaultBackendResolveConcurrency
}

resolved := append([]SearchResult(nil), results...)
group, workerCtx := errgroup.WithContext(ctx)
group.SetLimit(resolveConcurrency)

for i := range resolved {
group.Go(func() error {
Comment on lines +169 to +170
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Loop variable i is captured by the goroutine closure, which can lead to races and incorrect backend assignments.

Because Go reuses the loop variable across iterations, the closure may see a different i than the one intended when the goroutine runs, so it can read/write the wrong resolved entry. Capture i in a new variable inside the loop, e.g.

for i := range resolved {
    i := i // capture
    group.Go(func() error {
        backend, err := resolve(workerCtx, resolved[i])
        if err != nil || backend == "" {
            resolved[i].Backend = backendUnknown
            return nil
        }
        resolved[i].Backend = backend
        return nil
    })
}

Alternatively, take the pointer once and close over that: for i := range resolved { res := &resolved[i]; group.Go(func() error { /* use res */ }) }.

backend, err := resolve(workerCtx, resolved[i])
if err != nil || backend == "" {
resolved[i].Backend = backendUnknown
return nil
}
resolved[i].Backend = backend
return nil
})
}
Comment on lines +169 to +179
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a race condition in this loop due to the capture of the loop variable i in the goroutine's closure. When resolveConcurrency is greater than 1, multiple goroutines will likely operate on the same, incorrect index of the resolved slice, leading to incorrect results and some search results not having their backends resolved.

This is a classic Go pitfall for for loops with goroutines. To fix this, you should create a new variable within the loop scope to capture the correct value of i for each iteration.

The associated tests in dockerhub_test.go and huggingface_test.go use resolveConcurrency: 1, which prevents this race condition from being triggered. It would be beneficial to add a test case with a higher concurrency level to catch such issues in the future.

Suggested change
for i := range resolved {
group.Go(func() error {
backend, err := resolve(workerCtx, resolved[i])
if err != nil || backend == "" {
resolved[i].Backend = backendUnknown
return nil
}
resolved[i].Backend = backend
return nil
})
}
for i := range resolved {
i := i // Capture loop variable
group.Go(func() error {
backend, err := resolve(workerCtx, resolved[i])
if err != nil || backend == "" {
resolved[i].Backend = backendUnknown
return nil
}
resolved[i].Backend = backend
return nil
})
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this project uses Go 1.22+ (go 1.26.0 in go.mod), loop variables are scoped per-iteration, so the for i := range resolved closure is safe and there's no race condition here.
In fact, adding i := i would trigger the copyloopvar linter rule, which flags unnecessary copies of loop variables in Go 1.22+.
That said, I've added a TestResolveSearchResultBackendsConcurrent test that exercises resolveSearchResultBackends with high concurrency (20 results, full parallelism) to explicitly verify correct concurrent behavior and provide better test coverage.


_ = group.Wait()
return resolved
}

func joinBackends(backends ...string) string {
seen := map[string]bool{}
for _, backend := range backends {
if backend == "" || backend == backendUnknown {
continue
}
seen[backend] = true
}

ordered := []string{
backendLlamaCpp,
backendVLLM,
backendDiffusers,
}

var unique []string
for _, backend := range ordered {
if seen[backend] {
unique = append(unique, backend)
}
}

if len(unique) == 0 {
return backendUnknown
}

return strings.Join(unique, ", ")
}

func withDefaultTag(reference string) string {
lastSlash := strings.LastIndex(reference, "/")
lastColon := strings.LastIndex(reference, ":")
lastDigest := strings.LastIndex(reference, "@")

if lastColon > lastSlash || lastDigest > lastSlash {
return reference
}

return reference + ":latest"
}
Loading
Loading