diff --git a/registry/model.go b/registry/model.go index 160a13a1..d68e5487 100644 --- a/registry/model.go +++ b/registry/model.go @@ -17,6 +17,7 @@ import ( "github.com/meshery/meshkit/errors" "github.com/meshery/meshkit/files" "github.com/meshery/meshkit/generators" + "github.com/meshery/meshkit/generators/github" "github.com/meshery/meshkit/generators/models" "github.com/meshery/meshkit/models/meshmodel/entity" "github.com/meshery/meshkit/utils" @@ -31,11 +32,86 @@ import ( "github.com/meshery/schemas/models/v1beta1/subcategory" log "github.com/sirupsen/logrus" "golang.org/x/sync/semaphore" + "golang.org/x/sync/singleflight" "google.golang.org/api/sheets/v4" ) var modelToCompGenerateTracker = store.NewGenericThreadSafeStore[compGenerateTracker]() +type generatorFactory func(registrant, url, packageName string) (models.PackageManager, error) + +type packageFetcher struct { + newGenerator generatorFactory + cache sync.Map + fetchGroup singleflight.Group +} + +func newPackageFetcher(newGenerator generatorFactory) *packageFetcher { + return &packageFetcher{ + newGenerator: newGenerator, + } +} + +func packageCacheKey(registrant, sourceURL, modelName string) string { + normalizedRegistrant := utils.ReplaceSpacesAndConvertToLowercase(registrant) + if normalizedRegistrant == artifactHub { + return fmt.Sprintf("%s\x00%s\x00%s", normalizedRegistrant, sourceURL, utils.ReplaceSpacesAndConvertToLowercase(modelName)) + } + + return fmt.Sprintf("%s\x00%s", normalizedRegistrant, sourceURL) +} + +// GitHub packages derive generated component metadata from the model name, so +// reuse the fetched content but return a per-model copy with the requested name. +func packageForModel(registrant, modelName string, pkg models.Package) models.Package { + if utils.ReplaceSpacesAndConvertToLowercase(registrant) != gitHub { + return pkg + } + + switch typedPkg := pkg.(type) { + case github.GitHubPackage: + typedPkg.Name = modelName + return typedPkg + case *github.GitHubPackage: + clonedPkg := *typedPkg + clonedPkg.Name = modelName + return &clonedPkg + default: + return pkg + } +} + +func (pf *packageFetcher) getPackage(registrant, sourceURL, modelName string) (models.Package, error) { + cacheKey := packageCacheKey(registrant, sourceURL, modelName) + if cachedPkg, ok := pf.cache.Load(cacheKey); ok { + return packageForModel(registrant, modelName, cachedPkg.(models.Package)), nil + } + + fetchedPkg, err, _ := pf.fetchGroup.Do(cacheKey, func() (interface{}, error) { + generator, err := pf.newGenerator(registrant, sourceURL, modelName) + if err != nil { + return nil, err + } + + if utils.ReplaceSpacesAndConvertToLowercase(registrant) == artifactHub { + RateLimitArtifactHub() + } + + pkg, err := generator.GetPackage() + if err != nil { + return nil, err + } + + pf.cache.Store(cacheKey, pkg) + return pkg, nil + }) + if err != nil { + return nil, err + } + + return packageForModel(registrant, modelName, fetchedPkg.(models.Package)), nil +} + type compGenerateTracker struct { totalComps int version string @@ -800,6 +876,7 @@ func InvokeGenerationFromSheet(wg *sync.WaitGroup, path string, modelsheetID, co // - Latest version only filtering func InvokeGenerationFromSheetWithOptions(wg *sync.WaitGroup, path string, modelsheetID, componentSheetID int64, spreadsheeetID string, modelName string, modelCSVFilePath, componentCSVFilePath, spreadsheeetCred, relationshipCSVFilePath string, relationshipSheetID int64, srv *sheets.Service, opts GenerationOptions) error { weightedSem := semaphore.NewWeighted(20) + packageFetcher := newPackageFetcher(generators.NewGenerator) url := GoogleSpreadSheetURL + spreadsheeetID totalAvailableModels := 0 spreadsheeetChan := make(chan SpreadsheetData) @@ -924,19 +1001,8 @@ func InvokeGenerationFromSheetWithOptions(wg *sync.WaitGroup, path string, model } Log.Debug(fmt.Sprintf("Model %s: Creating generator for registrant: %s, source: %s", model.Model, model.Registrant, model.SourceURL)) - - generator, genErr := generators.NewGenerator(model.Registrant, model.SourceURL, model.Model) - if genErr != nil { - done <- ErrGenerateModel(genErr, model.Model) - return - } - - if utils.ReplaceSpacesAndConvertToLowercase(model.Registrant) == "artifacthub" { - RateLimitArtifactHub() - } - Log.Debug(fmt.Sprintf("Model %s: Fetching package from source", model.Model)) - pkg, genErr := generator.GetPackage() + pkg, genErr := packageFetcher.getPackage(model.Registrant, model.SourceURL, model.Model) if genErr != nil { done <- ErrGenerateModel(genErr, model.Model) return diff --git a/registry/model_generation_test.go b/registry/model_generation_test.go index b9d36cf6..9b863273 100644 --- a/registry/model_generation_test.go +++ b/registry/model_generation_test.go @@ -4,12 +4,46 @@ import ( "context" "fmt" "sync" + "sync/atomic" "testing" "time" + artifacthubgen "github.com/meshery/meshkit/generators/artifacthub" + githubgen "github.com/meshery/meshkit/generators/github" + "github.com/meshery/meshkit/generators/models" "github.com/stretchr/testify/assert" ) +type stubPackageManager struct { + pkg models.Package + callCount *atomic.Int32 + delay time.Duration +} + +func (spm stubPackageManager) GetPackage() (models.Package, error) { + spm.callCount.Add(1) + if spm.delay > 0 { + time.Sleep(spm.delay) + } + return spm.pkg, nil +} + +func stubPackageForRegistrant(registrant, url, packageName string) models.Package { + switch registrant { + case "artifacthub": + return artifacthubgen.AhPackage{ + Name: fmt.Sprintf("%s:%s", registrant, packageName), + ChartUrl: url, + Version: "v1.0.0", + } + default: + return githubgen.GitHubPackage{ + Name: packageName, + SourceURL: url, + } + } +} + func TestGenerationOptionsTimeoutBehavior(t *testing.T) { // Test that timeout value is respected when set tests := []struct { @@ -44,6 +78,103 @@ func TestGenerationOptionsTimeoutBehavior(t *testing.T) { } } +func TestPackageFetcherCachesGitHubPackagesByRegistrantAndSourceURL(t *testing.T) { + t.Parallel() + + callCount := &atomic.Int32{} + fetcher := newPackageFetcher(func(registrant, url, packageName string) (models.PackageManager, error) { + return stubPackageManager{ + pkg: stubPackageForRegistrant(registrant, url, packageName), + callCount: callCount, + }, nil + }) + + firstPkg, err := fetcher.getPackage("github", "https://example.com/aso.yaml", "azure-network") + assert.NoError(t, err) + + secondPkg, err := fetcher.getPackage("github", "https://example.com/aso.yaml", "azure-compute") + assert.NoError(t, err) + + assert.EqualValues(t, 1, callCount.Load()) + assert.Equal(t, "azure-network", firstPkg.GetName()) + assert.Equal(t, "azure-compute", secondPkg.GetName()) +} + +func TestPackageFetcherDoesNotShareArtifactHubPackagesAcrossModelNames(t *testing.T) { + t.Parallel() + + callCount := &atomic.Int32{} + fetcher := newPackageFetcher(func(registrant, url, packageName string) (models.PackageManager, error) { + return stubPackageManager{ + pkg: stubPackageForRegistrant(registrant, url, packageName), + callCount: callCount, + }, nil + }) + + _, err := fetcher.getPackage("artifacthub", "https://example.com/shared.yaml", "azure-network") + assert.NoError(t, err) + + _, err = fetcher.getPackage("artifacthub", "https://example.com/shared.yaml", "azure-compute") + assert.NoError(t, err) + + assert.EqualValues(t, 2, callCount.Load()) +} + +func TestPackageFetcherDoesNotShareAcrossRegistrants(t *testing.T) { + t.Parallel() + + callCount := &atomic.Int32{} + fetcher := newPackageFetcher(func(registrant, url, packageName string) (models.PackageManager, error) { + return stubPackageManager{ + pkg: stubPackageForRegistrant(registrant, url, packageName), + callCount: callCount, + }, nil + }) + + _, err := fetcher.getPackage("github", "https://example.com/shared.yaml", "azure-network") + assert.NoError(t, err) + + _, err = fetcher.getPackage("artifacthub", "https://example.com/shared.yaml", "azure-network") + assert.NoError(t, err) + + assert.EqualValues(t, 2, callCount.Load()) +} + +func TestPackageFetcherDeduplicatesConcurrentGitHubRequests(t *testing.T) { + t.Parallel() + + callCount := &atomic.Int32{} + fetcher := newPackageFetcher(func(registrant, url, packageName string) (models.PackageManager, error) { + return stubPackageManager{ + pkg: stubPackageForRegistrant(registrant, url, packageName), + callCount: callCount, + delay: 25 * time.Millisecond, + }, nil + }) + + modelNames := []string{ + "azure-network", + "azure-compute", + "azure-storage", + "azure-network", + "azure-compute", + "azure-storage", + } + var wg sync.WaitGroup + for _, modelName := range modelNames { + wg.Add(1) + go func(modelName string) { + defer wg.Done() + pkg, err := fetcher.getPackage("github", "https://example.com/aso.yaml", modelName) + assert.NoError(t, err) + assert.Equal(t, modelName, pkg.GetName()) + }(modelName) + } + wg.Wait() + + assert.EqualValues(t, 1, callCount.Load()) +} + func TestProgressTrackerIntegration(t *testing.T) { // Simulate a model generation workflow totalModels := 50