diff --git a/app/common/billing.go b/app/common/billing.go index 678140f8cd..7d280df5a8 100644 --- a/app/common/billing.go +++ b/app/common/billing.go @@ -212,7 +212,7 @@ func NewBillingRegistry( if err != nil { return BillingRegistry{}, err } - subscriptionSyncService, err := NewBillingSubscriptionSyncService(logger, subscriptionServices, billingRegistry, subscriptionSyncAdapter, tracer, creditsConfig, featureGate) + subscriptionSyncService, err := NewBillingSubscriptionSyncService(logger, subscriptionServices, billingRegistry, subscriptionSyncAdapter, tracer, creditsConfig, fsConfig, featureGate) if err != nil { return BillingRegistry{}, err } @@ -271,6 +271,7 @@ func NewBillingCollector(logger *slog.Logger, billingRegistry BillingRegistry, f BillingService: billingRegistry.Billing, Logger: logger, LockedNamespaces: fs.NamespaceLockdown, + MaxLinesPerInvoice: fs.MaxLinesPerCollectedInvoice, }) } @@ -289,14 +290,15 @@ func NewBillingSubscriptionSyncAdapter(db *entdb.Client) (subscriptionsync.Adapt }) } -func NewBillingSubscriptionSyncService(logger *slog.Logger, subsServices SubscriptionServiceWithWorkflow, billingRegistry BillingRegistry, subscriptionSyncAdapter subscriptionsync.Adapter, tracer trace.Tracer, creditsConfig config.CreditsConfiguration, featureGate *featuregate.FeatureGateChecker) (subscriptionsync.Service, error) { +func NewBillingSubscriptionSyncService(logger *slog.Logger, subsServices SubscriptionServiceWithWorkflow, billingRegistry BillingRegistry, subscriptionSyncAdapter subscriptionsync.Adapter, tracer trace.Tracer, creditsConfig config.CreditsConfiguration, billingFsConfig config.BillingFeatureSwitchesConfiguration, featureGate *featuregate.FeatureGateChecker) (subscriptionsync.Service, error) { return subscriptionsyncservice.New(subscriptionsyncservice.Config{ SubscriptionService: subsServices.Service, BillingService: billingRegistry.Billing, ChargesService: billingRegistry.ChargesServiceOrNil(), SubscriptionSyncAdapter: subscriptionSyncAdapter, FeatureFlags: subscriptionsyncservice.FeatureFlags{ - EnableCreditThenInvoice: creditsConfig.EnableCreditThenInvoice, + EnableCreditThenInvoice: creditsConfig.EnableCreditThenInvoice, + MaxLinesPerCollectedInvoice: billingFsConfig.MaxLinesPerCollectedInvoice, }, Logger: logger, Tracer: tracer, diff --git a/app/config/billing.go b/app/config/billing.go index 008e5feaf8..3f36f6dd83 100644 --- a/app/config/billing.go +++ b/app/config/billing.go @@ -7,6 +7,7 @@ import ( "github.com/spf13/viper" "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/pkg/models" ) type BillingConfiguration struct { @@ -35,10 +36,18 @@ func (c BillingConfiguration) Validate() error { type BillingFeatureSwitchesConfiguration struct { NamespaceLockdown []string + // MaxLinesPerCollectedInvoice is the maximum number of lines that can be collected for a single invoice, 0 means no limit. + MaxLinesPerCollectedInvoice int } func (c BillingFeatureSwitchesConfiguration) Validate() error { - return nil + var errs []error + + if c.MaxLinesPerCollectedInvoice < 0 { + errs = append(errs, errors.New("maxLinesPerCollectedInvoice must not be negative")) + } + + return models.NewNillableGenericValidationError(errors.Join(errs...)) } func ConfigureBilling(v *viper.Viper, flags *pflag.FlagSet) { @@ -50,4 +59,5 @@ func ConfigureBilling(v *viper.Viper, flags *pflag.FlagSet) { _ = v.BindPFlag("billing.advancementStrategy", flags.Lookup("billing-advancement-strategy")) v.SetDefault("billing.advancementStrategy", billing.ForegroundAdvancementStrategy) v.SetDefault("billing.maxParallelQuantitySnapshots", 4) + v.SetDefault("billing.featureSwitches.maxLinesPerCollectedInvoice", 0) } diff --git a/app/config/billing_test.go b/app/config/billing_test.go new file mode 100644 index 0000000000..3dc76f084c --- /dev/null +++ b/app/config/billing_test.go @@ -0,0 +1,29 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBillingFeatureSwitchesConfigurationValidate(t *testing.T) { + t.Run("zero max collected invoice lines is valid", func(t *testing.T) { + require.NoError(t, BillingFeatureSwitchesConfiguration{ + MaxLinesPerCollectedInvoice: 0, + }.Validate()) + }) + + t.Run("positive max collected invoice lines is valid", func(t *testing.T) { + require.NoError(t, BillingFeatureSwitchesConfiguration{ + MaxLinesPerCollectedInvoice: 10, + }.Validate()) + }) + + t.Run("negative max collected invoice lines is invalid", func(t *testing.T) { + err := BillingFeatureSwitchesConfiguration{ + MaxLinesPerCollectedInvoice: -1, + }.Validate() + + require.ErrorContains(t, err, "maxLinesPerCollectedInvoice must not be negative") + }) +} diff --git a/cmd/billing-worker/wire_gen.go b/cmd/billing-worker/wire_gen.go index 801af8541c..dcd11d4390 100644 --- a/cmd/billing-worker/wire_gen.go +++ b/cmd/billing-worker/wire_gen.go @@ -374,7 +374,7 @@ func initializeApplication(ctx context.Context, conf config.Configuration) (Appl cleanup() return Application{}, nil, err } - subscriptionsyncService, err := common.NewBillingSubscriptionSyncService(logger, subscriptionServiceWithWorkflow, billingRegistry, subscriptionsyncAdapter, tracer, creditsConfiguration, featureGateChecker) + subscriptionsyncService, err := common.NewBillingSubscriptionSyncService(logger, subscriptionServiceWithWorkflow, billingRegistry, subscriptionsyncAdapter, tracer, creditsConfiguration, billingFeatureSwitchesConfiguration, featureGateChecker) if err != nil { cleanup7() cleanup6() diff --git a/cmd/jobs/internal/wire_gen.go b/cmd/jobs/internal/wire_gen.go index 96e178bc04..c5b19bf0af 100644 --- a/cmd/jobs/internal/wire_gen.go +++ b/cmd/jobs/internal/wire_gen.go @@ -449,7 +449,7 @@ func initializeApplication(ctx context.Context, conf config.Configuration) (Appl cleanup() return Application{}, nil, err } - subscriptionsyncService, err := common.NewBillingSubscriptionSyncService(logger, subscriptionServiceWithWorkflow, billingRegistry, subscriptionsyncAdapter, tracer, creditsConfiguration, featureGateChecker) + subscriptionsyncService, err := common.NewBillingSubscriptionSyncService(logger, subscriptionServiceWithWorkflow, billingRegistry, subscriptionsyncAdapter, tracer, creditsConfiguration, billingFeatureSwitchesConfiguration, featureGateChecker) if err != nil { cleanup7() cleanup6() diff --git a/config.example.yaml b/config.example.yaml index a3dfc1349a..fe31f5469d 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -96,6 +96,10 @@ server: billing: # for production deployments it's recommended to use queued for server only # advancementStrategy: foreground + # featureSwitches: + # # 0 collects all eligible pending invoice lines. Positive values cap each + # # collected standard invoice to the earliest N lines by service period start. + # maxLinesPerCollectedInvoice: 0 credits: enabled: true diff --git a/openmeter/billing/invoice.go b/openmeter/billing/invoice.go index 6b86df3cfe..b0dc572210 100644 --- a/openmeter/billing/invoice.go +++ b/openmeter/billing/invoice.go @@ -481,6 +481,9 @@ func (i InvoicePendingLinesInput) Validate() error { type InvoicePendingLinesOptions struct { BypassCollectionAlignment bool + // MaxLinesPerInvoice caps the number of pending lines collected into a single invoice. + // 0 means no limit. + MaxLinesPerInvoice int // PartialInvoiceLinesEnabled overrides the billing profile's progressive billing setting // for this invocation: @@ -508,6 +511,12 @@ func WithBypassCollectionAlignment() InvoicePendingLinesOption { } } +func WithMaxLinesPerInvoice(maxLines int) InvoicePendingLinesOption { + return func(o *InvoicePendingLinesOptions) { + o.MaxLinesPerInvoice = maxLines + } +} + func WithPartialInvoiceLinesDisabled() InvoicePendingLinesOption { return func(o *InvoicePendingLinesOptions) { o.PartialInvoiceLinesEnabled = lo.ToPtr(false) diff --git a/openmeter/billing/service/gatheringinvoicependinglines.go b/openmeter/billing/service/gatheringinvoicependinglines.go index 1b4894228e..6dcd120f9e 100644 --- a/openmeter/billing/service/gatheringinvoicependinglines.go +++ b/openmeter/billing/service/gatheringinvoicependinglines.go @@ -17,6 +17,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/productcatalog/feature" "github.com/openmeterio/openmeter/openmeter/streaming" "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/cmpx" "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/slicesx" "github.com/openmeterio/openmeter/pkg/timeutil" @@ -137,6 +138,12 @@ func (s *Service) prepareBillableLines(ctx context.Context, input billing.Prepar return nil, fmt.Errorf("resolving collection asOf: %w", err) } + if options.MaxLinesPerInvoice < 0 { + return nil, billing.ValidationError{ + Err: errors.New("max lines per invoice must not be negative"), + } + } + // let's fetch the existing gathering invoices for the customer existingGatheringInvoices, err := s.adapter.ListGatheringInvoices(ctx, billing.ListGatheringInvoicesInput{ Namespaces: []string{input.Customer.Namespace}, @@ -204,6 +211,16 @@ func (s *Service) prepareBillableLines(ctx context.Context, input billing.Prepar continue } + if input.IncludePendingLines.IsPresent() && options.MaxLinesPerInvoice > 0 && len(inScopeLines) > options.MaxLinesPerInvoice { + return nil, billing.ValidationError{ + Err: fmt.Errorf("include pending lines exceeds max lines per invoice: requested %d, limit %d", len(inScopeLines), options.MaxLinesPerInvoice), + } + } + + if !input.IncludePendingLines.IsPresent() { + inScopeLines = limitGatheringLinesForInvoice(inScopeLines, options.MaxLinesPerInvoice) + } + // Step 1: Let's make sure we have lines properly split on the gathering invoice. // Invariant: the gathering invoice is updated to contain the new lines if any were split. prepareResults, err := s.prepareLinesToBill(ctx, prepareLinesToBillInput{ @@ -425,6 +442,27 @@ func (s *Service) gatherInScopeLines(ctx context.Context, in gatherInScopeLineIn return res, nil } +func limitGatheringLinesForInvoice(lines []gatheringLineWithBillablePeriod, maxLines int) []gatheringLineWithBillablePeriod { + if maxLines <= 0 || len(lines) <= maxLines { + return lines + } + + out := slices.Clone(lines) + slices.SortFunc(out, func(a, b gatheringLineWithBillablePeriod) int { + if result := cmpx.Compare(a.Line.ServicePeriod.From, b.Line.ServicePeriod.From); result != 0 { + return result + } + + if result := cmpx.Compare(a.Line.ServicePeriod.To, b.Line.ServicePeriod.To); result != 0 { + return result + } + + return strings.Compare(a.Line.ID, b.Line.ID) + }) + + return out[:maxLines] +} + type hasInvoicableLinesInput struct { Invoice billing.GatheringInvoice AsOf time.Time diff --git a/openmeter/billing/service/gatheringinvoicependinglines_test.go b/openmeter/billing/service/gatheringinvoicependinglines_test.go index 898f774c7c..7c0f13bacb 100644 --- a/openmeter/billing/service/gatheringinvoicependinglines_test.go +++ b/openmeter/billing/service/gatheringinvoicependinglines_test.go @@ -8,8 +8,44 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/pkg/datetime" + "github.com/openmeterio/openmeter/pkg/timeutil" ) +func TestLimitGatheringLinesForInvoice(t *testing.T) { + line := func(id, from, to string) gatheringLineWithBillablePeriod { + gatheringLine := billing.GatheringLine{} + gatheringLine.ID = id + gatheringLine.ServicePeriod = timeutil.ClosedPeriod{ + From: mustTime(t, from), + To: mustTime(t, to), + } + + return gatheringLineWithBillablePeriod{ + Line: gatheringLine, + BillablePeriod: gatheringLine.ServicePeriod, + } + } + + lines := []gatheringLineWithBillablePeriod{ + line("later", "2025-03-01T00:00:00Z", "2025-04-01T00:00:00Z"), + line("tie-b", "2025-01-01T00:00:00Z", "2025-02-01T00:00:00Z"), + line("earliest", "2024-12-01T00:00:00Z", "2025-01-01T00:00:00Z"), + line("tie-a", "2025-01-01T00:00:00Z", "2025-02-01T00:00:00Z"), + } + + t.Run("zero keeps all lines without reordering", func(t *testing.T) { + got := limitGatheringLinesForInvoice(lines, 0) + + require.Equal(t, lines, got) + }) + + t.Run("positive limit keeps earliest service periods", func(t *testing.T) { + got := limitGatheringLinesForInvoice(lines, 3) + + require.Equal(t, []string{"earliest", "tie-a", "tie-b"}, gatheringLineIDsForLimitTest(got)) + }) +} + func TestResolvePendingLineCollectionCutoff(t *testing.T) { asOf := mustTime(t, "2025-06-15T12:00:00Z") anchor := mustTime(t, "2025-06-01T00:00:00Z") @@ -157,3 +193,12 @@ func mustTime(t *testing.T, value string) time.Time { return parsed } + +func gatheringLineIDsForLimitTest(lines []gatheringLineWithBillablePeriod) []string { + ids := make([]string, 0, len(lines)) + for _, line := range lines { + ids = append(ids, line.Line.ID) + } + + return ids +} diff --git a/openmeter/billing/worker/collect/collect.go b/openmeter/billing/worker/collect/collect.go index a237b84eff..6bb5e564fb 100644 --- a/openmeter/billing/worker/collect/collect.go +++ b/openmeter/billing/worker/collect/collect.go @@ -17,9 +17,10 @@ import ( ) type InvoiceCollector struct { - gatheringInvoices billing.GatheringInvoiceService - billingService billing.Service - lockedNamespaces []string + gatheringInvoices billing.GatheringInvoiceService + billingService billing.Service + lockedNamespaces []string + maxLinesPerInvoice int logger *slog.Logger } @@ -111,6 +112,7 @@ func (a *InvoiceCollector) CollectCustomerInvoice(ctx context.Context, params Co }, // We want to make sure that system collection does not use progressive billing. billing.WithPartialInvoiceLinesDisabled(), + billing.WithMaxLinesPerInvoice(a.maxLinesPerInvoice), ) if err != nil { if errors.Is(err, billing.ErrNamespaceLocked) { @@ -218,6 +220,7 @@ type Config struct { BillingService billing.Service Logger *slog.Logger LockedNamespaces []string + MaxLinesPerInvoice int } func NewInvoiceCollector(config Config) (*InvoiceCollector, error) { @@ -234,9 +237,10 @@ func NewInvoiceCollector(config Config) (*InvoiceCollector, error) { } return &InvoiceCollector{ - gatheringInvoices: config.GatheringInvoiceService, - billingService: config.BillingService, - logger: config.Logger, - lockedNamespaces: config.LockedNamespaces, + gatheringInvoices: config.GatheringInvoiceService, + billingService: config.BillingService, + logger: config.Logger, + lockedNamespaces: config.LockedNamespaces, + maxLinesPerInvoice: config.MaxLinesPerInvoice, }, nil } diff --git a/openmeter/billing/worker/subscriptionsync/service/service.go b/openmeter/billing/worker/subscriptionsync/service/service.go index 7ecd3e45a0..a5ebf4b32a 100644 --- a/openmeter/billing/worker/subscriptionsync/service/service.go +++ b/openmeter/billing/worker/subscriptionsync/service/service.go @@ -22,6 +22,7 @@ type FeatureFlags struct { EnableFlatFeeInAdvanceProrating bool EnableFlatFeeInArrearsProrating bool EnableCreditThenInvoice bool + MaxLinesPerCollectedInvoice int } type Config struct { diff --git a/openmeter/billing/worker/subscriptionsync/service/sync.go b/openmeter/billing/worker/subscriptionsync/service/sync.go index 7261f9b054..26ca7bf457 100644 --- a/openmeter/billing/worker/subscriptionsync/service/sync.go +++ b/openmeter/billing/worker/subscriptionsync/service/sync.go @@ -39,6 +39,7 @@ func (s *Service) invoicePendingLines(ctx context.Context, customer customer.Cus Customer: customer, }, billing.WithPartialInvoiceLinesDisabled(), + billing.WithMaxLinesPerInvoice(s.featureFlags.MaxLinesPerCollectedInvoice), ) if err != nil { if errors.Is(err, billing.ErrInvoiceCreateNoLines) {