diff --git a/.agents/skills/billing/SKILL.md b/.agents/skills/billing/SKILL.md index fcb41084dc..9e197ee0d9 100644 --- a/.agents/skills/billing/SKILL.md +++ b/.agents/skills/billing/SKILL.md @@ -29,6 +29,10 @@ openmeter/billing/worker/asyncadvance/ # Event-driven advance handler test/billing/ # Shared test suite base (BaseSuite, SubscriptionMixin) ``` +## Currency Boundary + +Billing invoices, invoice lines, split-line groups, and standard detailed lines use fiat invoice currencies only. Do not widen billing invoice currency columns or treat custom/non-fiat credit units as invoice currency. Convert or materialize custom-unit economics before creating billing invoice artifacts; billing should only persist the fiat money-of-account as `currency`. + ## Core Type Patterns ### Union Types (Invoice, InvoiceLine) diff --git a/AGENTS.md b/AGENTS.md index 75321dcdb2..11df97324f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -227,6 +227,7 @@ For TypeSpec-specific coding constraints, update `api/spec/AGENTS.md` instead of - Do not extract helper functions only to hide a couple of simple operations or short guard checks. If the helper would only wrap 2-4 lines and its name does not add meaningful domain or business intent, keep the code inline even when there is some duplication. Readers can inspect the function body to see what the code does; prefer function names that explain the domain reason for the call over names that merely restate the implementation steps. When you encounter a leftover pass-through wrapper that only calls another function without adding behavior, remove it and call the underlying function directly, even if it is outside the immediate change area. - Do not hide non-trivial branching or domain translation inside local inline functions. If a closure performs type switching, validation, persistence mapping, or meaningful domain conversion, make it a named helper near the code that uses it so it is discoverable, testable, and grep-friendly. Reserve inline closures for tiny callbacks where the surrounding API requires a function literal and the logic is obvious at the call site. - For `Validate() error` methods, prefer collecting all validation issues into `var errs []error` and returning `models.NewNillableGenericValidationError(errors.Join(errs...))` instead of returning on the first invalid field. Preserve field context with wrapped errors like `fmt.Errorf("field: %w", err)` and use plain `errors.New(...)` for simple local checks. +- Credit-purchase charges can carry custom ledger currency codes even though standard invoice and settlement currencies stay fiat-only. When persisting credit-purchase charge metadata through shared helpers such as `chargemeta.Create` or `chargemeta.Update`, use the credit-purchase-specific intent validator instead of falling back to generic `meta.Intent.Validate`, which intentionally rejects non-fiat codes. - Do not introduce `context.Background()` or `context.TODO()` to sidestep missing context propagation in application code. Either propagate the caller's context through the full call path, or remove the unused `context.Context` parameter from the API if the operation is purely local and does not need cancellation, deadlines, or request-scoped values. - Never use `panic` in non-test code paths. If a new failure mode is possible, change the function signature to return an error and propagate it explicitly. - In production constructors and initialization, do not use `slog.Default()` as a fallback dependency. Require a `*slog.Logger` in config/provider inputs and inject it explicitly. diff --git a/openmeter/app/stripe/calculator.go b/openmeter/app/stripe/calculator.go index 1a8d6c5f89..f3e4b57aab 100644 --- a/openmeter/app/stripe/calculator.go +++ b/openmeter/app/stripe/calculator.go @@ -17,6 +17,9 @@ func NewStripeCalculator(currency currencyx.Code) (StripeCalculator, error) { if err != nil { return StripeCalculator{}, fmt.Errorf("failed to get stripe calculator: %w", err) } + if calculator.CurrencyType() != currencyx.CurrencyTypeFiat { + return StripeCalculator{}, fmt.Errorf("stripe currency must be a known fiat currency: %s", currency) + } return StripeCalculator{ calculator: calculator, diff --git a/openmeter/billing/charges/creditpurchase/adapter/charge.go b/openmeter/billing/charges/creditpurchase/adapter/charge.go index 280679b7c1..c232354cdf 100644 --- a/openmeter/billing/charges/creditpurchase/adapter/charge.go +++ b/openmeter/billing/charges/creditpurchase/adapter/charge.go @@ -42,6 +42,7 @@ func (a *adapter) UpdateCharge(ctx context.Context, charge creditpurchase.Charge Intent: charge.Intent.Intent, IntentMutableFields: charge.Intent.IntentMutableFields.IntentMutableFields, Status: metaStatus, + ValidateIntent: charge.Intent.Validate, }) if err != nil { return creditpurchase.ChargeBase{}, err @@ -85,6 +86,7 @@ func (a *adapter) CreateCharge(ctx context.Context, in creditpurchase.CreateChar Intent: in.Intent.Intent, IntentMutableFields: in.Intent.IntentMutableFields.IntentMutableFields, Status: metaStatus, + ValidateIntent: in.Intent.Validate, }) if err != nil { return creditpurchase.Charge{}, err diff --git a/openmeter/billing/charges/creditpurchase/charge.go b/openmeter/billing/charges/creditpurchase/charge.go index 80541d9ef1..00e8f8f979 100644 --- a/openmeter/billing/charges/creditpurchase/charge.go +++ b/openmeter/billing/charges/creditpurchase/charge.go @@ -3,11 +3,13 @@ package creditpurchase import ( "errors" "fmt" + "slices" "time" "github.com/alpacahq/alpacadecimal" "github.com/samber/lo" + "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" "github.com/openmeterio/openmeter/openmeter/billing/charges/models/ledgertransaction" "github.com/openmeterio/openmeter/openmeter/billing/charges/models/payment" @@ -217,8 +219,30 @@ func (i Intent) CalculateEffectiveAt() time.Time { func (i Intent) Validate() error { var errs []error - if err := i.Intent.Validate(); err != nil { - errs = append(errs, fmt.Errorf("intent meta: %w", err)) + if !slices.Contains(billing.InvoiceLineManagedBy("").Values(), string(i.ManagedBy)) { + errs = append(errs, fmt.Errorf("intent meta: invalid managed by %s", i.ManagedBy)) + } + + if i.CustomerID == "" { + errs = append(errs, fmt.Errorf("intent meta: customer ID is required")) + } + + if err := i.Currency.ValidateFormat(); err != nil { + errs = append(errs, fmt.Errorf("intent meta: currency: %w", err)) + } + + if err := i.TaxConfig.Validate(); err != nil { + errs = append(errs, fmt.Errorf("intent meta: tax config: %w", err)) + } + + if i.Subscription != nil { + if err := i.Subscription.Validate(); err != nil { + errs = append(errs, fmt.Errorf("intent meta: subscription: %w", err)) + } + } + + if i.UniqueReferenceID != nil && *i.UniqueReferenceID == "" { + errs = append(errs, fmt.Errorf("intent meta: unique reference ID cannot be empty")) } if err := i.IntentMutableFields.Validate(); err != nil { diff --git a/openmeter/billing/charges/creditpurchase/charge_test.go b/openmeter/billing/charges/creditpurchase/charge_test.go index d54c5b906e..035395013c 100644 --- a/openmeter/billing/charges/creditpurchase/charge_test.go +++ b/openmeter/billing/charges/creditpurchase/charge_test.go @@ -7,6 +7,8 @@ import ( "github.com/stretchr/testify/require" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" + "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/timeutil" ) @@ -71,3 +73,18 @@ func TestFeatureFiltersValidateAsFeatureFilter(t *testing.T) { require.Error(t, FeatureFilters([]string{""}).ValidateAsFeatureFilter()) }) } + +func TestListFundedCreditActivitiesInputValidateAllowsCustomCurrency(t *testing.T) { + currency := currencyx.Code("CREDITS") + + input := ListFundedCreditActivitiesInput{ + Customer: customer.CustomerID{ + Namespace: "ns", + ID: "customer-id", + }, + Limit: 1, + Currency: ¤cy, + } + + require.NoError(t, input.Validate()) +} diff --git a/openmeter/billing/charges/creditpurchase/funded_credit_activity.go b/openmeter/billing/charges/creditpurchase/funded_credit_activity.go index 2f0d603f8a..678f28e417 100644 --- a/openmeter/billing/charges/creditpurchase/funded_credit_activity.go +++ b/openmeter/billing/charges/creditpurchase/funded_credit_activity.go @@ -88,7 +88,7 @@ func (i ListFundedCreditActivitiesInput) Validate() error { } if i.Currency != nil { - if err := i.Currency.Validate(); err != nil { + if err := i.Currency.ValidateFormat(); err != nil { errs = append(errs, fmt.Errorf("currency: %w", err)) } } diff --git a/openmeter/billing/charges/creditpurchase/settlement.go b/openmeter/billing/charges/creditpurchase/settlement.go index 8545c4f415..b2e570ab48 100644 --- a/openmeter/billing/charges/creditpurchase/settlement.go +++ b/openmeter/billing/charges/creditpurchase/settlement.go @@ -43,8 +43,10 @@ type GenericSettlement struct { func (s GenericSettlement) Validate() error { var errs []error - if err := s.Currency.Validate(); err != nil { + if err := s.Currency.ValidateFormat(); err != nil { errs = append(errs, fmt.Errorf("settlement currency: %w", err)) + } else if !s.Currency.IsKnownFiat() { + errs = append(errs, fmt.Errorf("settlement currency must be a known fiat currency")) } if !s.CostBasis.IsPositive() { diff --git a/openmeter/billing/charges/creditpurchase/settlement_test.go b/openmeter/billing/charges/creditpurchase/settlement_test.go index c7385b4b30..a94341696d 100644 --- a/openmeter/billing/charges/creditpurchase/settlement_test.go +++ b/openmeter/billing/charges/creditpurchase/settlement_test.go @@ -50,3 +50,39 @@ func TestGenericSettlementValidateRequiresPositiveCostBasis(t *testing.T) { }) } } + +func TestGenericSettlementValidateRequiresFiatCurrency(t *testing.T) { + for _, tc := range []struct { + name string + currency currencyx.Code + wantErr bool + }{ + { + name: "fiat", + currency: currencyx.Code("USD"), + }, + { + name: "custom", + currency: currencyx.Code("CREDITS"), + wantErr: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + settlement := GenericSettlement{ + Currency: tc.currency, + CostBasis: alpacadecimal.NewFromFloat(0.5), + } + + err := settlement.Validate() + + if tc.wantErr { + require.Error(t, err) + require.ErrorContains(t, err, "settlement currency must be a known fiat currency") + require.True(t, models.IsGenericValidationError(err)) + return + } + + require.NoError(t, err) + }) + } +} diff --git a/openmeter/billing/charges/flatfee/charge_test.go b/openmeter/billing/charges/flatfee/charge_test.go index 270f2fb99c..0e9e383877 100644 --- a/openmeter/billing/charges/flatfee/charge_test.go +++ b/openmeter/billing/charges/flatfee/charge_test.go @@ -166,7 +166,7 @@ func TestCalculateAmountAfterProration(t *testing.T) { t.Run("invalid currency returns error", func(t *testing.T) { intent := baseIntent() - intent.Currency = currencyx.Code("INVALID") + intent.Currency = currencyx.Code("BAD|CODE") _, err := intent.CalculateAmountAfterProration() require.Error(t, err) diff --git a/openmeter/billing/charges/lineage/service.go b/openmeter/billing/charges/lineage/service.go index 57c414f9ad..dab99e11ac 100644 --- a/openmeter/billing/charges/lineage/service.go +++ b/openmeter/billing/charges/lineage/service.go @@ -110,7 +110,7 @@ func (i BackfillAdvanceLineageSegmentsInput) Validate() error { if i.CustomerID == "" { errs = append(errs, errors.New("customer id is required")) } - if err := i.Currency.Validate(); err != nil { + if err := i.Currency.ValidateFormat(); err != nil { errs = append(errs, fmt.Errorf("currency: %w", err)) } if !i.Amount.IsPositive() { @@ -138,7 +138,7 @@ func (i LoadLineagesByCustomerInput) Validate() error { if i.CustomerID == "" { errs = append(errs, errors.New("customer id is required")) } - if err := i.Currency.Validate(); err != nil { + if err := i.Currency.ValidateFormat(); err != nil { errs = append(errs, fmt.Errorf("currency: %w", err)) } diff --git a/openmeter/billing/charges/models/chargemeta/mixin.go b/openmeter/billing/charges/models/chargemeta/mixin.go index 264daa548a..c6f9bfb50c 100644 --- a/openmeter/billing/charges/models/chargemeta/mixin.go +++ b/openmeter/billing/charges/models/chargemeta/mixin.go @@ -62,7 +62,7 @@ func (metaMixin) Fields() []ent.Field { NotEmpty(). Immutable(). SchemaType(map[string]string{ - dialect.Postgres: "varchar(3)", + dialect.Postgres: currencyx.PostgresCodeSchemaType, }), field.Enum("managed_by"). @@ -117,8 +117,9 @@ type CreateInput struct { Intent meta.Intent IntentMutableFields meta.IntentMutableFields - Status meta.ChargeStatus - AdvanceAfter *time.Time + Status meta.ChargeStatus + AdvanceAfter *time.Time + ValidateIntent func() error } type Creator[T any] interface { @@ -170,7 +171,11 @@ func Create[T Creator[T]](creator Creator[T], in CreateInput) (T, error) { in.IntentMutableFields = in.IntentMutableFields.Normalized() in.AdvanceAfter = meta.NormalizeOptionalTimestamp(in.AdvanceAfter) - if err := in.Intent.Validate(); err != nil { + validateIntent := in.Intent.Validate + if in.ValidateIntent != nil { + validateIntent = in.ValidateIntent + } + if err := validateIntent(); err != nil { var empty T return empty, err } @@ -223,8 +228,9 @@ type UpdateInput struct { Intent meta.Intent IntentMutableFields meta.IntentMutableFields - Status meta.ChargeStatus - AdvanceAfter *time.Time + Status meta.ChargeStatus + AdvanceAfter *time.Time + ValidateIntent func() error } func Update[T Updater[T]](updater Updater[T], in UpdateInput) (T, error) { @@ -236,7 +242,11 @@ func Update[T Updater[T]](updater Updater[T], in UpdateInput) (T, error) { return empty, err } - if err := in.Intent.Validate(); err != nil { + validateIntent := in.Intent.Validate + if in.ValidateIntent != nil { + validateIntent = in.ValidateIntent + } + if err := validateIntent(); err != nil { var empty T return empty, err } diff --git a/openmeter/billing/charges/service/creditpurchase_test.go b/openmeter/billing/charges/service/creditpurchase_test.go index 623834ada8..c2598f688a 100644 --- a/openmeter/billing/charges/service/creditpurchase_test.go +++ b/openmeter/billing/charges/service/creditpurchase_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/alpacahq/alpacadecimal" - "github.com/invopop/gobl/currency" "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -97,9 +96,9 @@ func (s *CreditPurchaseTestSuite) TestPromotionalCreditPurchase() { s.Equal(creditpurchase.StatusFinal, updatedCPCharge.Status) } -func (s *CreditPurchaseTestSuite) TestCreditPurchaseRejectsMismatchedSettlementCurrency() { +func (s *CreditPurchaseTestSuite) TestCreditPurchaseRejectsCustomSettlementCurrency() { ctx := context.Background() - ns := s.GetUniqueNamespace("charges-service-credit-purchase-mismatched-settlement-currency") + ns := s.GetUniqueNamespace("charges-service-credit-purchase-custom-settlement-currency") s.ProvisionDefaultTaxCodes(ctx, ns) cust := s.CreateTestCustomer(ns, "test-subject") @@ -119,7 +118,7 @@ func (s *CreditPurchaseTestSuite) TestCreditPurchaseRejectsMismatchedSettlementC settlement: creditpurchase.NewSettlement(creditpurchase.ExternalSettlement{ InitialStatus: creditpurchase.CreatedInitialPaymentSettlementStatus, GenericSettlement: creditpurchase.GenericSettlement{ - Currency: currencyx.Code(currency.EUR), + Currency: currencyx.Code("CREDITS"), CostBasis: alpacadecimal.NewFromFloat(0.5), }, }), @@ -128,19 +127,29 @@ func (s *CreditPurchaseTestSuite) TestCreditPurchaseRejectsMismatchedSettlementC name: "invoice", settlement: creditpurchase.NewSettlement(creditpurchase.InvoiceSettlement{ GenericSettlement: creditpurchase.GenericSettlement{ - Currency: currencyx.Code(currency.EUR), + Currency: currencyx.Code("CREDITS"), CostBasis: alpacadecimal.NewFromFloat(0.5), }, }), }, } { s.Run(tc.name, func() { - intent := CreateCreditPurchaseIntent(s.T(), createCreditPurchaseIntentInput{ - customer: cust.GetID(), - currency: USD, - amount: alpacadecimal.NewFromFloat(100), - servicePeriod: servicePeriod, - settlement: tc.settlement, + intent := charges.NewChargeIntent(creditpurchase.Intent{ + Intent: meta.Intent{ + ManagedBy: billing.ManuallyManagedLine, + CustomerID: cust.ID, + Currency: USD, + }, + IntentMutableFields: creditpurchase.IntentMutableFields{ + IntentMutableFields: meta.IntentMutableFields{ + Name: "Credit Purchase", + ServicePeriod: servicePeriod, + BillingPeriod: servicePeriod, + FullServicePeriod: servicePeriod, + }, + CreditAmount: alpacadecimal.NewFromFloat(100), + Settlement: tc.settlement, + }, }) res, err := s.Charges.Create(ctx, charges.CreateInput{ @@ -150,7 +159,7 @@ func (s *CreditPurchaseTestSuite) TestCreditPurchaseRejectsMismatchedSettlementC }, }) s.Error(err) - s.ErrorContains(err, `settlement currency "EUR" must match credit currency "USD"`) + s.ErrorContains(err, "settlement currency must be a known fiat currency") s.Empty(res) }) } diff --git a/openmeter/billing/charges/service/lineage_test.go b/openmeter/billing/charges/service/lineage_test.go index 2a2c5ae23f..7b950c79c7 100644 --- a/openmeter/billing/charges/service/lineage_test.go +++ b/openmeter/billing/charges/service/lineage_test.go @@ -238,14 +238,15 @@ func (s *CreditRealizationLineageTestSuite) TestBackfillAdvanceLineageSegmentsFi ns := s.GetUniqueNamespace("charges-service-lineage-feature-backfill") customerID := ulid.Make().String() - apiLineageID := s.createAdvanceLineageForBackfill(ctx, ns, customerID, []string{"api-calls"}, alpacadecimal.NewFromInt(40)) - storageLineageID := s.createAdvanceLineageForBackfill(ctx, ns, customerID, []string{"storage"}, alpacadecimal.NewFromInt(30)) + customCurrency := currencyx.Code("CREDITS") + apiLineageID := s.createAdvanceLineageForBackfill(ctx, ns, customerID, customCurrency, []string{"api-calls"}, alpacadecimal.NewFromInt(40)) + storageLineageID := s.createAdvanceLineageForBackfill(ctx, ns, customerID, customCurrency, []string{"storage"}, alpacadecimal.NewFromInt(30)) backingTransactionGroupID := ulid.Make().String() err = service.BackfillAdvanceLineageSegments(ctx, lineage.BackfillAdvanceLineageSegmentsInput{ Namespace: ns, CustomerID: customerID, - Currency: currencyx.Code(currency.USD), + Currency: customCurrency, Amount: alpacadecimal.NewFromInt(50), BackingTransactionGroupID: backingTransactionGroupID, FeatureFilters: []string{"api-calls"}, @@ -439,7 +440,7 @@ func (s *CreditRealizationLineageTestSuite) mustListLineages(namespace string, r return out } -func (s *CreditRealizationLineageTestSuite) createAdvanceLineageForBackfill(ctx context.Context, namespace string, customerID string, advanceFeatures []string, amount alpacadecimal.Decimal) string { +func (s *CreditRealizationLineageTestSuite) createAdvanceLineageForBackfill(ctx context.Context, namespace string, customerID string, currencyCode currencyx.Code, advanceFeatures []string, amount alpacadecimal.Decimal) string { s.T().Helper() chargeID := ulid.Make().String() @@ -458,7 +459,7 @@ func (s *CreditRealizationLineageTestSuite) createAdvanceLineageForBackfill(ctx SetChargeID(chargeID). SetRootRealizationID(ulid.Make().String()). SetCustomerID(customerID). - SetCurrency(currencyx.Code(currency.USD)). + SetCurrency(currencyCode). SetOriginKind(creditrealization.LineageOriginKindAdvance). SetAdvanceFeatures(pq.StringArray(advanceFeatures)). Save(ctx) diff --git a/openmeter/billing/creditgrant/service.go b/openmeter/billing/creditgrant/service.go index ff8ec77598..93813928f1 100644 --- a/openmeter/billing/creditgrant/service.go +++ b/openmeter/billing/creditgrant/service.go @@ -110,6 +110,8 @@ func (i CreateInput) Validate() error { if i.Purchase != nil { if err := i.Purchase.Currency.Validate(); err != nil { errs = append(errs, fmt.Errorf("purchase currency: %w", err)) + } else if i.Purchase.Currency.CurrencyType() != currencyx.CurrencyTypeFiat { + errs = append(errs, errors.New("purchase currency must be a known fiat currency")) } if i.Purchase.PerUnitCostBasis != nil && !i.Purchase.PerUnitCostBasis.IsPositive() { diff --git a/openmeter/billing/gatheringinvoice.go b/openmeter/billing/gatheringinvoice.go index 32ace8506f..ea06566ab7 100644 --- a/openmeter/billing/gatheringinvoice.go +++ b/openmeter/billing/gatheringinvoice.go @@ -51,6 +51,8 @@ func (g GatheringInvoiceBase) Validate() error { if err := g.Currency.Validate(); err != nil { errs = append(errs, err) + } else if g.Currency.CurrencyType() != currencyx.CurrencyTypeFiat { + errs = append(errs, errors.New("currency must be a known fiat currency")) } if err := g.ServicePeriod.Validate(); err != nil { @@ -469,6 +471,8 @@ func (i GatheringLineBase) Validate() error { if err := i.Currency.Validate(); err != nil { errs = append(errs, fmt.Errorf("currency: %w", err)) + } else if i.Currency.CurrencyType() != currencyx.CurrencyTypeFiat { + errs = append(errs, errors.New("currency must be a known fiat currency")) } if !slices.Contains(InvoiceLineManagedBy("").Values(), string(i.ManagedBy)) { @@ -867,6 +871,8 @@ func (c CreatePendingInvoiceLinesInput) Validate() error { if err := c.Currency.Validate(); err != nil { errs = append(errs, fmt.Errorf("currency: %w", err)) + } else if c.Currency.CurrencyType() != currencyx.CurrencyTypeFiat { + errs = append(errs, errors.New("currency must be a known fiat currency")) } for id, line := range c.Lines { @@ -918,6 +924,8 @@ func (c CreateGatheringInvoiceAdapterInput) Validate() error { if err := c.Currency.Validate(); err != nil { errs = append(errs, fmt.Errorf("currency: %w", err)) + } else if c.Currency.CurrencyType() != currencyx.CurrencyTypeFiat { + errs = append(errs, errors.New("currency must be a known fiat currency")) } if c.Number == "" { diff --git a/openmeter/billing/invoicelinesplitgroup.go b/openmeter/billing/invoicelinesplitgroup.go index 028eb452f0..054ff519a8 100644 --- a/openmeter/billing/invoicelinesplitgroup.go +++ b/openmeter/billing/invoicelinesplitgroup.go @@ -90,6 +90,10 @@ func (i SplitLineGroupCreate) Validate() error { if i.Currency == "" { errs = append(errs, errors.New("currency is required")) + } else if err := i.Currency.Validate(); err != nil { + errs = append(errs, fmt.Errorf("currency: %w", err)) + } else if i.Currency.CurrencyType() != currencyx.CurrencyTypeFiat { + errs = append(errs, errors.New("currency must be a known fiat currency")) } if i.UniqueReferenceID != nil && *i.UniqueReferenceID == "" { @@ -146,6 +150,10 @@ func (i SplitLineGroup) Validate() error { if i.Currency == "" { errs = append(errs, errors.New("currency is required")) + } else if err := i.Currency.Validate(); err != nil { + errs = append(errs, fmt.Errorf("currency: %w", err)) + } else if i.Currency.CurrencyType() != currencyx.CurrencyTypeFiat { + errs = append(errs, errors.New("currency must be a known fiat currency")) } return errors.Join(errs...) diff --git a/openmeter/billing/models/stddetailedline/mixin.go b/openmeter/billing/models/stddetailedline/mixin.go index c75c46899c..e2a7feef3f 100644 --- a/openmeter/billing/models/stddetailedline/mixin.go +++ b/openmeter/billing/models/stddetailedline/mixin.go @@ -17,11 +17,55 @@ import ( ) type Mixin struct { - entutils.RecursiveMixin[mixinBase] + mixin.Schema + + CurrencyPostgresSchemaType string +} + +func (m Mixin) base() mixinBase { + return mixinBase{ + currencyPostgresSchemaType: m.CurrencyPostgresSchemaType, + } +} + +func (m Mixin) Fields() []ent.Field { + base := m.base() + fields := base.Fields() + + for _, mixin := range base.Mixin() { + fields = append(fields, mixin.Fields()...) + } + + return fields +} + +func (m Mixin) Indexes() []ent.Index { + base := m.base() + indexes := base.Indexes() + + for _, mixin := range base.Mixin() { + indexes = append(indexes, mixin.Indexes()...) + } + + return indexes +} + +func (m Mixin) Annotations() []schema.Annotation { + return m.base().Annotations() } type mixinBase struct { mixin.Schema + + currencyPostgresSchemaType string +} + +func (m mixinBase) currencySchemaType() string { + if m.currencyPostgresSchemaType != "" { + return m.currencyPostgresSchemaType + } + + return currencyx.PostgresCodeSchemaType } func (mixinBase) Mixin() []ent.Mixin { @@ -32,14 +76,14 @@ func (mixinBase) Mixin() []ent.Mixin { } } -func (mixinBase) Fields() []ent.Field { +func (m mixinBase) Fields() []ent.Field { return []ent.Field{ field.String("currency"). GoType(currencyx.Code("")). NotEmpty(). Immutable(). SchemaType(map[string]string{ - dialect.Postgres: "varchar(3)", + dialect.Postgres: m.currencySchemaType(), }), field.Time("service_period_start"), diff --git a/openmeter/billing/seq.go b/openmeter/billing/seq.go index c95890ff94..edfe22a79a 100644 --- a/openmeter/billing/seq.go +++ b/openmeter/billing/seq.go @@ -76,6 +76,10 @@ func (i SequenceGenerationInput) Validate() error { if i.Currency == "" { return fmt.Errorf("currency is required") + } else if err := i.Currency.Validate(); err != nil { + return fmt.Errorf("currency: %w", err) + } else if i.Currency.CurrencyType() != currencyx.CurrencyTypeFiat { + return fmt.Errorf("currency must be a known fiat currency") } if i.Namespace == "" { diff --git a/openmeter/billing/stdinvoice.go b/openmeter/billing/stdinvoice.go index 2bf7cf36fc..080ff6f2a5 100644 --- a/openmeter/billing/stdinvoice.go +++ b/openmeter/billing/stdinvoice.go @@ -293,6 +293,8 @@ func (i StandardInvoiceBase) Validate() error { if err := i.Currency.Validate(); err != nil { outErr = errors.Join(outErr, ValidationWithFieldPrefix("currency", err)) + } else if i.Currency.CurrencyType() != currencyx.CurrencyTypeFiat { + outErr = errors.Join(outErr, ValidationWithFieldPrefix("currency", errors.New("currency must be a known fiat currency"))) } if err := i.Status.Validate(); err != nil { @@ -775,6 +777,8 @@ func (c CreateInvoiceAdapterInput) Validate() error { if err := c.Currency.Validate(); err != nil { return fmt.Errorf("currency: %w", err) + } else if c.Currency.CurrencyType() != currencyx.CurrencyTypeFiat { + return errors.New("currency must be a known fiat currency") } if err := c.Status.Validate(); err != nil { @@ -958,6 +962,10 @@ func (i SimulateInvoiceInput) Validate() error { if i.Currency == "" { return errors.New("currency is required") + } else if err := i.Currency.Validate(); err != nil { + return fmt.Errorf("currency: %w", err) + } else if i.Currency.CurrencyType() != currencyx.CurrencyTypeFiat { + return errors.New("currency must be a known fiat currency") } if len(i.Lines.OrEmpty()) == 0 { @@ -1174,6 +1182,8 @@ func (i CreateStandardInvoiceFromGatheringLinesInput) Validate() error { if err := i.Currency.Validate(); err != nil { errs = append(errs, fmt.Errorf("currency: %w", err)) + } else if i.Currency.CurrencyType() != currencyx.CurrencyTypeFiat { + errs = append(errs, errors.New("currency must be a known fiat currency")) } if len(i.Lines) == 0 { diff --git a/openmeter/billing/stdinvoiceline.go b/openmeter/billing/stdinvoiceline.go index b0e803e00d..97b210fc7b 100644 --- a/openmeter/billing/stdinvoiceline.go +++ b/openmeter/billing/stdinvoiceline.go @@ -111,8 +111,10 @@ func (i StandardLineBase) Validate() error { errs = append(errs, errors.New("name is required")) } - if err := i.Currency.Validate(); err != nil { + if err := i.Currency.ValidateFormat(); err != nil { errs = append(errs, fmt.Errorf("currency: %w", err)) + } else if !i.Currency.IsKnownFiat() { + errs = append(errs, errors.New("currency must be a known fiat currency")) } if !slices.Contains(InvoiceLineManagedBy("").Values(), string(i.ManagedBy)) { diff --git a/openmeter/billing/stdinvoiceline_test.go b/openmeter/billing/stdinvoiceline_test.go index 6ae7b3a840..9a624e85d7 100644 --- a/openmeter/billing/stdinvoiceline_test.go +++ b/openmeter/billing/stdinvoiceline_test.go @@ -34,6 +34,13 @@ func TestStandardLineValidateRejectsNegativeTotals(t *testing.T) { require.ErrorContains(t, line.Validate(), "totals: total is negative") } +func TestStandardLineValidateRejectsCustomCurrency(t *testing.T) { + line := validStandardLineForValidation() + line.Currency = currencyx.Code("CREDITS") + + require.ErrorContains(t, line.Validate(), "currency must be a known fiat currency") +} + func TestStandardLineValidateAllowsNegativeDetailedLineQuantityWithPositiveTotal(t *testing.T) { line := validStandardLineForValidation() line.Totals.Total = alpacadecimal.NewFromInt(1) diff --git a/openmeter/currencies/models.go b/openmeter/currencies/models.go index d2c6b04d30..9c587405ec 100644 --- a/openmeter/currencies/models.go +++ b/openmeter/currencies/models.go @@ -7,6 +7,7 @@ import ( "github.com/alpacahq/alpacadecimal" + "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/filter" "github.com/openmeterio/openmeter/pkg/models" "github.com/openmeterio/openmeter/pkg/pagination" @@ -79,21 +80,11 @@ func (i ListCurrenciesInput) Validate() error { return errors.Join(errs...) } -// CurrencyType distinguishes custom currencies from ISO/fiat ones. -type CurrencyType string - -func (t CurrencyType) Validate() error { - switch t { - case CurrencyTypeCustom, CurrencyTypeFiat: - return nil - default: - return fmt.Errorf("currency type: %s", t) - } -} +type CurrencyType = currencyx.CurrencyType const ( - CurrencyTypeCustom CurrencyType = "custom" - CurrencyTypeFiat CurrencyType = "fiat" + CurrencyTypeCustom = currencyx.CurrencyTypeCustom + CurrencyTypeFiat = currencyx.CurrencyTypeFiat ) var _ models.Validator = (*CreateCurrencyInput)(nil) @@ -114,6 +105,8 @@ func (i CreateCurrencyInput) Validate() error { if i.Code == "" { errs = append(errs, errors.New("code is required")) + } else if err := currencyx.Code(i.Code).ValidateCustom(); err != nil { + errs = append(errs, fmt.Errorf("code: %w", err)) } if i.Name == "" { @@ -161,6 +154,8 @@ func (i CreateCostBasisInput) Validate() error { if i.FiatCode == "" { errs = append(errs, errors.New("fiat_code is required")) + } else if err := currencyx.Code(i.FiatCode).Validate(); err != nil { + errs = append(errs, fmt.Errorf("fiat_code: %w", err)) } if !i.Rate.IsPositive() { @@ -191,5 +186,13 @@ func (i ListCostBasesInput) Validate() error { errs = append(errs, errors.New("currency_id is required")) } + if i.FilterFiatCode != nil { + if *i.FilterFiatCode == "" { + errs = append(errs, errors.New("filter_fiat_code is required")) + } else if err := currencyx.Code(*i.FilterFiatCode).Validate(); err != nil { + errs = append(errs, fmt.Errorf("filter_fiat_code: %w", err)) + } + } + return errors.Join(errs...) } diff --git a/openmeter/currencies/models_test.go b/openmeter/currencies/models_test.go new file mode 100644 index 0000000000..8f5101cd7f --- /dev/null +++ b/openmeter/currencies/models_test.go @@ -0,0 +1,70 @@ +package currencies_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/openmeterio/openmeter/openmeter/currencies" +) + +func TestCreateCurrencyInputValidate(t *testing.T) { + valid := currencies.CreateCurrencyInput{ + Namespace: "ns", + Code: "CREDITS", + Name: "Credits", + Symbol: "cr", + } + + tests := []struct { + name string + input currencies.CreateCurrencyInput + wantErr string + }{ + { + name: "valid", + input: valid, + }, + { + name: "fiat code collision", + input: currencies.CreateCurrencyInput{ + Namespace: "ns", + Code: "USD", + Name: "Credits", + Symbol: "cr", + }, + wantErr: "conflicts with fiat currency code", + }, + { + name: "invalid structural code", + input: currencies.CreateCurrencyInput{ + Namespace: "ns", + Code: "BAD|CODE", + Name: "Credits", + Symbol: "cr", + }, + wantErr: "currency code cannot contain route delimiter", + }, + { + name: "missing code", + input: currencies.CreateCurrencyInput{ + Namespace: "ns", + Name: "Credits", + Symbol: "cr", + }, + wantErr: "code is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.input.Validate() + if tt.wantErr == "" { + require.NoError(t, err) + return + } + + require.ErrorContains(t, err, tt.wantErr) + }) + } +} diff --git a/openmeter/ent/db/ledgersubaccountroute.go b/openmeter/ent/db/ledgersubaccountroute.go index 399fe2f08c..1fc43364fc 100644 --- a/openmeter/ent/db/ledgersubaccountroute.go +++ b/openmeter/ent/db/ledgersubaccountroute.go @@ -14,6 +14,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db/ledgeraccount" "github.com/openmeterio/openmeter/openmeter/ent/db/ledgersubaccountroute" "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/pkg/currencyx" ) // LedgerSubAccountRoute is the model entity for the LedgerSubAccountRoute schema. @@ -37,6 +38,8 @@ type LedgerSubAccountRoute struct { RoutingKey string `json:"routing_key,omitempty"` // Currency holds the value of the "currency" field. Currency string `json:"currency,omitempty"` + // Source holds the value of the "source" field. + Source *currencyx.Code `json:"source,omitempty"` // TaxCode holds the value of the "tax_code" field. TaxCode *string `json:"tax_code,omitempty"` // TaxBehavior holds the value of the "tax_behavior" field. @@ -97,7 +100,7 @@ func (*LedgerSubAccountRoute) scanValues(columns []string) ([]any, error) { values[i] = new(pq.StringArray) case ledgersubaccountroute.FieldCreditPriority: values[i] = new(sql.NullInt64) - case ledgersubaccountroute.FieldID, ledgersubaccountroute.FieldNamespace, ledgersubaccountroute.FieldAccountID, ledgersubaccountroute.FieldRoutingKeyVersion, ledgersubaccountroute.FieldRoutingKey, ledgersubaccountroute.FieldCurrency, ledgersubaccountroute.FieldTaxCode, ledgersubaccountroute.FieldTaxBehavior, ledgersubaccountroute.FieldTransactionAuthorizationStatus: + case ledgersubaccountroute.FieldID, ledgersubaccountroute.FieldNamespace, ledgersubaccountroute.FieldAccountID, ledgersubaccountroute.FieldRoutingKeyVersion, ledgersubaccountroute.FieldRoutingKey, ledgersubaccountroute.FieldCurrency, ledgersubaccountroute.FieldSource, ledgersubaccountroute.FieldTaxCode, ledgersubaccountroute.FieldTaxBehavior, ledgersubaccountroute.FieldTransactionAuthorizationStatus: values[i] = new(sql.NullString) case ledgersubaccountroute.FieldCreatedAt, ledgersubaccountroute.FieldUpdatedAt, ledgersubaccountroute.FieldDeletedAt: values[i] = new(sql.NullTime) @@ -171,6 +174,13 @@ func (_m *LedgerSubAccountRoute) assignValues(columns []string, values []any) er } else if value.Valid { _m.Currency = value.String } + case ledgersubaccountroute.FieldSource: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field source", values[i]) + } else if value.Valid { + _m.Source = new(currencyx.Code) + *_m.Source = currencyx.Code(value.String) + } case ledgersubaccountroute.FieldTaxCode: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field tax_code", values[i]) @@ -284,6 +294,11 @@ func (_m *LedgerSubAccountRoute) String() string { builder.WriteString("currency=") builder.WriteString(_m.Currency) builder.WriteString(", ") + if v := _m.Source; v != nil { + builder.WriteString("source=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") if v := _m.TaxCode; v != nil { builder.WriteString("tax_code=") builder.WriteString(*v) diff --git a/openmeter/ent/db/ledgersubaccountroute/ledgersubaccountroute.go b/openmeter/ent/db/ledgersubaccountroute/ledgersubaccountroute.go index d92cd6881b..cd17171d1f 100644 --- a/openmeter/ent/db/ledgersubaccountroute/ledgersubaccountroute.go +++ b/openmeter/ent/db/ledgersubaccountroute/ledgersubaccountroute.go @@ -30,6 +30,8 @@ const ( FieldRoutingKey = "routing_key" // FieldCurrency holds the string denoting the currency field in the database. FieldCurrency = "currency" + // FieldSource holds the string denoting the source field in the database. + FieldSource = "source" // FieldTaxCode holds the string denoting the tax_code field in the database. FieldTaxCode = "tax_code" // FieldTaxBehavior holds the string denoting the tax_behavior field in the database. @@ -75,6 +77,7 @@ var Columns = []string{ FieldRoutingKeyVersion, FieldRoutingKey, FieldCurrency, + FieldSource, FieldTaxCode, FieldTaxBehavior, FieldFeatures, @@ -154,6 +157,11 @@ func ByCurrency(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCurrency, opts...).ToFunc() } +// BySource orders the results by the source field. +func BySource(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSource, opts...).ToFunc() +} + // ByTaxCode orders the results by the tax_code field. func ByTaxCode(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTaxCode, opts...).ToFunc() diff --git a/openmeter/ent/db/ledgersubaccountroute/where.go b/openmeter/ent/db/ledgersubaccountroute/where.go index 30bc3b4362..e5b6ab731c 100644 --- a/openmeter/ent/db/ledgersubaccountroute/where.go +++ b/openmeter/ent/db/ledgersubaccountroute/where.go @@ -11,6 +11,7 @@ import ( "github.com/lib/pq" "github.com/openmeterio/openmeter/openmeter/ent/db/predicate" "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/pkg/currencyx" ) // ID filters vertices based on their ID field. @@ -109,6 +110,12 @@ func Currency(v string) predicate.LedgerSubAccountRoute { return predicate.LedgerSubAccountRoute(sql.FieldEQ(FieldCurrency, v)) } +// Source applies equality check predicate on the "source" field. It's identical to SourceEQ. +func Source(v currencyx.Code) predicate.LedgerSubAccountRoute { + vc := string(v) + return predicate.LedgerSubAccountRoute(sql.FieldEQ(FieldSource, vc)) +} + // TaxCode applies equality check predicate on the "tax_code" field. It's identical to TaxCodeEQ. func TaxCode(v string) predicate.LedgerSubAccountRoute { return predicate.LedgerSubAccountRoute(sql.FieldEQ(FieldTaxCode, v)) @@ -615,6 +622,100 @@ func CurrencyContainsFold(v string) predicate.LedgerSubAccountRoute { return predicate.LedgerSubAccountRoute(sql.FieldContainsFold(FieldCurrency, v)) } +// SourceEQ applies the EQ predicate on the "source" field. +func SourceEQ(v currencyx.Code) predicate.LedgerSubAccountRoute { + vc := string(v) + return predicate.LedgerSubAccountRoute(sql.FieldEQ(FieldSource, vc)) +} + +// SourceNEQ applies the NEQ predicate on the "source" field. +func SourceNEQ(v currencyx.Code) predicate.LedgerSubAccountRoute { + vc := string(v) + return predicate.LedgerSubAccountRoute(sql.FieldNEQ(FieldSource, vc)) +} + +// SourceIn applies the In predicate on the "source" field. +func SourceIn(vs ...currencyx.Code) predicate.LedgerSubAccountRoute { + v := make([]any, len(vs)) + for i := range v { + v[i] = string(vs[i]) + } + return predicate.LedgerSubAccountRoute(sql.FieldIn(FieldSource, v...)) +} + +// SourceNotIn applies the NotIn predicate on the "source" field. +func SourceNotIn(vs ...currencyx.Code) predicate.LedgerSubAccountRoute { + v := make([]any, len(vs)) + for i := range v { + v[i] = string(vs[i]) + } + return predicate.LedgerSubAccountRoute(sql.FieldNotIn(FieldSource, v...)) +} + +// SourceGT applies the GT predicate on the "source" field. +func SourceGT(v currencyx.Code) predicate.LedgerSubAccountRoute { + vc := string(v) + return predicate.LedgerSubAccountRoute(sql.FieldGT(FieldSource, vc)) +} + +// SourceGTE applies the GTE predicate on the "source" field. +func SourceGTE(v currencyx.Code) predicate.LedgerSubAccountRoute { + vc := string(v) + return predicate.LedgerSubAccountRoute(sql.FieldGTE(FieldSource, vc)) +} + +// SourceLT applies the LT predicate on the "source" field. +func SourceLT(v currencyx.Code) predicate.LedgerSubAccountRoute { + vc := string(v) + return predicate.LedgerSubAccountRoute(sql.FieldLT(FieldSource, vc)) +} + +// SourceLTE applies the LTE predicate on the "source" field. +func SourceLTE(v currencyx.Code) predicate.LedgerSubAccountRoute { + vc := string(v) + return predicate.LedgerSubAccountRoute(sql.FieldLTE(FieldSource, vc)) +} + +// SourceContains applies the Contains predicate on the "source" field. +func SourceContains(v currencyx.Code) predicate.LedgerSubAccountRoute { + vc := string(v) + return predicate.LedgerSubAccountRoute(sql.FieldContains(FieldSource, vc)) +} + +// SourceHasPrefix applies the HasPrefix predicate on the "source" field. +func SourceHasPrefix(v currencyx.Code) predicate.LedgerSubAccountRoute { + vc := string(v) + return predicate.LedgerSubAccountRoute(sql.FieldHasPrefix(FieldSource, vc)) +} + +// SourceHasSuffix applies the HasSuffix predicate on the "source" field. +func SourceHasSuffix(v currencyx.Code) predicate.LedgerSubAccountRoute { + vc := string(v) + return predicate.LedgerSubAccountRoute(sql.FieldHasSuffix(FieldSource, vc)) +} + +// SourceIsNil applies the IsNil predicate on the "source" field. +func SourceIsNil() predicate.LedgerSubAccountRoute { + return predicate.LedgerSubAccountRoute(sql.FieldIsNull(FieldSource)) +} + +// SourceNotNil applies the NotNil predicate on the "source" field. +func SourceNotNil() predicate.LedgerSubAccountRoute { + return predicate.LedgerSubAccountRoute(sql.FieldNotNull(FieldSource)) +} + +// SourceEqualFold applies the EqualFold predicate on the "source" field. +func SourceEqualFold(v currencyx.Code) predicate.LedgerSubAccountRoute { + vc := string(v) + return predicate.LedgerSubAccountRoute(sql.FieldEqualFold(FieldSource, vc)) +} + +// SourceContainsFold applies the ContainsFold predicate on the "source" field. +func SourceContainsFold(v currencyx.Code) predicate.LedgerSubAccountRoute { + vc := string(v) + return predicate.LedgerSubAccountRoute(sql.FieldContainsFold(FieldSource, vc)) +} + // TaxCodeEQ applies the EQ predicate on the "tax_code" field. func TaxCodeEQ(v string) predicate.LedgerSubAccountRoute { return predicate.LedgerSubAccountRoute(sql.FieldEQ(FieldTaxCode, v)) diff --git a/openmeter/ent/db/ledgersubaccountroute_create.go b/openmeter/ent/db/ledgersubaccountroute_create.go index 4ebc10b0c3..ffb060ffdc 100644 --- a/openmeter/ent/db/ledgersubaccountroute_create.go +++ b/openmeter/ent/db/ledgersubaccountroute_create.go @@ -18,6 +18,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db/ledgersubaccount" "github.com/openmeterio/openmeter/openmeter/ent/db/ledgersubaccountroute" "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/pkg/currencyx" ) // LedgerSubAccountRouteCreate is the builder for creating a LedgerSubAccountRoute entity. @@ -100,6 +101,20 @@ func (_c *LedgerSubAccountRouteCreate) SetCurrency(v string) *LedgerSubAccountRo return _c } +// SetSource sets the "source" field. +func (_c *LedgerSubAccountRouteCreate) SetSource(v currencyx.Code) *LedgerSubAccountRouteCreate { + _c.mutation.SetSource(v) + return _c +} + +// SetNillableSource sets the "source" field if the given value is not nil. +func (_c *LedgerSubAccountRouteCreate) SetNillableSource(v *currencyx.Code) *LedgerSubAccountRouteCreate { + if v != nil { + _c.SetSource(*v) + } + return _c +} + // SetTaxCode sets the "tax_code" field. func (_c *LedgerSubAccountRouteCreate) SetTaxCode(v string) *LedgerSubAccountRouteCreate { _c.mutation.SetTaxCode(v) @@ -292,6 +307,11 @@ func (_c *LedgerSubAccountRouteCreate) check() error { if _, ok := _c.mutation.Currency(); !ok { return &ValidationError{Name: "currency", err: errors.New(`db: missing required field "LedgerSubAccountRoute.currency"`)} } + if v, ok := _c.mutation.Source(); ok { + if err := v.Validate(); err != nil { + return &ValidationError{Name: "source", err: fmt.Errorf(`db: validator failed for field "LedgerSubAccountRoute.source": %w`, err)} + } + } if v, ok := _c.mutation.TaxBehavior(); ok { if err := v.Validate(); err != nil { return &ValidationError{Name: "tax_behavior", err: fmt.Errorf(`db: validator failed for field "LedgerSubAccountRoute.tax_behavior": %w`, err)} @@ -369,6 +389,10 @@ func (_c *LedgerSubAccountRouteCreate) createSpec() (*LedgerSubAccountRoute, *sq _spec.SetField(ledgersubaccountroute.FieldCurrency, field.TypeString, value) _node.Currency = value } + if value, ok := _c.mutation.Source(); ok { + _spec.SetField(ledgersubaccountroute.FieldSource, field.TypeString, value) + _node.Source = &value + } if value, ok := _c.mutation.TaxCode(); ok { _spec.SetField(ledgersubaccountroute.FieldTaxCode, field.TypeString, value) _node.TaxCode = &value @@ -543,6 +567,9 @@ func (u *LedgerSubAccountRouteUpsertOne) UpdateNewValues() *LedgerSubAccountRout if _, exists := u.create.mutation.Currency(); exists { s.SetIgnore(ledgersubaccountroute.FieldCurrency) } + if _, exists := u.create.mutation.Source(); exists { + s.SetIgnore(ledgersubaccountroute.FieldSource) + } if _, exists := u.create.mutation.TaxCode(); exists { s.SetIgnore(ledgersubaccountroute.FieldTaxCode) } @@ -828,6 +855,9 @@ func (u *LedgerSubAccountRouteUpsertBulk) UpdateNewValues() *LedgerSubAccountRou if _, exists := b.mutation.Currency(); exists { s.SetIgnore(ledgersubaccountroute.FieldCurrency) } + if _, exists := b.mutation.Source(); exists { + s.SetIgnore(ledgersubaccountroute.FieldSource) + } if _, exists := b.mutation.TaxCode(); exists { s.SetIgnore(ledgersubaccountroute.FieldTaxCode) } diff --git a/openmeter/ent/db/ledgersubaccountroute_update.go b/openmeter/ent/db/ledgersubaccountroute_update.go index b99d73d6bf..e9cadc7b49 100644 --- a/openmeter/ent/db/ledgersubaccountroute_update.go +++ b/openmeter/ent/db/ledgersubaccountroute_update.go @@ -161,6 +161,9 @@ func (_u *LedgerSubAccountRouteUpdate) sqlSave(ctx context.Context) (_node int, if _u.mutation.DeletedAtCleared() { _spec.ClearField(ledgersubaccountroute.FieldDeletedAt, field.TypeTime) } + if _u.mutation.SourceCleared() { + _spec.ClearField(ledgersubaccountroute.FieldSource, field.TypeString) + } if _u.mutation.TaxCodeCleared() { _spec.ClearField(ledgersubaccountroute.FieldTaxCode, field.TypeString) } @@ -406,6 +409,9 @@ func (_u *LedgerSubAccountRouteUpdateOne) sqlSave(ctx context.Context) (_node *L if _u.mutation.DeletedAtCleared() { _spec.ClearField(ledgersubaccountroute.FieldDeletedAt, field.TypeTime) } + if _u.mutation.SourceCleared() { + _spec.ClearField(ledgersubaccountroute.FieldSource, field.TypeString) + } if _u.mutation.TaxCodeCleared() { _spec.ClearField(ledgersubaccountroute.FieldTaxCode, field.TypeString) } diff --git a/openmeter/ent/db/migrate/schema.go b/openmeter/ent/db/migrate/schema.go index 4ca0592974..e732b771d7 100644 --- a/openmeter/ent/db/migrate/schema.go +++ b/openmeter/ent/db/migrate/schema.go @@ -1629,7 +1629,7 @@ var ( {Name: "full_service_period_to", Type: field.TypeTime}, {Name: "status", Type: field.TypeEnum, Enums: []string{"created", "active", "final", "deleted"}}, {Name: "unique_reference_id", Type: field.TypeString, Nullable: true}, - {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(3)"}}, + {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(24)"}}, {Name: "managed_by", Type: field.TypeEnum, Enums: []string{"subscription", "system", "manual"}}, {Name: "advance_after", Type: field.TypeTime, Nullable: true}, {Name: "tax_behavior", Type: field.TypeEnum, Nullable: true, Enums: []string{"inclusive", "exclusive"}}, @@ -1915,7 +1915,7 @@ var ( {Name: "full_service_period_to", Type: field.TypeTime}, {Name: "status", Type: field.TypeEnum, Enums: []string{"created", "active", "final", "deleted"}}, {Name: "unique_reference_id", Type: field.TypeString, Nullable: true}, - {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(3)"}}, + {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(24)"}}, {Name: "managed_by", Type: field.TypeEnum, Enums: []string{"subscription", "system", "manual"}}, {Name: "advance_after", Type: field.TypeTime, Nullable: true}, {Name: "tax_behavior", Type: field.TypeEnum, Nullable: true, Enums: []string{"inclusive", "exclusive"}}, @@ -2240,7 +2240,7 @@ var ( // ChargeFlatFeeRunDetailedLinesColumns holds the columns for the "charge_flat_fee_run_detailed_lines" table. ChargeFlatFeeRunDetailedLinesColumns = []*schema.Column{ {Name: "id", Type: field.TypeString, Unique: true, SchemaType: map[string]string{"postgres": "char(26)"}}, - {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(3)"}}, + {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(24)"}}, {Name: "service_period_start", Type: field.TypeTime}, {Name: "service_period_end", Type: field.TypeTime}, {Name: "quantity", Type: field.TypeOther, SchemaType: map[string]string{"postgres": "numeric"}}, @@ -2459,7 +2459,7 @@ var ( {Name: "full_service_period_to", Type: field.TypeTime}, {Name: "status", Type: field.TypeEnum, Enums: []string{"created", "active", "final", "deleted"}}, {Name: "unique_reference_id", Type: field.TypeString, Nullable: true}, - {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(3)"}}, + {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(24)"}}, {Name: "managed_by", Type: field.TypeEnum, Enums: []string{"subscription", "system", "manual"}}, {Name: "advance_after", Type: field.TypeTime, Nullable: true}, {Name: "tax_behavior", Type: field.TypeEnum, Nullable: true, Enums: []string{"inclusive", "exclusive"}}, @@ -2706,7 +2706,7 @@ var ( // ChargeUsageBasedRunDetailedLineColumns holds the columns for the "charge_usage_based_run_detailed_line" table. ChargeUsageBasedRunDetailedLineColumns = []*schema.Column{ {Name: "id", Type: field.TypeString, Unique: true, SchemaType: map[string]string{"postgres": "char(26)"}}, - {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(3)"}}, + {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(24)"}}, {Name: "service_period_start", Type: field.TypeTime}, {Name: "service_period_end", Type: field.TypeTime}, {Name: "quantity", Type: field.TypeOther, SchemaType: map[string]string{"postgres": "numeric"}}, @@ -3004,7 +3004,7 @@ var ( {Name: "namespace", Type: field.TypeString}, {Name: "root_realization_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "char(26)"}}, {Name: "customer_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "char(26)"}}, - {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(3)"}}, + {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(24)"}}, {Name: "origin_kind", Type: field.TypeEnum, Enums: []string{"real_credit", "advance"}}, {Name: "advance_features", Type: field.TypeOther, Nullable: true, SchemaType: map[string]string{"postgres": "text[]"}}, {Name: "created_at", Type: field.TypeTime}, @@ -3203,7 +3203,7 @@ var ( {Name: "annotations", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "key", Type: field.TypeString, Nullable: true}, {Name: "primary_email", Type: field.TypeString, Nullable: true}, - {Name: "currency", Type: field.TypeString, Nullable: true, Size: 3}, + {Name: "currency", Type: field.TypeString, Nullable: true, Size: 24}, } // CustomersTable holds the schema information for the "customers" table. CustomersTable = &schema.Table{ @@ -3677,7 +3677,7 @@ var ( {Name: "kind", Type: field.TypeEnum, Enums: []string{"plan", "release", "reopen"}}, {Name: "amount", Type: field.TypeOther, SchemaType: map[string]string{"postgres": "numeric"}}, {Name: "customer_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "char(26)"}}, - {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(3)"}}, + {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(24)"}}, {Name: "credit_priority", Type: field.TypeInt}, {Name: "expires_at", Type: field.TypeTime}, {Name: "source_kind", Type: field.TypeEnum, Enums: []string{"credit_purchase", "usage", "usage_correction", "credit_purchase_correction", "advance_backfill"}}, @@ -3958,6 +3958,7 @@ var ( {Name: "routing_key_version", Type: field.TypeString}, {Name: "routing_key", Type: field.TypeString}, {Name: "currency", Type: field.TypeString}, + {Name: "source", Type: field.TypeString, Nullable: true}, {Name: "tax_code", Type: field.TypeString, Nullable: true}, {Name: "tax_behavior", Type: field.TypeString, Nullable: true}, {Name: "features", Type: field.TypeOther, Nullable: true, SchemaType: map[string]string{"postgres": "text[]"}}, @@ -3974,7 +3975,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "ledger_sub_account_routes_ledger_accounts_sub_account_routes", - Columns: []*schema.Column{LedgerSubAccountRoutesColumns[14]}, + Columns: []*schema.Column{LedgerSubAccountRoutesColumns[15]}, RefColumns: []*schema.Column{LedgerAccountsColumns[0]}, OnDelete: schema.NoAction, }, @@ -3993,7 +3994,7 @@ var ( { Name: "ledgersubaccountroute_namespace_account_id_routing_key_version_routing_key", Unique: true, - Columns: []*schema.Column{LedgerSubAccountRoutesColumns[1], LedgerSubAccountRoutesColumns[14], LedgerSubAccountRoutesColumns[5], LedgerSubAccountRoutesColumns[6]}, + Columns: []*schema.Column{LedgerSubAccountRoutesColumns[1], LedgerSubAccountRoutesColumns[15], LedgerSubAccountRoutesColumns[5], LedgerSubAccountRoutesColumns[6]}, }, }, } @@ -4800,7 +4801,7 @@ var ( {Name: "active_to", Type: field.TypeTime, Nullable: true}, {Name: "name", Type: field.TypeString, Default: "Subscription"}, {Name: "description", Type: field.TypeString, Nullable: true}, - {Name: "currency", Type: field.TypeString, Size: 3}, + {Name: "currency", Type: field.TypeString, Size: 24}, {Name: "billing_anchor", Type: field.TypeTime}, {Name: "billing_cadence", Type: field.TypeString}, {Name: "pro_rating_config", Type: field.TypeString, SchemaType: map[string]string{"postgres": "jsonb"}}, diff --git a/openmeter/ent/db/mutation.go b/openmeter/ent/db/mutation.go index 54d58e8819..1a302ed0c8 100644 --- a/openmeter/ent/db/mutation.go +++ b/openmeter/ent/db/mutation.go @@ -87376,6 +87376,7 @@ type LedgerSubAccountRouteMutation struct { routing_key_version *ledger.RoutingKeyVersion routing_key *string currency *string + source *currencyx.Code tax_code *string tax_behavior *ledger.TaxBehavior features *pq.StringArray @@ -87799,6 +87800,55 @@ func (m *LedgerSubAccountRouteMutation) ResetCurrency() { m.currency = nil } +// SetSource sets the "source" field. +func (m *LedgerSubAccountRouteMutation) SetSource(c currencyx.Code) { + m.source = &c +} + +// Source returns the value of the "source" field in the mutation. +func (m *LedgerSubAccountRouteMutation) Source() (r currencyx.Code, exists bool) { + v := m.source + if v == nil { + return + } + return *v, true +} + +// OldSource returns the old "source" field's value of the LedgerSubAccountRoute entity. +// If the LedgerSubAccountRoute object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *LedgerSubAccountRouteMutation) OldSource(ctx context.Context) (v *currencyx.Code, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSource is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSource requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSource: %w", err) + } + return oldValue.Source, nil +} + +// ClearSource clears the value of the "source" field. +func (m *LedgerSubAccountRouteMutation) ClearSource() { + m.source = nil + m.clearedFields[ledgersubaccountroute.FieldSource] = struct{}{} +} + +// SourceCleared returns if the "source" field was cleared in this mutation. +func (m *LedgerSubAccountRouteMutation) SourceCleared() bool { + _, ok := m.clearedFields[ledgersubaccountroute.FieldSource] + return ok +} + +// ResetSource resets all changes to the "source" field. +func (m *LedgerSubAccountRouteMutation) ResetSource() { + m.source = nil + delete(m.clearedFields, ledgersubaccountroute.FieldSource) +} + // SetTaxCode sets the "tax_code" field. func (m *LedgerSubAccountRouteMutation) SetTaxCode(s string) { m.tax_code = &s @@ -88229,7 +88279,7 @@ func (m *LedgerSubAccountRouteMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *LedgerSubAccountRouteMutation) Fields() []string { - fields := make([]string, 0, 14) + fields := make([]string, 0, 15) if m.namespace != nil { fields = append(fields, ledgersubaccountroute.FieldNamespace) } @@ -88254,6 +88304,9 @@ func (m *LedgerSubAccountRouteMutation) Fields() []string { if m.currency != nil { fields = append(fields, ledgersubaccountroute.FieldCurrency) } + if m.source != nil { + fields = append(fields, ledgersubaccountroute.FieldSource) + } if m.tax_code != nil { fields = append(fields, ledgersubaccountroute.FieldTaxCode) } @@ -88296,6 +88349,8 @@ func (m *LedgerSubAccountRouteMutation) Field(name string) (ent.Value, bool) { return m.RoutingKey() case ledgersubaccountroute.FieldCurrency: return m.Currency() + case ledgersubaccountroute.FieldSource: + return m.Source() case ledgersubaccountroute.FieldTaxCode: return m.TaxCode() case ledgersubaccountroute.FieldTaxBehavior: @@ -88333,6 +88388,8 @@ func (m *LedgerSubAccountRouteMutation) OldField(ctx context.Context, name strin return m.OldRoutingKey(ctx) case ledgersubaccountroute.FieldCurrency: return m.OldCurrency(ctx) + case ledgersubaccountroute.FieldSource: + return m.OldSource(ctx) case ledgersubaccountroute.FieldTaxCode: return m.OldTaxCode(ctx) case ledgersubaccountroute.FieldTaxBehavior: @@ -88410,6 +88467,13 @@ func (m *LedgerSubAccountRouteMutation) SetField(name string, value ent.Value) e } m.SetCurrency(v) return nil + case ledgersubaccountroute.FieldSource: + v, ok := value.(currencyx.Code) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSource(v) + return nil case ledgersubaccountroute.FieldTaxCode: v, ok := value.(string) if !ok { @@ -88500,6 +88564,9 @@ func (m *LedgerSubAccountRouteMutation) ClearedFields() []string { if m.FieldCleared(ledgersubaccountroute.FieldDeletedAt) { fields = append(fields, ledgersubaccountroute.FieldDeletedAt) } + if m.FieldCleared(ledgersubaccountroute.FieldSource) { + fields = append(fields, ledgersubaccountroute.FieldSource) + } if m.FieldCleared(ledgersubaccountroute.FieldTaxCode) { fields = append(fields, ledgersubaccountroute.FieldTaxCode) } @@ -88535,6 +88602,9 @@ func (m *LedgerSubAccountRouteMutation) ClearField(name string) error { case ledgersubaccountroute.FieldDeletedAt: m.ClearDeletedAt() return nil + case ledgersubaccountroute.FieldSource: + m.ClearSource() + return nil case ledgersubaccountroute.FieldTaxCode: m.ClearTaxCode() return nil @@ -88585,6 +88655,9 @@ func (m *LedgerSubAccountRouteMutation) ResetField(name string) error { case ledgersubaccountroute.FieldCurrency: m.ResetCurrency() return nil + case ledgersubaccountroute.FieldSource: + m.ResetSource() + return nil case ledgersubaccountroute.FieldTaxCode: m.ResetTaxCode() return nil diff --git a/openmeter/ent/schema/billing.go b/openmeter/ent/schema/billing.go index f1418d8dcf..0cd2b04138 100644 --- a/openmeter/ent/schema/billing.go +++ b/openmeter/ent/schema/billing.go @@ -846,7 +846,9 @@ type BillingStandardInvoiceDetailedLine struct { func (BillingStandardInvoiceDetailedLine) Mixin() []ent.Mixin { return []ent.Mixin{ - stddetailedline.Mixin{}, + stddetailedline.Mixin{ + CurrencyPostgresSchemaType: "varchar(3)", + }, } } diff --git a/openmeter/ent/schema/creditrealizationlineage.go b/openmeter/ent/schema/creditrealizationlineage.go index 6e4baaadc3..cbc32fb8cf 100644 --- a/openmeter/ent/schema/creditrealizationlineage.go +++ b/openmeter/ent/schema/creditrealizationlineage.go @@ -57,7 +57,7 @@ func (CreditRealizationLineage) Fields() []ent.Field { NotEmpty(). Immutable(). SchemaType(map[string]string{ - dialect.Postgres: "varchar(3)", + dialect.Postgres: currencyx.PostgresCodeSchemaType, }), field.Enum("origin_kind"). GoType(creditrealization.LineageOriginKind("")). diff --git a/openmeter/ent/schema/custom_currencies.go b/openmeter/ent/schema/custom_currencies.go index 006129bd73..d5a46f973f 100644 --- a/openmeter/ent/schema/custom_currencies.go +++ b/openmeter/ent/schema/custom_currencies.go @@ -29,8 +29,8 @@ func (CustomCurrency) Fields() []ent.Field { return []ent.Field{ field.String("code"). NotEmpty(). - MinLen(3). - MaxLen(24). + MinLen(currencyx.MinCodeLength). + MaxLen(currencyx.MaxCodeLength). Immutable(), field.String("name"). NotEmpty(), diff --git a/openmeter/ent/schema/customer.go b/openmeter/ent/schema/customer.go index e458dcf9bf..20a6ad244e 100644 --- a/openmeter/ent/schema/customer.go +++ b/openmeter/ent/schema/customer.go @@ -34,7 +34,12 @@ func (Customer) Fields() []ent.Field { // because we can only add unique indexes on fields that are not nullable. field.String("key").Optional(), field.String("primary_email").Optional().Nillable(), - field.String("currency").GoType(currencyx.Code("")).MinLen(3).MaxLen(3).Optional().Nillable(), + field.String("currency"). + GoType(currencyx.Code("")). + MinLen(currencyx.MinCodeLength). + MaxLen(currencyx.MaxCodeLength). + Optional(). + Nillable(), } } diff --git a/openmeter/ent/schema/ledger_account.go b/openmeter/ent/schema/ledger_account.go index e311048cc1..80bf278627 100644 --- a/openmeter/ent/schema/ledger_account.go +++ b/openmeter/ent/schema/ledger_account.go @@ -10,6 +10,7 @@ import ( "github.com/lib/pq" "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/framework/entutils" ) @@ -116,6 +117,9 @@ func (LedgerSubAccountRoute) Fields() []ent.Field { field.String("routing_key").Immutable(), // Literal routing values (denormalized from routing_key for query filtering; not FKs). field.String("currency").Immutable(), + field.String("source"). + GoType(currencyx.Code("")). + Optional().Nillable().Immutable(), // tax_code stores the TaxCode.Key string used as a routing dimension, not a FK to the tax_codes table. field.String("tax_code").Optional().Nillable().Immutable(), field.String("tax_behavior"). diff --git a/openmeter/ent/schema/ledger_breakage_record.go b/openmeter/ent/schema/ledger_breakage_record.go index f5fa8c6d0b..39765359ba 100644 --- a/openmeter/ent/schema/ledger_breakage_record.go +++ b/openmeter/ent/schema/ledger_breakage_record.go @@ -47,7 +47,7 @@ func (LedgerBreakageRecord) Fields() []ent.Field { NotEmpty(). Immutable(). SchemaType(map[string]string{ - dialect.Postgres: "varchar(3)", + dialect.Postgres: currencyx.PostgresCodeSchemaType, }), field.Int("credit_priority"). Immutable(), diff --git a/openmeter/ent/schema/subscription.go b/openmeter/ent/schema/subscription.go index 7666fc1cbc..d2f49ebc27 100644 --- a/openmeter/ent/schema/subscription.go +++ b/openmeter/ent/schema/subscription.go @@ -36,7 +36,12 @@ func (Subscription) Fields() []ent.Field { field.String("description").Optional().Nillable(), field.String("plan_id").Optional().Nillable(), field.String("customer_id").NotEmpty().Immutable(), - field.String("currency").GoType(currencyx.Code("")).MinLen(3).MaxLen(3).NotEmpty().Immutable(), + field.String("currency"). + GoType(currencyx.Code("")). + MinLen(currencyx.MinCodeLength). + MaxLen(currencyx.MaxCodeLength). + NotEmpty(). + Immutable(), field.Time("billing_anchor"), field.String("billing_cadence"). GoType(datetime.ISODurationString("")). diff --git a/openmeter/ledger/account/adapter/repo_test.go b/openmeter/ledger/account/adapter/repo_test.go index c58915cac7..656f3248f2 100644 --- a/openmeter/ledger/account/adapter/repo_test.go +++ b/openmeter/ledger/account/adapter/repo_test.go @@ -12,6 +12,7 @@ import ( entdb "github.com/openmeterio/openmeter/openmeter/ent/db" ledgeraccountdb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgeraccount" + ledgersubaccountdb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgersubaccount" ledgersubaccountroutedb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgersubaccountroute" "github.com/openmeterio/openmeter/openmeter/ledger" ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" @@ -142,6 +143,18 @@ func TestRepo_ListSubAccounts(t *testing.T) { }) require.NoError(t, err) + usdSource := currencyx.Code("USD") + eurSource := currencyx.Code("EUR") + subA6CustomSource, err := env.repo.EnsureSubAccount(ctx, ledgeraccount.CreateSubAccountInput{ + Namespace: namespace, + AccountID: accountA.ID.ID, + Route: ledger.Route{ + Currency: currencyx.Code("CREDITS"), + Source: &usdSource, + }, + }) + require.NoError(t, err) + _, err = env.repo.EnsureSubAccount(ctx, ledgeraccount.CreateSubAccountInput{ Namespace: namespace, AccountID: accountB.ID.ID, @@ -155,7 +168,7 @@ func TestRepo_ListSubAccounts(t *testing.T) { AccountID: accountA.ID.ID, }) require.NoError(t, err) - require.Len(t, items, 5) + require.Len(t, items, 6) }) t.Run("filters by route", func(t *testing.T) { @@ -204,6 +217,31 @@ func TestRepo_ListSubAccounts(t *testing.T) { require.Equal(t, authorizedStatus, *items[0].Route.TransactionAuthorizationStatus) }) + t.Run("maps and filters by source", func(t *testing.T) { + items, err := env.repo.ListSubAccounts(ctx, ledgeraccount.ListSubAccountsInput{ + Namespace: namespace, + AccountID: accountA.ID.ID, + Route: ledger.RouteFilter{ + Currency: currencyx.Code("CREDITS"), + Source: mo.Some(&usdSource), + }, + }) + require.NoError(t, err) + require.Len(t, items, 1) + require.Equal(t, subA6CustomSource.ID, items[0].ID) + require.Equal(t, usdSource, *items[0].Route.Source) + + entity, err := env.client.LedgerSubAccount.Query(). + Where( + ledgersubaccountdb.Namespace(namespace), + ledgersubaccountdb.ID(subA6CustomSource.ID), + ). + WithRoute(). + Only(ctx) + require.NoError(t, err) + require.Equal(t, usdSource, *entity.Edges.Route.Source) + }) + t.Run("create uses route uniqueness", func(t *testing.T) { dup, err := env.repo.EnsureSubAccount(ctx, ledgeraccount.CreateSubAccountInput{ Namespace: namespace, @@ -214,6 +252,30 @@ func TestRepo_ListSubAccounts(t *testing.T) { require.Equal(t, subA1.ID, dup.ID) }) + t.Run("create uses source in route uniqueness", func(t *testing.T) { + dup, err := env.repo.EnsureSubAccount(ctx, ledgeraccount.CreateSubAccountInput{ + Namespace: namespace, + AccountID: accountA.ID.ID, + Route: ledger.Route{ + Currency: currencyx.Code("CREDITS"), + Source: &usdSource, + }, + }) + require.NoError(t, err) + require.Equal(t, subA6CustomSource.ID, dup.ID) + + otherSource, err := env.repo.EnsureSubAccount(ctx, ledgeraccount.CreateSubAccountInput{ + Namespace: namespace, + AccountID: accountA.ID.ID, + Route: ledger.Route{ + Currency: currencyx.Code("CREDITS"), + Source: &eurSource, + }, + }) + require.NoError(t, err) + require.NotEqual(t, subA6CustomSource.ID, otherSource.ID) + }) + t.Run("create canonicalizes cost basis uniqueness", func(t *testing.T) { dup, err := env.repo.EnsureSubAccount(ctx, ledgeraccount.CreateSubAccountInput{ Namespace: namespace, diff --git a/openmeter/ledger/account/adapter/subaccount.go b/openmeter/ledger/account/adapter/subaccount.go index 18431fdd7c..60efdfadee 100644 --- a/openmeter/ledger/account/adapter/subaccount.go +++ b/openmeter/ledger/account/adapter/subaccount.go @@ -85,6 +85,7 @@ func (r *repo) resolveOrCreateRoute(ctx context.Context, input ledgeraccount.Cre SetRoutingKeyVersion(routeKey.Version()). SetRoutingKey(routeKey.Value()). SetCurrency(string(normalizedRoute.Currency)). + SetNillableSource(normalizedRoute.Source). SetNillableTaxCode(normalizedRoute.TaxCode). SetNillableTaxBehavior(normalizedRoute.TaxBehavior). SetFeatures(pq.StringArray(normalizedRoute.Features)). @@ -155,10 +156,18 @@ func (r *repo) ListSubAccounts(ctx context.Context, input ledgeraccount.ListSubA return nil, fmt.Errorf("failed to normalize route filter: %w", err) } - routePredicates := make([]predicate.LedgerSubAccountRoute, 0, 7) + routePredicates := make([]predicate.LedgerSubAccountRoute, 0, 8) if normalizedRoute.Currency != "" { routePredicates = append(routePredicates, dbledgersubaccountroute.Currency(string(normalizedRoute.Currency))) } + if normalizedRoute.Source.IsPresent() { + source, _ := normalizedRoute.Source.Get() + if source != nil { + routePredicates = append(routePredicates, dbledgersubaccountroute.Source(*source)) + } else { + routePredicates = append(routePredicates, dbledgersubaccountroute.SourceIsNil()) + } + } if normalizedRoute.CreditPriority != nil { routePredicates = append(routePredicates, dbledgersubaccountroute.CreditPriority(*normalizedRoute.CreditPriority), @@ -258,6 +267,7 @@ func MapSubAccountData(entity *db.LedgerSubAccount) (ledgeraccount.SubAccountDat AccountType: entity.Edges.Account.AccountType, Route: ledger.Route{ Currency: currencyx.Code(dbRoute.Currency), + Source: dbRoute.Source, TaxCode: dbRoute.TaxCode, TaxBehavior: dbRoute.TaxBehavior, Features: []string(dbRoute.Features), diff --git a/openmeter/ledger/accounts.go b/openmeter/ledger/accounts.go index 0b232beec1..0dd8e2f820 100644 --- a/openmeter/ledger/accounts.go +++ b/openmeter/ledger/accounts.go @@ -29,6 +29,7 @@ type CustomerFBOAccount interface { // CreditPriority is required (non-pointer) — the type system enforces its presence. type CustomerFBORouteParams struct { Currency currencyx.Code + Source *currencyx.Code CreditPriority int Features []string CostBasis *alpacadecimal.Decimal @@ -49,6 +50,7 @@ func (p CustomerFBORouteParams) Validate() error { func (p CustomerFBORouteParams) Route() Route { return Route{ Currency: p.Currency, + Source: p.Source, Features: p.Features, CostBasis: p.CostBasis, CreditPriority: &p.CreditPriority, @@ -68,6 +70,7 @@ type CustomerReceivableAccount interface { // TransactionAuthorizationStatus is required; callers must explicitly select the open or authorized route. type CustomerReceivableRouteParams struct { Currency currencyx.Code + Source *currencyx.Code TaxCode *string Features []string CostBasis *alpacadecimal.Decimal @@ -85,6 +88,7 @@ func (p CustomerReceivableRouteParams) Validate() error { func (p CustomerReceivableRouteParams) Route() Route { return Route{ Currency: p.Currency, + Source: p.Source, TaxCode: p.TaxCode, Features: p.Features, CostBasis: p.CostBasis, @@ -103,6 +107,7 @@ type CustomerAccruedAccount interface { // CustomerAccruedRouteParams are routing parameters specific to customer accrued sub-accounts. type CustomerAccruedRouteParams struct { Currency currencyx.Code + Source *currencyx.Code TaxCode *string TaxBehavior *TaxBehavior CostBasis *alpacadecimal.Decimal @@ -115,6 +120,7 @@ func (p CustomerAccruedRouteParams) Validate() error { func (p CustomerAccruedRouteParams) Route() Route { return Route{ Currency: p.Currency, + Source: p.Source, TaxCode: p.TaxCode, TaxBehavior: p.TaxBehavior, CostBasis: p.CostBasis, @@ -134,6 +140,7 @@ type BusinessAccount interface { type BusinessRouteParams struct { Currency currencyx.Code + Source *currencyx.Code TaxCode *string TaxBehavior *TaxBehavior CostBasis *alpacadecimal.Decimal @@ -146,6 +153,7 @@ func (p BusinessRouteParams) Validate() error { func (p BusinessRouteParams) Route() Route { return Route{ Currency: p.Currency, + Source: p.Source, TaxCode: p.TaxCode, TaxBehavior: p.TaxBehavior, CostBasis: p.CostBasis, diff --git a/openmeter/ledger/accounts_test.go b/openmeter/ledger/accounts_test.go new file mode 100644 index 0000000000..ba80db6517 --- /dev/null +++ b/openmeter/ledger/accounts_test.go @@ -0,0 +1,56 @@ +package ledger + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/openmeterio/openmeter/pkg/currencyx" +) + +func TestAccountRouteParamsPreserveSource(t *testing.T) { + source := currencyx.Code("USD") + status := TransactionAuthorizationStatusOpen + + tests := []struct { + name string + route Route + }{ + { + name: "customer fbo", + route: CustomerFBORouteParams{ + Currency: currencyx.Code("CREDITS"), + Source: &source, + CreditPriority: DefaultCustomerFBOPriority, + }.Route(), + }, + { + name: "customer receivable", + route: CustomerReceivableRouteParams{ + Currency: currencyx.Code("CREDITS"), + Source: &source, + TransactionAuthorizationStatus: status, + }.Route(), + }, + { + name: "customer accrued", + route: CustomerAccruedRouteParams{ + Currency: currencyx.Code("CREDITS"), + Source: &source, + }.Route(), + }, + { + name: "business", + route: BusinessRouteParams{ + Currency: currencyx.Code("CREDITS"), + Source: &source, + }.Route(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, &source, tt.route.Source) + }) + } +} diff --git a/openmeter/ledger/chargeadapter/creditpurchase.go b/openmeter/ledger/chargeadapter/creditpurchase.go index e510f2ee91..b2caf6e2e2 100644 --- a/openmeter/ledger/chargeadapter/creditpurchase.go +++ b/openmeter/ledger/chargeadapter/creditpurchase.go @@ -3,6 +3,7 @@ package chargeadapter import ( "cmp" "context" + "errors" "fmt" "slices" @@ -460,11 +461,6 @@ func (h *creditPurchaseHandler) advanceAttributions( return nil, err } - calculator, err := currency.Calculator() - if err != nil { - return nil, fmt.Errorf("get currency calculator: %w", err) - } - receivableBuckets := newAdvanceReceivableBuckets(advanceReceivables, creditFeatures) // The purchase can only attribute as much as both the purchase and matching @@ -490,7 +486,7 @@ func (h *creditPurchaseHandler) advanceAttributions( attributions := make([]advanceAttribution, 0, len(unattributedAccrued)+len(advanceReceivables)) if accruedAttributable.IsPositive() { - accruedAttributions, err := allocateAccruedAttribution(calculator, accruedAttributable, unattributedAccrued, receivableBuckets.remainingBySpendKey) + accruedAttributions, err := allocateAccruedAttribution(currency, accruedAttributable, unattributedAccrued, receivableBuckets.remainingBySpendKey) if err != nil { return nil, err } @@ -769,7 +765,7 @@ func (h *creditPurchaseHandler) unattributedAccruedBalances(ctx context.Context, // This keeps the old proportional tax-bucket behavior while preserving spend // provenance on each generated attribution leg. func allocateAccruedAttribution( - calculator currencyx.Calculator, + currency currencyx.Code, amount alpacadecimal.Decimal, unattributedAccrued []unattributedAccruedBalance, advanceRemainingBySpendKey map[string]alpacadecimal.Decimal, @@ -790,11 +786,25 @@ func allocateAccruedAttribution( }) } - allocations, err := currencyx.AllocateByAmount(calculator, currencyx.AmountAllocationInput[accruedBackfillBucketKey]{ - Amount: amount, - Items: items, - CompareKey: cmpx.Compare[accruedBackfillBucketKey], - }) + calculator, err := currency.Calculator() + if err == nil { + allocations, err := currencyx.AllocateByAmount(calculator, currencyx.AmountAllocationInput[accruedBackfillBucketKey]{ + Amount: amount, + Items: items, + CompareKey: cmpx.Compare[accruedBackfillBucketKey], + }) + if err != nil { + return nil, fmt.Errorf("allocate accrued attribution: %w", err) + } + + return allocations, nil + } + + if err := ledger.ValidateCurrency(currency); err != nil { + return nil, fmt.Errorf("currency: %w", err) + } + + allocations, err := allocateAccruedAttributionExactly(amount, items) if err != nil { return nil, fmt.Errorf("allocate accrued attribution: %w", err) } @@ -802,6 +812,70 @@ func allocateAccruedAttribution( return allocations, nil } +func allocateAccruedAttributionExactly( + amount alpacadecimal.Decimal, + items []currencyx.AmountAllocationItem[accruedBackfillBucketKey], +) ([]currencyx.AmountAllocation[accruedBackfillBucketKey], error) { + if amount.Sign() < 0 { + return nil, errors.New("amount must be non-negative") + } + + if amount.IsZero() { + return nil, nil + } + + if len(items) == 0 { + return nil, errors.New("items are required for a non-zero amount") + } + + totalAmount := alpacadecimal.Zero + for i, item := range items { + if item.Amount.Sign() <= 0 { + return nil, fmt.Errorf("items[%d].amount must be positive", i) + } + + totalAmount = totalAmount.Add(item.Amount) + } + + if amount.GreaterThan(totalAmount) { + return nil, errors.New("amount must not exceed total item amount") + } + + remainingAmount := amount + remainingTotal := totalAmount + allocations := make([]currencyx.AmountAllocation[accruedBackfillBucketKey], 0, len(items)) + + for _, item := range items { + if !remainingAmount.IsPositive() { + break + } + + allocated := remainingAmount + if item.Amount.LessThan(remainingTotal) { + allocated = remainingAmount.Mul(item.Amount).Div(remainingTotal) + } + if allocated.GreaterThan(item.Amount) { + allocated = item.Amount + } + + if allocated.IsPositive() { + allocations = append(allocations, currencyx.AmountAllocation[accruedBackfillBucketKey]{ + Key: item.Key, + Amount: allocated, + }) + } + + remainingAmount = remainingAmount.Sub(allocated) + remainingTotal = remainingTotal.Sub(item.Amount) + } + + if remainingAmount.IsPositive() { + return nil, errors.New("cannot distribute remaining allocation without exceeding item amounts") + } + + return allocations, nil +} + // totalUnattributedAccruedBalance returns accrued capacity that has matching // open advance receivable. This caps receivable attribution so backfill does // not translate more accrued value than exists for eligible spend provenance. diff --git a/openmeter/ledger/chargeadapter/creditpurchase_test.go b/openmeter/ledger/chargeadapter/creditpurchase_test.go index d84e3183ca..2ce3ec16d8 100644 --- a/openmeter/ledger/chargeadapter/creditpurchase_test.go +++ b/openmeter/ledger/chargeadapter/creditpurchase_test.go @@ -73,6 +73,32 @@ func TestOnPromotionalCreditPurchase_BacksAdvanceBeforeTopUp(t *testing.T) { require.True(t, env.sumBalance(t, env.washSubAccount(t, alpacadecimal.Zero)).Equal(alpacadecimal.NewFromInt(-100))) } +func TestOnPromotionalCreditPurchase_CustomCurrencyBacksAdvanceBeforeTopUp(t *testing.T) { + env := newCreditPurchaseHandlerTestEnv(t) + env.Currency = currencyx.Code("CREDITS") + env.createAdvanceExposure(t, alpacadecimal.NewFromInt(40)) + + charge := env.newPromotionalCharge(alpacadecimal.NewFromInt(100)) + ref, err := env.handler.OnPromotionalCreditPurchase(t.Context(), charge) + require.NoError(t, err) + require.NotEmpty(t, ref.TransactionGroupID) + require.ElementsMatch(t, []string{ + transactions.TemplateCode(transactions.AttributeCustomerAdvanceReceivableCostBasisTemplate{}), + transactions.TemplateCode(transactions.TranslateCustomerAccruedCostBasisTemplate{}), + transactions.TemplateCode(transactions.IssueCustomerReceivableTemplate{}), + transactions.TemplateCode(transactions.AuthorizeCustomerReceivablePaymentTemplate{}), + transactions.TemplateCode(transactions.SettleCustomerReceivableFromPaymentTemplate{}), + }, env.transactionTemplateCodes(t, ref.TransactionGroupID)) + + require.True(t, env.sumBalance(t, env.receivableSubAccount(t, alpacadecimal.Zero)).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.authorizedReceivableSubAccount(t, alpacadecimal.Zero)).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.unknownReceivableSubAccount(t)).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.unknownAccruedSubAccount(t)).Equal(alpacadecimal.Zero)) + require.True(t, env.sumBalance(t, env.accruedSubAccount(t, alpacadecimal.Zero)).Equal(alpacadecimal.NewFromInt(40))) + require.True(t, env.sumBalance(t, env.fboSubAccount(t, alpacadecimal.Zero)).Equal(alpacadecimal.NewFromInt(60))) + require.True(t, env.sumBalance(t, env.washSubAccount(t, alpacadecimal.Zero)).Equal(alpacadecimal.NewFromInt(-100))) +} + func TestOnCreditPurchaseInitiated_BackfillsOnlyMatchingFeatureAdvances(t *testing.T) { env := newCreditPurchaseHandlerTestEnv(t) env.createAdvanceExposureWithFeatures(t, alpacadecimal.NewFromInt(40), []string{"api-calls"}) @@ -551,7 +577,7 @@ func (e *creditPurchaseHandlerTestEnv) newPromotionalCharge(amount alpacadecimal Intent: meta.Intent{ ManagedBy: billing.SystemManagedLine, CustomerID: e.CustomerID.ID, - Currency: currencyx.Code("USD"), + Currency: e.Currency, TaxConfig: productcatalog.TaxCodeConfig{ TaxCodeID: "tax-code-id", }, @@ -595,7 +621,7 @@ func (e *creditPurchaseHandlerTestEnv) newExternalCharge(amount, costBasis alpac Intent: meta.Intent{ ManagedBy: billing.SystemManagedLine, CustomerID: e.CustomerID.ID, - Currency: currencyx.Code("USD"), + Currency: e.Currency, TaxConfig: productcatalog.TaxCodeConfig{ TaxCodeID: "tax-code-id", }, @@ -611,7 +637,7 @@ func (e *creditPurchaseHandlerTestEnv) newExternalCharge(amount, costBasis alpac Settlement: chargecreditpurchase.NewSettlement(chargecreditpurchase.ExternalSettlement{ InitialStatus: chargecreditpurchase.CreatedInitialPaymentSettlementStatus, GenericSettlement: chargecreditpurchase.GenericSettlement{ - Currency: currencyx.Code("USD"), + Currency: e.Currency, CostBasis: costBasis, }, }), diff --git a/openmeter/ledger/customerbalance/facade.go b/openmeter/ledger/customerbalance/facade.go index 9af8ec0bb7..14c16e6509 100644 --- a/openmeter/ledger/customerbalance/facade.go +++ b/openmeter/ledger/customerbalance/facade.go @@ -7,6 +7,7 @@ import ( "time" "github.com/alpacahq/alpacadecimal" + "github.com/samber/lo" "github.com/samber/mo" "github.com/openmeterio/openmeter/openmeter/billing/charges/creditpurchase" @@ -21,13 +22,15 @@ type CurrencyFilter struct { } func (f CurrencyFilter) Validate() error { - for _, code := range f.Codes { - if code == "" { - return errors.New("currency code is required") + errs := lo.Map(f.Codes, func(code currencyx.Code, i int) error { + if err := ledger.ValidateCurrency(code); err != nil { + return fmt.Errorf("code %d: %w", i, err) } - } - return nil + return nil + }) + + return models.NewNillableGenericValidationError(errors.Join(errs...)) } type GetBalancesInput struct { @@ -74,7 +77,7 @@ func (i GetBalanceInput) Validate() error { errs = append(errs, fmt.Errorf("customer ID: %w", err)) } - if err := i.Currency.Validate(); err != nil { + if err := ledger.ValidateCurrency(i.Currency); err != nil { errs = append(errs, fmt.Errorf("currency: %w", err)) } @@ -132,7 +135,7 @@ func (f *Facade) GetBalances(ctx context.Context, input GetBalancesInput) ([]Bal codes = dedupeCurrencies(input.Currencies.Codes) for _, code := range codes { - if err := code.Validate(); err != nil { + if err := ledger.ValidateCurrency(code); err != nil { return nil, fmt.Errorf("currency %q is not supported by ledger: %w", code, err) } } diff --git a/openmeter/ledger/customerbalance/facade_test.go b/openmeter/ledger/customerbalance/facade_test.go index f9caa6cd2d..ff9ec20b96 100644 --- a/openmeter/ledger/customerbalance/facade_test.go +++ b/openmeter/ledger/customerbalance/facade_test.go @@ -100,21 +100,23 @@ func TestFacadeGetBalancesDiscoversPendingGrantCurrencies(t *testing.T) { require.True(t, balances[0].Balance.Pending().Equal(alpacadecimal.NewFromInt(40))) } -func TestFacadeGetBalancesWithUnsupportedExplicitCurrency(t *testing.T) { +func TestFacadeGetBalancesWithCustomExplicitCurrency(t *testing.T) { env := newTestEnv(t) facade, err := NewFacade(env.Service) require.NoError(t, err) - _, err = facade.GetBalances(t.Context(), GetBalancesInput{ + balances, err := facade.GetBalances(t.Context(), GetBalancesInput{ CustomerID: env.CustomerID, Currencies: CurrencyFilter{ Codes: []currencyx.Code{"CUSTOM"}, }, }) - require.Error(t, err) - require.ErrorContains(t, err, "CUSTOM") - require.ErrorContains(t, err, "not supported by ledger") + require.NoError(t, err) + require.Len(t, balances, 1) + require.Equal(t, currencyx.Code("CUSTOM"), balances[0].Currency) + require.True(t, balances[0].Balance.Settled().IsZero()) + require.True(t, balances[0].Balance.Pending().IsZero()) } func TestFacadeGetBalanceAfterTransactionCursor(t *testing.T) { diff --git a/openmeter/ledger/customerbalance/service.go b/openmeter/ledger/customerbalance/service.go index db5a6816a0..a008c16d49 100644 --- a/openmeter/ledger/customerbalance/service.go +++ b/openmeter/ledger/customerbalance/service.go @@ -115,7 +115,7 @@ func (i GetBalanceServiceInput) Validate() error { errs = append(errs, fmt.Errorf("customer ID: %w", err)) } - if err := i.Currency.Validate(); err != nil { + if err := ledger.ValidateCurrency(i.Currency); err != nil { errs = append(errs, fmt.Errorf("currency: %w", err)) } diff --git a/openmeter/ledger/customerbalance/service_test.go b/openmeter/ledger/customerbalance/service_test.go index 9dbdf7caba..6a326015e5 100644 --- a/openmeter/ledger/customerbalance/service_test.go +++ b/openmeter/ledger/customerbalance/service_test.go @@ -63,7 +63,7 @@ func TestGetBalanceServiceInputValidate(t *testing.T) { name: "invalid currency", input: GetBalanceServiceInput{ CustomerID: valid.CustomerID, - Currency: currencyx.Code("not-a-currency"), + Currency: currencyx.Code("NO"), }, wantErr: true, }, diff --git a/openmeter/ledger/customerbalance/transactions.go b/openmeter/ledger/customerbalance/transactions.go index f5231f5d96..519c8db6ac 100644 --- a/openmeter/ledger/customerbalance/transactions.go +++ b/openmeter/ledger/customerbalance/transactions.go @@ -84,7 +84,7 @@ func (i ListCreditTransactionsInput) Validate() error { } if i.Currency != nil { - if err := i.Currency.Validate(); err != nil { + if err := ledger.ValidateCurrency(*i.Currency); err != nil { errs = append(errs, fmt.Errorf("currency: %w", err)) } } diff --git a/openmeter/ledger/historical/adapter/ledger_test.go b/openmeter/ledger/historical/adapter/ledger_test.go index 953e107700..2ae25a5028 100644 --- a/openmeter/ledger/historical/adapter/ledger_test.go +++ b/openmeter/ledger/historical/adapter/ledger_test.go @@ -998,6 +998,12 @@ func TestRepo_SumEntries_Filters(t *testing.T) { CreditPriority: lo.ToPtr(1), CostBasis: lo.ToPtr(mustDecimal(t, "0.7")), }) + usdSource := currencyx.Code("USD") + eurSource := currencyx.Code("EUR") + subAccountE := env.createSubAccount(t, namespace, ledger.Route{ + Currency: currencyx.Code("CREDITS"), + Source: &usdSource, + }) group, err := env.repo.CreateTransactionGroup(ctx, ledgerhistorical.CreateTransactionGroupInput{Namespace: namespace}) require.NoError(t, err) @@ -1043,6 +1049,15 @@ func TestRepo_SumEntries_Filters(t *testing.T) { txCostBasis, err := env.repo.BookTransaction(ctx, models.NamespacedID{Namespace: namespace, ID: group.ID}, txInputCostBasis) require.NoError(t, err) + txInputSource := mustSetUpHistoricalTransactionInput(t, time.Now().UTC().Add(-10*time.Minute), []*transactionstestutils.AnyEntryInput{ + { + Address: testAddress(t, subAccountE), + AmountValue: alpacadecimal.NewFromInt(10), + }, + }) + _, err = env.repo.BookTransaction(ctx, models.NamespacedID{Namespace: namespace, ID: group.ID}, txInputSource) + require.NoError(t, err) + // Sum by currency sumUSD, err := env.repo.SumEntries(ctx, ledger.Query{ Namespace: namespace, @@ -1108,6 +1123,30 @@ func TestRepo_SumEntries_Filters(t *testing.T) { require.NoError(t, err) require.True(t, sumCostBasis.Equal(alpacadecimal.NewFromInt(25))) + sumSource, err := env.repo.SumEntries(ctx, ledger.Query{ + Namespace: namespace, + Filters: ledger.Filters{ + Route: ledger.RouteFilter{ + Currency: currencyx.Code("CREDITS"), + Source: mo.Some(&usdSource), + }, + }, + }) + require.NoError(t, err) + require.True(t, sumSource.Equal(alpacadecimal.NewFromInt(10))) + + sumOtherSource, err := env.repo.SumEntries(ctx, ledger.Query{ + Namespace: namespace, + Filters: ledger.Filters{ + Route: ledger.RouteFilter{ + Currency: currencyx.Code("CREDITS"), + Source: mo.Some(&eurSource), + }, + }, + }) + require.NoError(t, err) + require.True(t, sumOtherSource.IsZero()) + sumAfterLate, err := env.repo.SumEntries(ctx, ledger.Query{ Namespace: namespace, Filters: ledger.Filters{ diff --git a/openmeter/ledger/historical/adapter/sumentries_query.go b/openmeter/ledger/historical/adapter/sumentries_query.go index 327fd1635f..aea70f8ead 100644 --- a/openmeter/ledger/historical/adapter/sumentries_query.go +++ b/openmeter/ledger/historical/adapter/sumentries_query.go @@ -136,10 +136,18 @@ func (b *sumEntriesQuery) subAccountPredicates() ([]predicate.LedgerSubAccount, }) } - routePredicates := make([]predicate.LedgerSubAccountRoute, 0, 7) + routePredicates := make([]predicate.LedgerSubAccountRoute, 0, 8) if normalizedRoute.Currency != "" { routePredicates = append(routePredicates, ledgersubaccountroutedb.Currency(string(normalizedRoute.Currency))) } + if normalizedRoute.Source.IsPresent() { + source, _ := normalizedRoute.Source.Get() + if source != nil { + routePredicates = append(routePredicates, ledgersubaccountroutedb.Source(*source)) + } else { + routePredicates = append(routePredicates, ledgersubaccountroutedb.SourceIsNil()) + } + } if normalizedRoute.CreditPriority != nil { routePredicates = append(routePredicates, ledgersubaccountroutedb.CreditPriority(*normalizedRoute.CreditPriority), diff --git a/openmeter/ledger/ledger_fx_test.go b/openmeter/ledger/ledger_fx_test.go index f6f5e47493..e0a0b223f1 100644 --- a/openmeter/ledger/ledger_fx_test.go +++ b/openmeter/ledger/ledger_fx_test.go @@ -144,4 +144,50 @@ func TestFXOnInvoiceIssued(t *testing.T) { _, err = histLedger.CommitGroup(ctx, transactions.GroupInputs(namespace, nil, inputs...)) require.NoError(t, err) }) + + t.Run("CustomFundingSourceRoutes", func(t *testing.T) { + sourceCurrency := currencyx.Code("USD") + targetCurrency := currencyx.Code("CREDITS") + inputs, err := transactions.ResolveTransactions( + ctx, + transactions.ResolverDependencies{ + AccountService: resolversSvc, + AccountCatalog: deps.AccountService, + BalanceQuerier: deps.HistoricalLedger, + }, + transactions.ResolutionScope{ + CustomerID: customerID, + Namespace: namespace, + }, + transactions.ConvertCurrencyTemplate{ + At: time.Now(), + TargetAmount: alpacadecimal.NewFromInt(200), + CostBasis: alpacadecimal.NewFromFloat(0.5), + SourceCurrency: sourceCurrency, + TargetCurrency: targetCurrency, + }, + ) + require.NoError(t, err) + require.Len(t, inputs, 1) + + totals := map[currencyx.Code]alpacadecimal.Decimal{} + customEntries := 0 + for _, entry := range inputs[0].EntryInputs() { + route := entry.PostingAddress().Route().Route() + totals[route.Currency] = totals[route.Currency].Add(entry.Amount()) + + switch route.Currency { + case sourceCurrency: + require.Nil(t, route.Source) + case targetCurrency: + customEntries++ + require.NotNil(t, route.Source) + require.Equal(t, sourceCurrency, *route.Source) + } + } + + require.Equal(t, 2, customEntries) + require.True(t, totals[sourceCurrency].IsZero(), "source currency total: %s", totals[sourceCurrency]) + require.True(t, totals[targetCurrency].IsZero(), "target currency total: %s", totals[targetCurrency]) + }) } diff --git a/openmeter/ledger/primitives.go b/openmeter/ledger/primitives.go index 4eac7f8ee7..a2ee8538c0 100644 --- a/openmeter/ledger/primitives.go +++ b/openmeter/ledger/primitives.go @@ -45,6 +45,7 @@ type SubAccount interface { // RouteFilter is the set of route fields that can be used to filter sub-accounts and query balances. type RouteFilter struct { Currency currencyx.Code + Source mo.Option[*currencyx.Code] // Non-currency fields are retained for near-future expansion. TaxCode mo.Option[*string] diff --git a/openmeter/ledger/recognizer/service.go b/openmeter/ledger/recognizer/service.go index d8266e2abe..ce4debfc6a 100644 --- a/openmeter/ledger/recognizer/service.go +++ b/openmeter/ledger/recognizer/service.go @@ -97,7 +97,7 @@ func (i RecognizeEarningsInput) Validate() error { if i.At.IsZero() { errs = append(errs, errors.New("at is required")) } - if err := i.Currency.Validate(); err != nil { + if err := i.Currency.ValidateFormat(); err != nil { errs = append(errs, fmt.Errorf("currency: %w", err)) } diff --git a/openmeter/ledger/recognizer/service_test.go b/openmeter/ledger/recognizer/service_test.go index 2d244746de..192933daa8 100644 --- a/openmeter/ledger/recognizer/service_test.go +++ b/openmeter/ledger/recognizer/service_test.go @@ -19,6 +19,7 @@ import ( ledgertestutils "github.com/openmeterio/openmeter/openmeter/ledger/testutils" "github.com/openmeterio/openmeter/openmeter/ledger/transactions" "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/timeutil" ) @@ -67,6 +68,18 @@ func testID() string { return ulid.Make().String() } +func TestRecognizeEarnings_AllowsCustomCurrencyWithoutLineages(t *testing.T) { + env := newRecognizerTestEnv(t) + + result, err := env.recognizer.RecognizeEarnings(t.Context(), recognizer.RecognizeEarningsInput{ + CustomerID: env.CustomerID, + At: clock.Now(), + Currency: currencyx.Code("CREDITS"), + }) + require.NoError(t, err) + require.True(t, result.RecognizedAmount.IsZero()) +} + func (e *recognizerTestEnv) resolverDeps() transactions.ResolverDependencies { return transactions.ResolverDependencies{ AccountService: e.Deps.ResolversService, diff --git a/openmeter/ledger/routing.go b/openmeter/ledger/routing.go index 75ea30b5a7..b83e4218ff 100644 --- a/openmeter/ledger/routing.go +++ b/openmeter/ledger/routing.go @@ -24,6 +24,8 @@ const ( // Use V2 when a route has a non-nil TaxBehavior; otherwise use V1 for // backward compatibility with sub-accounts created before tax_behavior existed. RoutingKeyVersionV2 RoutingKeyVersion = "v2" + // RoutingKeyVersionV3 extends V2 by adding the source segment. + RoutingKeyVersionV3 RoutingKeyVersion = "v3" ) type TransactionAuthorizationStatus string @@ -55,7 +57,7 @@ func (s TransactionAuthorizationStatus) Validate() error { func (v RoutingKeyVersion) Validate() error { switch v { - case RoutingKeyVersionV1, RoutingKeyVersionV2: + case RoutingKeyVersionV1, RoutingKeyVersionV2, RoutingKeyVersionV3: return nil default: return ErrRoutingKeyVersionInvalid.WithAttrs(models.Attributes{ @@ -159,7 +161,10 @@ type Route struct { // manually except for testing edge cases. Version RoutingKeyVersion Currency currencyx.Code - TaxCode *string + // Source identifies the fiat currency that funded a custom currency bucket. + // Fiat currency routes must keep Source nil. + Source *currencyx.Code + TaxCode *string // TaxBehavior distinguishes taxable accrued and earnings buckets. // Customer FBO routes do not carry tax dimensions; credit sources are // attributed to charge tax configuration when they accrue. @@ -174,6 +179,9 @@ func (r Route) Validate() error { if err := ValidateCurrency(r.Currency); err != nil { return err } + if err := ValidateCurrencySource(r.Currency, r.Source); err != nil { + return err + } if r.CreditPriority != nil { if err := ValidateCreditPriority(*r.CreditPriority); err != nil { @@ -211,6 +219,7 @@ func (r Route) Validate() error { func (r Route) Filter() RouteFilter { return RouteFilter{ Currency: r.Currency, + Source: mo.Some(r.Source), TaxCode: mo.Some(r.TaxCode), TaxBehavior: mo.Some(r.TaxBehavior), Features: mo.Some(r.Features), @@ -224,6 +233,17 @@ func (r Route) Matches(filter RouteFilter) bool { if filter.Currency != "" && r.Currency != filter.Currency { return false } + if filter.Source.IsPresent() { + source, _ := filter.Source.Get() + switch { + case source == nil && r.Source != nil: + return false + case source != nil && r.Source == nil: + return false + case source != nil && r.Source != nil && *source != *r.Source: + return false + } + } if filter.TaxCode.IsPresent() { taxCode, _ := filter.TaxCode.Get() switch { @@ -295,7 +315,7 @@ func (r Route) Normalize() (Route, error) { // Normalize canonicalizes route filter values before querying. func (f RouteFilter) Normalize() (RouteFilter, error) { - if f.Currency == "" && f.TaxCode.IsAbsent() && f.Features.IsAbsent() && f.MatchFeature == "" && f.CostBasis.IsAbsent() && f.CreditPriority == nil && f.TransactionAuthorizationStatus == nil && f.TaxBehavior.IsAbsent() { + if f.Currency == "" && f.Source.IsAbsent() && f.TaxCode.IsAbsent() && f.Features.IsAbsent() && f.MatchFeature == "" && f.CostBasis.IsAbsent() && f.CreditPriority == nil && f.TransactionAuthorizationStatus == nil && f.TaxBehavior.IsAbsent() { return f, nil } if f.Features.IsPresent() && f.MatchFeature != "" { @@ -307,12 +327,14 @@ func (f RouteFilter) Normalize() (RouteFilter, error) { } } + source, _ := f.Source.Get() taxCode, _ := f.TaxCode.Get() taxBehavior, _ := f.TaxBehavior.Get() features, _ := f.Features.Get() costBasis, _ := f.CostBasis.Get() normalized, err := Route{ Currency: f.Currency, + Source: source, TaxCode: taxCode, TaxBehavior: taxBehavior, Features: features, @@ -329,6 +351,11 @@ func (f RouteFilter) Normalize() (RouteFilter, error) { normalizedCostBasis = mo.Some(normalized.CostBasis) } + normalizedSource := mo.None[*currencyx.Code]() + if f.Source.IsPresent() { + normalizedSource = mo.Some(normalized.Source) + } + normalizedTaxCode := mo.None[*string]() if f.TaxCode.IsPresent() { normalizedTaxCode = mo.Some(normalized.TaxCode) @@ -346,6 +373,7 @@ func (f RouteFilter) Normalize() (RouteFilter, error) { return RouteFilter{ Currency: normalized.Currency, + Source: normalizedSource, TaxCode: normalizedTaxCode, TaxBehavior: normalizedTaxBehavior, Features: normalizedFeatures, @@ -369,6 +397,7 @@ type routingVersionRequirement struct { // routingVersionRequirements lists versions above V1 with the conditions that trigger them. // Ordered highest to lowest; selectRoutingKeyVersion returns the first match, V1 otherwise. var routingVersionRequirements = []routingVersionRequirement{ + {version: RoutingKeyVersionV3, requires: func(r Route) bool { return r.Source != nil }}, {version: RoutingKeyVersionV2, requires: func(r Route) bool { return r.TaxBehavior != nil }}, } @@ -397,6 +426,8 @@ func BuildRoutingKey(route Route) (RoutingKey, error) { return buildRoutingKeyV1Normalized(normalizedRoute) case RoutingKeyVersionV2: return buildRoutingKeyV2Normalized(normalizedRoute) + case RoutingKeyVersionV3: + return buildRoutingKeyV3Normalized(normalizedRoute) default: return RoutingKey{}, ErrRoutingKeyVersionUnsupported.WithAttrs(models.Attributes{ "routing_key_version": normalizedRoute.Version, @@ -411,6 +442,9 @@ func BuildRoutingKeyV1(route Route) (RoutingKey, error) { if route.TaxBehavior != nil { return RoutingKey{}, fmt.Errorf("TaxBehavior requires a V2 routing key; use BuildRoutingKey to select the version automatically") } + if route.Source != nil { + return RoutingKey{}, fmt.Errorf("Source requires a V3 routing key; use BuildRoutingKey to select the version automatically") + } normalizedRoute, err := route.Normalize() if err != nil { return RoutingKey{}, err @@ -420,6 +454,9 @@ func BuildRoutingKeyV1(route Route) (RoutingKey, error) { // BuildRoutingKeyV2 encodes route as a V2 routing key. func BuildRoutingKeyV2(route Route) (RoutingKey, error) { + if route.Source != nil { + return RoutingKey{}, fmt.Errorf("Source requires a V3 routing key; use BuildRoutingKey to select the version automatically") + } normalizedRoute, err := route.Normalize() if err != nil { return RoutingKey{}, err @@ -427,6 +464,15 @@ func BuildRoutingKeyV2(route Route) (RoutingKey, error) { return buildRoutingKeyV2Normalized(normalizedRoute) } +// BuildRoutingKeyV3 encodes route as a V3 routing key. +func BuildRoutingKeyV3(route Route) (RoutingKey, error) { + normalizedRoute, err := route.Normalize() + if err != nil { + return RoutingKey{}, err + } + return buildRoutingKeyV3Normalized(normalizedRoute) +} + // buildRoutingKeyV1Normalized encodes an already-normalized route as a V1 key. func buildRoutingKeyV1Normalized(route Route) (RoutingKey, error) { value := strings.Join([]string{ @@ -456,6 +502,21 @@ func buildRoutingKeyV2Normalized(route Route) (RoutingKey, error) { return NewRoutingKey(RoutingKeyVersionV2, value) } +func buildRoutingKeyV3Normalized(route Route) (RoutingKey, error) { + value := strings.Join([]string{ + "currency:" + string(route.Currency), + "source:" + optionalCurrencyValue(route.Source), + "tax_code:" + optionalStringValue(route.TaxCode), + "tax_behavior:" + string(lo.FromPtrOr(route.TaxBehavior, "null")), + "features:" + canonicalFeatures(route.Features), + "cost_basis:" + optionalDecimalValue(route.CostBasis), + "credit_priority:" + optionalIntValue(route.CreditPriority), + "transaction_authorization_status:" + string(lo.FromPtrOr(route.TransactionAuthorizationStatus, "null")), + }, "|") + + return NewRoutingKey(RoutingKeyVersionV3, value) +} + // ---------------------------------------------------------------------------- // Validation helpers // ---------------------------------------------------------------------------- @@ -486,9 +547,11 @@ func ValidateCreditPriority(value int) error { return nil } -// ValidateCurrency validates a currency value. +// ValidateCurrency validates durable ledger currency code shape. It does not +// require fiat registry metadata so historical custom-currency facts stay +// readable even if product currency definitions change. func ValidateCurrency(value currencyx.Code) error { - if err := value.Validate(); err != nil { + if err := value.ValidateFormat(); err != nil { return ErrCurrencyInvalid.WithAttrs(models.Attributes{ "currency": value, }) @@ -497,6 +560,35 @@ func ValidateCurrency(value currencyx.Code) error { return nil } +func ValidateCurrencySource(currency currencyx.Code, source *currencyx.Code) error { + if source == nil { + return nil + } + + if err := source.Validate(); err != nil { + return ErrCurrencyInvalid.WithAttrs(models.Attributes{ + "source": *source, + }) + } + + if !source.IsKnownFiat() { + return ErrCurrencyInvalid.WithAttrs(models.Attributes{ + "source": *source, + "reason": "source_must_be_fiat", + }) + } + + if currency.IsKnownFiat() { + return ErrCurrencyInvalid.WithAttrs(models.Attributes{ + "currency": currency, + "source": *source, + "reason": "fiat_currency_source_must_be_null", + }) + } + + return nil +} + func ValidateCostBasis(value alpacadecimal.Decimal) error { if value.IsNegative() { return ErrCostBasisInvalid.WithAttrs(models.Attributes{ @@ -564,6 +656,13 @@ func optionalStringValue(s *string) string { return *s } +func optionalCurrencyValue(s *currencyx.Code) string { + if s == nil || *s == "" { + return "null" + } + return string(*s) +} + func optionalIntValue(v *int) string { if v == nil { return "null" diff --git a/openmeter/ledger/routing_test.go b/openmeter/ledger/routing_test.go index 3e63112be9..f4a8efa73b 100644 --- a/openmeter/ledger/routing_test.go +++ b/openmeter/ledger/routing_test.go @@ -35,6 +35,110 @@ func TestBuildRoutingKeyV1_Nulls(t *testing.T) { require.Equal(t, "currency:USD|tax_code:null|features:null|cost_basis:null|credit_priority:null|transaction_authorization_status:null", key.Value()) } +func TestRouteValidateAcceptsCustomCurrency(t *testing.T) { + source := currencyx.Code("USD") + require.NoError(t, Route{ + Currency: currencyx.Code("CUSTOM"), + Source: &source, + }.Validate()) +} + +func TestRouteValidateAcceptsCustomCurrencyWithoutSource(t *testing.T) { + require.NoError(t, Route{ + Currency: currencyx.Code("CUSTOM"), + }.Validate()) +} + +func TestBuildRoutingKeyV1_CustomCurrencyWithoutSource(t *testing.T) { + key, err := BuildRoutingKey(Route{ + Currency: currencyx.Code("CUSTOM"), + }) + require.NoError(t, err) + require.Equal(t, RoutingKeyVersionV1, key.Version()) + require.Equal(t, "currency:CUSTOM|tax_code:null|features:null|cost_basis:null|credit_priority:null|transaction_authorization_status:null", key.Value()) +} + +func TestBuildRoutingKeyV3_CustomCurrencySource(t *testing.T) { + usd := currencyx.Code("USD") + eur := currencyx.Code("EUR") + + usdKey, err := BuildRoutingKey(Route{ + Currency: currencyx.Code("CUSTOM"), + Source: &usd, + }) + require.NoError(t, err) + require.Equal(t, RoutingKeyVersionV3, usdKey.Version()) + require.Equal(t, "currency:CUSTOM|source:USD|tax_code:null|tax_behavior:null|features:null|cost_basis:null|credit_priority:null|transaction_authorization_status:null", usdKey.Value()) + + eurKey, err := BuildRoutingKey(Route{ + Currency: currencyx.Code("CUSTOM"), + Source: &eur, + }) + require.NoError(t, err) + require.NotEqual(t, usdKey.Value(), eurKey.Value()) +} + +func TestRouteValidateRejectsInvalidSource(t *testing.T) { + tests := []struct { + name string + currency currencyx.Code + source *currencyx.Code + }{ + { + name: "fiat route source must be null", + currency: currencyx.Code("USD"), + source: lo.ToPtr(currencyx.Code("EUR")), + }, + { + name: "source must be fiat", + currency: currencyx.Code("CUSTOM"), + source: lo.ToPtr(currencyx.Code("CREDITS")), + }, + { + name: "source must be structurally valid", + currency: currencyx.Code("CUSTOM"), + source: lo.ToPtr(currencyx.Code("BAD|CODE")), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := Route{ + Currency: tt.currency, + Source: tt.source, + }.Validate() + require.Error(t, err) + require.ErrorIs(t, err, ErrCurrencyInvalid) + }) + } +} + +func TestRouteValidateRejectsInvalidCurrency(t *testing.T) { + tests := []struct { + name string + currency currencyx.Code + }{ + { + name: "too short", + currency: currencyx.Code("XY"), + }, + { + name: "routing delimiter", + currency: currencyx.Code("BAD|CODE"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := Route{ + Currency: tt.currency, + }.Validate() + require.Error(t, err) + require.ErrorIs(t, err, ErrCurrencyInvalid) + }) + } +} + func TestBuildRoutingKeyV1_SameLiterals_SameKey(t *testing.T) { priority := 100 input := Route{ @@ -139,6 +243,15 @@ func TestBuildRoutingKeyV2_DifferentTaxBehavior_DifferentKey(t *testing.T) { require.NotEqual(t, k1.Value(), k2.Value()) } +func TestBuildRoutingKeyV2_RejectsSource(t *testing.T) { + source := currencyx.Code("USD") + _, err := BuildRoutingKeyV2(Route{ + Currency: currencyx.Code("CUSTOM"), + Source: &source, + }) + require.Error(t, err) +} + func TestTaxBehaviorValidate(t *testing.T) { require.NoError(t, TaxBehaviorInclusive.Validate()) require.NoError(t, TaxBehaviorExclusive.Validate()) @@ -167,6 +280,19 @@ func TestRouteFilter_NormalizePreservesTaxCode(t *testing.T) { require.Equal(t, &tc, got) } +func TestRouteFilter_NormalizePreservesSource(t *testing.T) { + source := currencyx.Code("USD") + f := RouteFilter{ + Currency: currencyx.Code("CUSTOM"), + Source: mo.Some(&source), + } + norm, err := f.Normalize() + require.NoError(t, err) + require.True(t, norm.Source.IsPresent()) + got, _ := norm.Source.Get() + require.Equal(t, &source, got) +} + func TestRouteFilter_NormalizePreservesTaxBehavior(t *testing.T) { b := TaxBehaviorExclusive f := RouteFilter{ @@ -248,6 +374,8 @@ func TestRouteMatches(t *testing.T) { otherTaxBehavior := TaxBehaviorExclusive authStatus := TransactionAuthorizationStatusOpen otherAuthStatus := TransactionAuthorizationStatusAuthorized + source := currencyx.Code("USD") + otherSource := currencyx.Code("EUR") route := Route{ Currency: currencyx.Code("USD"), @@ -262,6 +390,9 @@ func TestRouteMatches(t *testing.T) { Currency: currencyx.Code("USD"), CreditPriority: &priority, } + sourcedRoute := route + sourcedRoute.Currency = currencyx.Code("CUSTOM") + sourcedRoute.Source = &source tests := []struct { name string @@ -286,6 +417,36 @@ func TestRouteMatches(t *testing.T) { filter: RouteFilter{Currency: currencyx.Code("EUR")}, want: false, }, + { + name: "source absent ignores populated route source", + route: sourcedRoute, + filter: RouteFilter{}, + want: true, + }, + { + name: "source match", + route: sourcedRoute, + filter: RouteFilter{Source: mo.Some(&source)}, + want: true, + }, + { + name: "source mismatch", + route: sourcedRoute, + filter: RouteFilter{Source: mo.Some(&otherSource)}, + want: false, + }, + { + name: "nil source filter rejects populated route source", + route: sourcedRoute, + filter: RouteFilter{Source: mo.Some[*currencyx.Code](nil)}, + want: false, + }, + { + name: "nil source filter matches nil route source", + route: unrestrictedRoute, + filter: RouteFilter{Source: mo.Some[*currencyx.Code](nil)}, + want: true, + }, { name: "tax code absent ignores populated route tax code", route: route, diff --git a/openmeter/ledger/transactions/fx.go b/openmeter/ledger/transactions/fx.go index e7b4f8414e..9ee769438f 100644 --- a/openmeter/ledger/transactions/fx.go +++ b/openmeter/ledger/transactions/fx.go @@ -73,6 +73,10 @@ func (t ConvertCurrencyTemplate) resolve(ctx context.Context, customerID custome return nil, fmt.Errorf("failed to normalize cost basis: cost basis must be non-negative") } costBasis := t.CostBasis + var targetSource *currencyx.Code + if t.SourceCurrency.IsKnownFiat() && !t.TargetCurrency.IsKnownFiat() { + targetSource = &t.SourceCurrency + } customerAccounts, err := resolvers.AccountService.GetCustomerAccounts(ctx, customerID) if err != nil { @@ -90,6 +94,7 @@ func (t ConvertCurrencyTemplate) resolve(ctx context.Context, customerID custome targetAccount, err := customerAccounts.FBOAccount.GetSubAccountForRoute(ctx, ledger.CustomerFBORouteParams{ Currency: t.TargetCurrency, + Source: targetSource, CostBasis: &costBasis, CreditPriority: priority, }) @@ -112,6 +117,7 @@ func (t ConvertCurrencyTemplate) resolve(ctx context.Context, customerID custome brokerageTarget, err := businessAccounts.BrokerageAccount.GetSubAccountForRoute(ctx, ledger.BusinessRouteParams{ Currency: t.TargetCurrency, + Source: targetSource, CostBasis: &costBasis, }) if err != nil { diff --git a/openmeter/ledger/validations.go b/openmeter/ledger/validations.go index f6597ba7dd..97c450d1a5 100644 --- a/openmeter/ledger/validations.go +++ b/openmeter/ledger/validations.go @@ -74,10 +74,7 @@ func validateEntryAmountPrecision(entry EntryInput) error { currency := entry.PostingAddress().Route().Route().Currency calculator, err := currency.Calculator() if err != nil { - return ErrCurrencyInvalid.WithAttrs(models.Attributes{ - "currency": currency, - "error": err, - }) + return ValidateCurrency(currency) } amount := entry.Amount() diff --git a/test/credits/sanity_test.go b/test/credits/sanity_test.go index c313cb1258..2bf39d54c7 100644 --- a/test/credits/sanity_test.go +++ b/test/credits/sanity_test.go @@ -26,6 +26,7 @@ import ( entdb "github.com/openmeterio/openmeter/openmeter/ent/db" dbledgerbreakagerecord "github.com/openmeterio/openmeter/openmeter/ent/db/ledgerbreakagerecord" "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" "github.com/openmeterio/openmeter/openmeter/ledger/transactions" "github.com/openmeterio/openmeter/openmeter/meter" "github.com/openmeterio/openmeter/openmeter/productcatalog" @@ -89,6 +90,83 @@ func (s *SanitySuite) TestUsageBasedCreditOnlyDeleteCorrectionSanity() { s.assertUnfundedCreditOnlyDeleted(setup.customer.GetID()) } +func (s *SanitySuite) TestCustomCurrencyLedgerCreditBalanceSanity() { + ctx := s.T().Context() + ns := s.GetUniqueNamespace("charges-sanity-custom-currency-credit-balance") + s.ProvisionDefaultTaxCodes(ctx, ns) + + cust := s.CreateLedgerBackedCustomer(ns, "test-subject") + + customCurrency := currencyx.Code("CREDITS") + grantAt := datetime.MustParseTimeInLocation(s.T(), "2026-01-01T00:00:00Z", time.UTC).AsTime() + amount := alpacadecimal.NewFromInt(42) + costBasis := alpacadecimal.Zero + + clock.FreezeTime(grantAt) + defer clock.UnFreeze() + + // given: + // - a promotional credit purchase uses a long custom currency code + res, err := s.Charges.Create(ctx, charges.CreateInput{ + Namespace: ns, + Intents: charges.ChargeIntents{ + s.CreateCreditPurchaseIntent(CreateCreditPurchaseIntentInput{ + Customer: cust.GetID(), + Currency: customCurrency, + Amount: amount, + ServicePeriod: timeutil.ClosedPeriod{ + From: grantAt, + To: grantAt, + }, + Settlement: creditpurchase.NewSettlement(creditpurchase.PromotionalSettlement{}), + }), + }, + }) + s.Require().NoError(err) + s.Require().Len(res, 1) + + charge, err := res[0].AsCreditPurchaseCharge() + s.Require().NoError(err) + s.Equal(creditpurchase.StatusFinal, charge.Status) + s.Equal(customCurrency, charge.Intent.Currency) + s.Require().NotNil(charge.Realizations.CreditGrantRealization) + s.Require().NotEmpty(charge.Realizations.CreditGrantRealization.TransactionGroupID) + + // then: + // - ledger buckets and the customer balance facade accept and preserve the custom code + s.AssertDecimalEqual(amount, s.MustCustomerFBOBalance(cust.GetID(), customCurrency, mo.Some(&costBasis)), "custom currency FBO balance") + s.AssertDecimalEqual(alpacadecimal.Zero, s.MustCustomerReceivableBalance(cust.GetID(), customCurrency, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusOpen), "custom currency open receivable") + s.AssertDecimalEqual(alpacadecimal.Zero, s.MustCustomerReceivableBalance(cust.GetID(), customCurrency, mo.Some(&costBasis), ledger.TransactionAuthorizationStatusAuthorized), "custom currency authorized receivable") + s.AssertDecimalEqual(amount.Neg(), s.MustWashBalance(ns, customCurrency, mo.Some(&costBasis)), "custom currency wash balance") + + facade, err := customerbalance.NewFacade(s.CustomerBalanceSvc) + s.Require().NoError(err) + + explicitBalances, err := facade.GetBalances(ctx, customerbalance.GetBalancesInput{ + CustomerID: cust.GetID(), + Currencies: customerbalance.CurrencyFilter{ + Codes: []currencyx.Code{customCurrency}, + }, + }) + s.Require().NoError(err) + s.Require().Len(explicitBalances, 1) + s.Equal(customCurrency, explicitBalances[0].Currency) + s.AssertDecimalEqual(amount, explicitBalances[0].Balance.Settled(), "explicit custom currency settled balance") + s.AssertDecimalEqual(alpacadecimal.Zero, explicitBalances[0].Balance.Pending(), "explicit custom currency pending balance") + + discoveredBalances, err := facade.GetBalances(ctx, customerbalance.GetBalancesInput{ + CustomerID: cust.GetID(), + }) + s.Require().NoError(err) + + customBalance, ok := lo.Find(discoveredBalances, func(balance customerbalance.BalanceByCurrency) bool { + return balance.Currency == customCurrency + }) + s.Require().True(ok, "custom currency must be discovered from FBO routes") + s.AssertDecimalEqual(amount, customBalance.Balance.Settled(), "discovered custom currency settled balance") + s.AssertDecimalEqual(alpacadecimal.Zero, customBalance.Balance.Pending(), "discovered custom currency pending balance") +} + func (s *SanitySuite) TestFlatFeeFundedCreditOnlyRecognizedRevenueDeleteCorrectionSanity() { setup := s.setupFlatFeeCreditOnlyDeleteCorrection("charges-sanity-flatfee-funded-credit-only-recognized-delete") zeroCostBasis := alpacadecimal.Zero diff --git a/tools/migrate/migrations/20260702111040_widen_currency_codes.down.sql b/tools/migrate/migrations/20260702111040_widen_currency_codes.down.sql new file mode 100644 index 0000000000..c6387a9c88 --- /dev/null +++ b/tools/migrate/migrations/20260702111040_widen_currency_codes.down.sql @@ -0,0 +1,19 @@ +-- drop dependent search view before narrowing charge currency columns +DROP VIEW IF EXISTS "charges_search_v1s"; +-- reverse: modify "ledger_breakage_records" table +ALTER TABLE "ledger_breakage_records" ALTER COLUMN "currency" TYPE character varying(3); +-- reverse: modify "credit_realization_lineages" table +ALTER TABLE "credit_realization_lineages" ALTER COLUMN "currency" TYPE character varying(3); +-- reverse: modify "charge_usage_based_run_detailed_line" table +ALTER TABLE "charge_usage_based_run_detailed_line" ALTER COLUMN "currency" TYPE character varying(3); +-- reverse: modify "charge_usage_based" table +ALTER TABLE "charge_usage_based" ALTER COLUMN "currency" TYPE character varying(3); +-- reverse: modify "charge_flat_fees" table +ALTER TABLE "charge_flat_fees" ALTER COLUMN "currency" TYPE character varying(3); +-- reverse: modify "charge_flat_fee_run_detailed_lines" table +ALTER TABLE "charge_flat_fee_run_detailed_lines" ALTER COLUMN "currency" TYPE character varying(3); +-- reverse: modify "charge_credit_purchases" table +ALTER TABLE "charge_credit_purchases" ALTER COLUMN "currency" TYPE character varying(3); +-- recreate charges_search_v1s view after narrowing charge currency columns +CREATE VIEW "charges_search_v1s" AS +SELECT "id", "namespace", "metadata", "created_at", "updated_at", "deleted_at", "name", "description", "annotations", "customer_id", "service_period_from", "service_period_to", "billing_period_from", "billing_period_to", "full_service_period_from", "full_service_period_to", "status", "unique_reference_id", "currency", "managed_by", "subscription_id", "subscription_phase_id", "subscription_item_id", "advance_after", "tax_code_id", "tax_behavior", NULL::timestamptz AS "base_intent_deleted_at", 'credit_purchase' AS "type" FROM "charge_credit_purchases" UNION ALL SELECT "id", "namespace", "metadata", "created_at", "updated_at", "deleted_at", "name", "description", "annotations", "customer_id", "service_period_from", "service_period_to", "billing_period_from", "billing_period_to", "full_service_period_from", "full_service_period_to", "status", "unique_reference_id", "currency", "managed_by", "subscription_id", "subscription_phase_id", "subscription_item_id", "advance_after", "tax_code_id", "tax_behavior", "intent_deleted_at" AS "base_intent_deleted_at", 'flat_fee' AS "type" FROM "charge_flat_fees" UNION ALL SELECT "id", "namespace", "metadata", "created_at", "updated_at", "deleted_at", "name", "description", "annotations", "customer_id", "service_period_from", "service_period_to", "billing_period_from", "billing_period_to", "full_service_period_from", "full_service_period_to", "status", "unique_reference_id", "currency", "managed_by", "subscription_id", "subscription_phase_id", "subscription_item_id", "advance_after", "tax_code_id", "tax_behavior", "intent_deleted_at" AS "base_intent_deleted_at", 'usage_based' AS "type" FROM "charge_usage_based"; diff --git a/tools/migrate/migrations/20260702111040_widen_currency_codes.up.sql b/tools/migrate/migrations/20260702111040_widen_currency_codes.up.sql new file mode 100644 index 0000000000..325dc50a74 --- /dev/null +++ b/tools/migrate/migrations/20260702111040_widen_currency_codes.up.sql @@ -0,0 +1,19 @@ +-- drop dependent search view before widening charge currency columns +DROP VIEW IF EXISTS "charges_search_v1s"; +-- modify "charge_credit_purchases" table +ALTER TABLE "charge_credit_purchases" ALTER COLUMN "currency" TYPE character varying(24); +-- modify "charge_flat_fee_run_detailed_lines" table +ALTER TABLE "charge_flat_fee_run_detailed_lines" ALTER COLUMN "currency" TYPE character varying(24); +-- modify "charge_flat_fees" table +ALTER TABLE "charge_flat_fees" ALTER COLUMN "currency" TYPE character varying(24); +-- modify "charge_usage_based" table +ALTER TABLE "charge_usage_based" ALTER COLUMN "currency" TYPE character varying(24); +-- modify "charge_usage_based_run_detailed_line" table +ALTER TABLE "charge_usage_based_run_detailed_line" ALTER COLUMN "currency" TYPE character varying(24); +-- modify "credit_realization_lineages" table +ALTER TABLE "credit_realization_lineages" ALTER COLUMN "currency" TYPE character varying(24); +-- modify "ledger_breakage_records" table +ALTER TABLE "ledger_breakage_records" ALTER COLUMN "currency" TYPE character varying(24); +-- recreate charges_search_v1s view after widening charge currency columns +CREATE VIEW "charges_search_v1s" AS +SELECT "id", "namespace", "metadata", "created_at", "updated_at", "deleted_at", "name", "description", "annotations", "customer_id", "service_period_from", "service_period_to", "billing_period_from", "billing_period_to", "full_service_period_from", "full_service_period_to", "status", "unique_reference_id", "currency", "managed_by", "subscription_id", "subscription_phase_id", "subscription_item_id", "advance_after", "tax_code_id", "tax_behavior", NULL::timestamptz AS "base_intent_deleted_at", 'credit_purchase' AS "type" FROM "charge_credit_purchases" UNION ALL SELECT "id", "namespace", "metadata", "created_at", "updated_at", "deleted_at", "name", "description", "annotations", "customer_id", "service_period_from", "service_period_to", "billing_period_from", "billing_period_to", "full_service_period_from", "full_service_period_to", "status", "unique_reference_id", "currency", "managed_by", "subscription_id", "subscription_phase_id", "subscription_item_id", "advance_after", "tax_code_id", "tax_behavior", "intent_deleted_at" AS "base_intent_deleted_at", 'flat_fee' AS "type" FROM "charge_flat_fees" UNION ALL SELECT "id", "namespace", "metadata", "created_at", "updated_at", "deleted_at", "name", "description", "annotations", "customer_id", "service_period_from", "service_period_to", "billing_period_from", "billing_period_to", "full_service_period_from", "full_service_period_to", "status", "unique_reference_id", "currency", "managed_by", "subscription_id", "subscription_phase_id", "subscription_item_id", "advance_after", "tax_code_id", "tax_behavior", "intent_deleted_at" AS "base_intent_deleted_at", 'usage_based' AS "type" FROM "charge_usage_based"; diff --git a/tools/migrate/migrations/20260702112521_add_ledger_sub_account_source.down.sql b/tools/migrate/migrations/20260702112521_add_ledger_sub_account_source.down.sql new file mode 100644 index 0000000000..b88223b717 --- /dev/null +++ b/tools/migrate/migrations/20260702112521_add_ledger_sub_account_source.down.sql @@ -0,0 +1,2 @@ +-- reverse: modify "ledger_sub_account_routes" table +ALTER TABLE "ledger_sub_account_routes" DROP COLUMN "source"; diff --git a/tools/migrate/migrations/20260702112521_add_ledger_sub_account_source.up.sql b/tools/migrate/migrations/20260702112521_add_ledger_sub_account_source.up.sql new file mode 100644 index 0000000000..ab6c80159e --- /dev/null +++ b/tools/migrate/migrations/20260702112521_add_ledger_sub_account_source.up.sql @@ -0,0 +1,2 @@ +-- modify "ledger_sub_account_routes" table +ALTER TABLE "ledger_sub_account_routes" ADD COLUMN "source" character varying NULL; diff --git a/tools/migrate/migrations/atlas.sum b/tools/migrate/migrations/atlas.sum index b1e720a2e2..2a42643946 100644 --- a/tools/migrate/migrations/atlas.sum +++ b/tools/migrate/migrations/atlas.sum @@ -1,4 +1,4 @@ -h1:s2oRo/6B13zS2TWjmu07nC0DNRALM73DHEEgbOEait4= +h1:Vx/0/OYUPEP/6afTtexCxzf7xTTDsZBdum+1TayiJJk= 20240826120919_init.up.sql h1:tc1V91/smlmaeJGQ8h+MzTEeFjjnrrFDbDAjOYJK91o= 20240903155435_entitlement-expired-index.up.sql h1:Hp8u5uckmLXc1cRvWU0AtVnnK8ShlpzZNp8pbiJLhac= 20240917172257_billing-entities.up.sql h1:Q1dAMo0Vjiit76OybClNfYPGC5nmvov2/M2W1ioi4Kw= @@ -223,3 +223,5 @@ h1:s2oRo/6B13zS2TWjmu07nC0DNRALM73DHEEgbOEait4= 20260630163845_add_ledger_entry_schema_version.up.sql h1:WX9XcECusLP6iIlnR2YBe6i654ocAq1q7xzeOgSHo7Q= 20260630181340_add_breakage_record_source_charge_id.up.sql h1:OLPQPKq1S5J941TXL7tHceMSrvMWk5E/nAc3rAnXhZM= 20260701084156_add_currency_cost_basis_effective_to.up.sql h1:31ASikUUszFmPpArKruD3TeAnTIA7ebqMJdt2DzMH6A= +20260702111040_widen_currency_codes.up.sql h1:qGcUsZACAka76pND701/oODHSz3bMg1sB2t7KlKPopQ= +20260702112521_add_ledger_sub_account_source.up.sql h1:ffxGMU95yUF0lzv6tBOCx7y0dq0rAuy/jIHu15+Cuss=