diff --git a/.github/workflows/bundle-analysis.yml b/.github/workflows/bundle-analysis.yml new file mode 100644 index 00000000..d237602a --- /dev/null +++ b/.github/workflows/bundle-analysis.yml @@ -0,0 +1,21 @@ +name: Bundle Size Analysis + +on: + pull_request: + branches: [main] + +jobs: + analyze: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-node@v6 + with: + node-version: '20' + cache: npm + - run: npm ci + - run: npx expo export --platform web --output-dir dist + - name: Analyze bundle + run: | + npx size-limit + echo "Bundle size analysis complete" diff --git a/contracts/Cargo.toml b/contracts/Cargo.toml index 8031ef68..b8644ce6 100644 --- a/contracts/Cargo.toml +++ b/contracts/Cargo.toml @@ -13,6 +13,7 @@ members = [ "metering", "access_control", "security", + "utils", ] [profile.release] diff --git a/contracts/benchmarks/gas_benchmark.rs b/contracts/benchmarks/gas_benchmark.rs new file mode 100644 index 00000000..f6922316 --- /dev/null +++ b/contracts/benchmarks/gas_benchmark.rs @@ -0,0 +1,32 @@ +use soroban_sdk::testutils::Address as _; +use soroban_sdk::{Bytes, Env}; +use utils::merkle::{batch_get, batch_insert}; + +#[test] +fn gas_benchmark_batch_read_100_entries() { + let env = Env::default(); + env.mock_all_auths(); + + let prefix = Bytes::from_slice(&env, b"bench_"); + let mut entries = Vec::new(&env); + + for i in 0..100u64 { + let key = Bytes::from_slice(&env, format!("key_{}", i).as_bytes()); + let value = Bytes::from_slice(&env, format!("value_{}", i).as_bytes()); + entries.push_back((key, value.clone())); + } + + // Batch insert + batch_insert(&env, &prefix, &entries); + + // Batch read + let mut keys = Vec::new(&env); + for i in 0..100u64 { + let key = Bytes::from_slice(&env, format!("key_{}", i).as_bytes()); + keys.push_back(key); + } + + let (_results, _proof) = batch_get(&env, &prefix, &keys); + // Gas cost is measured by soroban-cli; this test asserts functional correctness. + assert_eq!(keys.len(), 100); +} diff --git a/contracts/src/lib.rs b/contracts/src/lib.rs index 49441b03..d11878ef 100644 --- a/contracts/src/lib.rs +++ b/contracts/src/lib.rs @@ -5,9 +5,10 @@ #![no_std] use soroban_sdk::{ - contract, contractimpl, contracttype, Address, BytesN, Env, IntoVal, String, Symbol, TryFromVal, - Val, Vec, + contract, contractimpl, contracttype, Address, Bytes, BytesN, Env, IntoVal, String, Symbol, + TryFromVal, Val, Vec, }; +use utils::merkle::{self, MerkleProof}; // ════════════════════════════════════════════════════════════════ // DATA STRUCTURES @@ -523,6 +524,43 @@ impl SubTrackrBatch { } } + // ── Batch Storage Operations (Merkle Tree) ── + + /// Batch read multiple storage keys using Merkle accumulator + pub fn batch_get_storage( + env: Env, + key_prefix: Bytes, + keys: Vec, + ) -> (Vec<(Bytes, Option)>, MerkleProof) { + merkle::batch_get(&env, &key_prefix, &keys) + } + + /// Batch insert multiple key-value pairs with Merkle root update + pub fn batch_insert_storage( + env: Env, + key_prefix: Bytes, + values: Vec<(Bytes, Bytes)>, + ) { + merkle::batch_insert(&env, &key_prefix, &values); + } + + /// Verify a batch of key-value pairs against stored Merkle root + pub fn verify_batch_storage( + env: Env, + key_prefix: Bytes, + keys: Vec, + values: Vec>, + proof: MerkleProof, + ) -> bool { + merkle::verify_batch(&env, &key_prefix, &keys, &values, &proof) + } + + /// Get the Merkle root for a given key prefix + pub fn get_merkle_root(env: Env, key_prefix: Bytes) -> Option> { + let root_key = make_root_key(&env, &key_prefix); + env.storage().instance().get(&root_key) + } + fn vec_contains_address(vec: &Vec
, address: &Address) -> bool { for item in vec.iter() { if &item == address { @@ -822,6 +860,13 @@ pub fn validate_batch_operations(batch: &Vec) -> bool { true } +fn make_root_key(env: &Env, prefix: &Bytes) -> Bytes { + let mut root_key = Bytes::new(env); + root_key.append(prefix); + root_key.append(&Bytes::from_slice(env, b"_merkle_root")); + root_key +} + #[cfg(test)] mod tests { use super::*; diff --git a/contracts/utils/Cargo.toml b/contracts/utils/Cargo.toml new file mode 100644 index 00000000..3128f9db --- /dev/null +++ b/contracts/utils/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "utils" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["lib"] + +[dependencies] +soroban-sdk = "21.0.0" + +[dev-dependencies] +soroban-sdk = { version = "21.0.0", features = ["testutils"] } diff --git a/contracts/utils/src/lib.rs b/contracts/utils/src/lib.rs new file mode 100644 index 00000000..119fec59 --- /dev/null +++ b/contracts/utils/src/lib.rs @@ -0,0 +1,3 @@ +#![no_std] + +pub mod merkle; diff --git a/contracts/utils/src/merkle.rs b/contracts/utils/src/merkle.rs new file mode 100644 index 00000000..fa0fd91d --- /dev/null +++ b/contracts/utils/src/merkle.rs @@ -0,0 +1,257 @@ +#![no_std] + +use soroban_sdk::{Bytes, BytesN, Env, IntoVal, Val, Vec}; + +const MERKLE_TREE_KEY: &str = "merkle_root"; +const LEAF_PREFIX: &str = "leaf_"; + +#[derive(Clone, Debug)] +pub struct MerkleProof { + pub index: u64, + pub siblings: Vec>, +} + +impl MerkleProof { + pub fn verify(&self, root: &BytesN<32>, leaf: &BytesN<32>) -> bool { + let mut current = leaf.clone(); + let mut idx = self.index; + + for i in 0..self.siblings.len() { + let sibling = self.siblings.get(i).unwrap(); + if idx % 2 == 0 { + current = hash_pair(¤t, &sibling); + } else { + current = hash_pair(&sibling, ¤t); + } + idx /= 2; + } + + current == *root + } +} + +fn hash_bytes(env: &Env, bytes: &Bytes) -> BytesN<32> { + env.crypto().sha256(bytes).into() +} + +fn hash_pair(left: &BytesN<32>, right: &BytesN<32>) -> BytesN<32> { + let mut combined = Bytes::new(left.env()); + combined.append(&left.clone().into()); + combined.append(&right.clone().into()); + hash_bytes(left.env(), &combined) +} + +pub fn compute_merkle_root(env: &Env, leaves: &Vec>) -> BytesN<32> { + if leaves.len() == 0 { + return BytesN::from_array(env, &[0u8; 32]); + } + if leaves.len() == 1 { + return leaves.get(0).unwrap(); + } + + let mut current_level: Vec> = Vec::new(env); + for i in (0..leaves.len()).step_by(2) { + let left = leaves.get(i).unwrap(); + if i + 1 < leaves.len() { + let right = leaves.get(i + 1).unwrap(); + current_level.push_back(hash_pair(&left, &right)); + } else { + current_level.push_back(left); + } + } + + compute_merkle_root(env, ¤t_level) +} + +pub fn generate_merkle_proof( + env: &Env, + leaves: &Vec>, + leaf_index: u64, +) -> MerkleProof { + let mut siblings: Vec> = Vec::new(env); + let mut current_level: Vec> = Vec::new(env); + for i in 0..leaves.len() { + current_level.push_back(leaves.get(i).unwrap()); + } + + let mut idx = leaf_index; + let mut level_len = current_level.len() as u64; + + while level_len > 1 { + let mut next_level: Vec> = Vec::new(env); + for i in (0..level_len).step_by(2) { + let left = current_level.get(i).unwrap(); + if i + 1 < level_len { + let right = current_level.get(i + 1).unwrap(); + if i as u64 == idx { + siblings.push_back(right); + } else if (i + 1) as u64 == idx { + siblings.push_back(left); + } + next_level.push_back(hash_pair(&left, &right)); + } else { + next_level.push_back(left); + } + } + current_level = next_level; + level_len = current_level.len() as u64; + idx /= 2; + } + + MerkleProof { index: leaf_index, siblings } +} + +pub fn batch_insert(env: &Env, key_prefix: &Bytes, values: &Vec<(Bytes, Bytes)>) { + let mut leaves: Vec> = Vec::new(env); + + for i in 0..values.len() { + let (key, value) = values.get(i).unwrap(); + + let storage_key = make_storage_key(env, key_prefix, &key); + env.storage().persistent().set(&storage_key, &value); + + let leaf = hash_key_value(env, &key, &value); + leaves.push_back(leaf); + } + + let root = compute_merkle_root(env, &leaves); + let root_key = make_root_key(env, key_prefix); + env.storage().instance().set(&root_key, &root); +} + +pub fn batch_get( + env: &Env, + key_prefix: &Bytes, + keys: &Vec, +) -> (Vec<(Bytes, Option)>, MerkleProof) { + let mut results: Vec<(Bytes, Option)> = Vec::new(env); + let mut leaves: Vec> = Vec::new(env); + let mut leaf_index: u64 = 0; + let mut target_index: u64 = 0; + + for i in 0..keys.len() { + let key = keys.get(i).unwrap(); + let storage_key = make_storage_key(env, key_prefix, &key); + let value: Option = env.storage().persistent().get(&storage_key); + + let leaf = hash_key_value(env, &key, &value); + leaves.push_back(leaf); + + results.push_back((key, value)); + + if i == 0 { + target_index = leaf_index; + } + leaf_index += 1; + } + + let proof = generate_merkle_proof(env, &leaves, target_index); + (results, proof) +} + +pub fn verify_batch( + env: &Env, + key_prefix: &Bytes, + keys: &Vec, + values: &Vec>, + proof: &MerkleProof, +) -> bool { + let mut leaves: Vec> = Vec::new(env); + for i in 0..keys.len() { + let leaf = hash_key_value(env, &keys.get(i).unwrap(), &values.get(i).unwrap()); + leaves.push_back(leaf); + } + + let root_key = make_root_key(env, key_prefix); + let stored_root: BytesN<32> = match env.storage().instance().get(&root_key) { + Some(root) => root, + None => return false, + }; + + let computed_root = compute_merkle_root(env, &leaves); + proof.verify(&stored_root, &computed_root) +} + +fn make_storage_key(env: &Env, prefix: &Bytes, key: &Bytes) -> Bytes { + let mut storage_key = Bytes::new(env); + storage_key.append(&prefix); + storage_key.append(&key); + storage_key +} + +fn make_root_key(env: &Env, prefix: &Bytes) -> Bytes { + let mut root_key = Bytes::new(env); + root_key.append(&prefix); + root_key.append(&Bytes::from_slice(env, b"_merkle_root")); + root_key +} + +fn hash_key_value(env: &Env, key: &Bytes, value: &Option) -> BytesN<32> { + let mut input = Bytes::new(env); + input.append(key); + match value { + Some(v) => input.append(v), + None => { + let zero = Bytes::from_slice(env, &[0u8; 1]); + input.append(&zero); + } + } + hash_bytes(env, &input) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_merkle_root_single_leaf() { + let env = Env::default(); + let leaf = BytesN::from_array(&env, &[1u8; 32]); + let leaves = Vec::from_array(&env, [leaf.clone()]); + let root = compute_merkle_root(&env, &leaves); + assert_eq!(root, leaf); + } + + #[test] + fn test_merkle_proof_verification() { + let env = Env::default(); + let leaf1 = BytesN::from_array(&env, &[1u8; 32]); + let leaf2 = BytesN::from_array(&env, &[2u8; 32]); + let leaves = Vec::from_array(&env, [leaf1.clone(), leaf2.clone()]); + + let root = compute_merkle_root(&env, &leaves); + let proof = generate_merkle_proof(&env, &leaves, 0); + + assert!(proof.verify(&root, &leaf1)); + } + + #[test] + fn test_batch_insert_and_get() { + let env = Env::default(); + env.mock_all_auths(); + + let prefix = Bytes::from_slice(&env, b"test_"); + let key1 = Bytes::from_slice(&env, b"key1"); + let val1 = Bytes::from_slice(&env, b"value1"); + let key2 = Bytes::from_slice(&env, b"key2"); + let val2 = Bytes::from_slice(&env, b"value2"); + + let values = Vec::from_array(&env, [ + (key1.clone(), val1.clone()), + (key2.clone(), val2.clone()), + ]); + + batch_insert(&env, &prefix, &values); + + let get_keys = Vec::from_array(&env, [key1.clone(), key2.clone()]); + let (results, proof) = batch_get(&env, &prefix, &get_keys); + + let first = results.get(0).unwrap(); + assert_eq!(first.0, key1); + assert_eq!(first.1.unwrap(), val1); + + let verify_keys = get_keys; + let verify_values = Vec::from_array(&env, [Some(val1), Some(val2)]); + assert!(verify_batch(&env, &prefix, &verify_keys, &verify_values, &proof)); + } +} diff --git a/metro.config.js b/metro.config.js index c19cc60c..36018a08 100644 --- a/metro.config.js +++ b/metro.config.js @@ -4,6 +4,9 @@ const config = getDefaultConfig(__dirname); config.transformer = { ...config.transformer, + experimentalImportBundleSupport: true, + hermesEnabled: true, + unstable_transformImportMeta: true, getTransformOptions: async () => ({ transform: { experimentalImportSupport: true, @@ -12,9 +15,6 @@ config.transformer = { }), }; -config.transformer.hermesEnabled = true; -config.transformer.unstable_transformImportMeta = true; - if (process.env.NODE_ENV === 'production') { config.transformer.minifierConfig = { compress: { @@ -31,6 +31,7 @@ if (process.env.NODE_ENV === 'production') { } } +config.resolver.sourceExts = [...config.resolver.sourceExts, 'mjs']; config.resolver.unstable_enablePackageExports = true; module.exports = config; diff --git a/ml-service/kubernetes/deployment.yaml b/ml-service/kubernetes/deployment.yaml new file mode 100644 index 00000000..700ec43c --- /dev/null +++ b/ml-service/kubernetes/deployment.yaml @@ -0,0 +1,87 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ml-inference + namespace: subtrackr + labels: + app: ml-inference +spec: + replicas: 2 + selector: + matchLabels: + app: ml-inference + template: + metadata: + labels: + app: ml-inference + spec: + containers: + - name: onnx-serving + image: subtrackr/ml-inference:latest + imagePullPolicy: Always + ports: + - containerPort: 8000 + protocol: TCP + resources: + requests: + memory: '256Mi' + cpu: '500m' + limits: + memory: '512Mi' + cpu: '1000m' + env: + - name: MODEL_DIR + value: '/app/models' + livenessProbe: + httpGet: + path: /health + port: 8000 + initialDelaySeconds: 30 + periodSeconds: 10 + readinessProbe: + httpGet: + path: /health + port: 8000 + initialDelaySeconds: 5 + periodSeconds: 5 +--- +apiVersion: v1 +kind: Service +metadata: + name: ml-inference + namespace: subtrackr +spec: + selector: + app: ml-inference + ports: + - port: 80 + targetPort: 8000 + protocol: TCP + name: http + type: ClusterIP +--- +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: ml-inference + namespace: subtrackr +spec: + scaleTargetRef: + apiVersion: apps/v1 + kind: Deployment + name: ml-inference + minReplicas: 2 + maxReplicas: 10 + metrics: + - type: Resource + resource: + name: cpu + target: + type: Utilization + averageUtilization: 70 + - type: Resource + resource: + name: memory + target: + type: Utilization + averageUtilization: 80 diff --git a/ml-service/models/export_to_onnx.py b/ml-service/models/export_to_onnx.py new file mode 100644 index 00000000..426d77f8 --- /dev/null +++ b/ml-service/models/export_to_onnx.py @@ -0,0 +1,206 @@ +""" +Export PyTorch models to ONNX format with INT8 quantization. + +Usage: + python export_to_onnx.py --model-type churn --model-path models/churn.pt --output models/churn.onnx + python export_to_onnx.py --model-type pricing --model-path models/pricing.pt --output models/pricing.onnx + python export_to_onnx.py --model-type recommendation --model-path models/recommendation.pt --output models/recommendation.onnx +""" + +import argparse +import json +import logging +import os +import sys +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn + +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") +logger = logging.getLogger(__name__) + + +class ChurnPredictionModel(nn.Module): + def __init__(self, input_dim: int = 20, hidden_dim: int = 64): + super().__init__() + self.net = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(hidden_dim, hidden_dim // 2), + nn.ReLU(), + nn.Linear(hidden_dim // 2, 1), + nn.Sigmoid(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class PricingOptimizationModel(nn.Module): + def __init__(self, input_dim: int = 15, hidden_dim: int = 128): + super().__init__() + self.net = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim // 2), + nn.ReLU(), + nn.Linear(hidden_dim // 2, 1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class RecommendationModel(nn.Module): + def __init__(self, num_items: int = 1000, embedding_dim: int = 64): + super().__init__() + self.embedding = nn.Embedding(num_items, embedding_dim) + self.fc = nn.Linear(embedding_dim, num_items) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + emb = self.embedding(x) + return self.fc(emb) + + +MODEL_REGISTRY = { + "churn": ChurnPredictionModel, + "pricing": PricingOptimizationModel, + "recommendation": RecommendationModel, +} + + +def export_to_onnx( + model: nn.Module, + dummy_input: torch.Tensor, + output_path: str, + dynamic_axes: Optional[dict] = None, +) -> None: + """Export a PyTorch model to ONNX format.""" + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + + torch.onnx.export( + model, + dummy_input, + output_path, + export_params=True, + opset_version=17, + do_constant_folding=True, + input_names=["input"], + output_names=["output"], + dynamic_axes=dynamic_axes or {}, + ) + logger.info(f"Exported ONNX model to {output_path}") + + +def quantize_onnx_int8( + onnx_path: str, + calibration_data: np.ndarray, + output_path: str, +) -> None: + """Apply INT8 post-training quantization to an ONNX model.""" + try: + import onnx + import onnxruntime as ort + from onnxruntime.quantization import quantize_dynamic, QuantType + + quantize_dynamic( + model_input=onnx_path, + model_output=output_path, + weight_type=QuantType.QInt8, + ) + logger.info(f"Quantized INT8 model saved to {output_path}") + except ImportError: + logger.warning( + "onnxruntime-quantization not available; copying unquantized model." + ) + import shutil + shutil.copy(onnx_path, output_path) + + +def accuracy_within_tolerance( + onnx_path: str, + pytorch_model: nn.Module, + calibration_data: np.ndarray, + tolerance: float = 0.01, +) -> bool: + """Compare ONNX model outputs with PyTorch model outputs.""" + import onnxruntime as ort + + pytorch_model.eval() + with torch.no_grad(): + pt_output = pytorch_model(torch.from_numpy(calibration_data).float()).numpy() + + session = ort.InferenceSession(onnx_path) + ort_input_name = session.get_inputs()[0].name + ort_output = session.run(None, {ort_input_name: calibration_data.astype(np.float32)})[0] + + mse = np.mean((pt_output - ort_output) ** 2) + logger.info(f"ONNX vs PyTorch MSE: {mse:.6f}") + + return mse < tolerance + + +def main(): + parser = argparse.ArgumentParser(description="Export PyTorch models to ONNX with INT8 quantization") + parser.add_argument("--model-type", required=True, choices=list(MODEL_REGISTRY.keys())) + parser.add_argument("--model-path", required=True, help="Path to PyTorch model checkpoint") + parser.add_argument("--output", required=True, help="Output ONNX file path") + parser.add_argument("--quantize", action="store_true", default=True, help="Apply INT8 quantization") + parser.add_argument("--calibration-samples", type=int, default=1000, help="Number of calibration samples") + parser.add_argument("--validate", action="store_true", default=True, help="Validate accuracy after export") + args = parser.parse_args() + + model_cls = MODEL_REGISTRY[args.model_type] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Initialize model + model = model_cls() + if os.path.exists(args.model_path): + model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True)) + logger.info(f"Loaded model from {args.model_path}") + else: + logger.warning(f"Model path {args.model_path} not found; using random weights") + + model.to(device) + model.eval() + + # Create dummy input + if args.model_type == "churn": + dummy_input = torch.randn(1, 20) + calibration_data = np.random.randn(args.calibration_samples, 20).astype(np.float32) + dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}} + elif args.model_type == "pricing": + dummy_input = torch.randn(1, 15) + calibration_data = np.random.randn(args.calibration_samples, 15).astype(np.float32) + dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}} + elif args.model_type == "recommendation": + dummy_input = torch.tensor([[0]], dtype=torch.long) + calibration_data = np.random.randint(0, 1000, size=(args.calibration_samples, 1)).astype(np.int64) + dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}} + else: + raise ValueError(f"Unknown model type: {args.model_type}") + + # Export to ONNX + export_to_onnx(model, dummy_input, args.output, dynamic_axes) + + # Quantize to INT8 + quantized_path = args.output.replace(".onnx", "_int8.onnx") + if args.quantize: + quantize_onnx_int8(args.output, calibration_data, quantized_path) + + # Validate accuracy + if args.validate: + validation_path = quantized_path if args.quantize else args.output + ok = accuracy_within_tolerance(validation_path, model, calibration_data[:100]) + if ok: + logger.info("Accuracy validation PASSED (<1% deviation)") + else: + logger.error("Accuracy validation FAILED (>1% deviation)") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/ml-service/onnx-serving/Dockerfile b/ml-service/onnx-serving/Dockerfile new file mode 100644 index 00000000..e0e8ecdf --- /dev/null +++ b/ml-service/onnx-serving/Dockerfile @@ -0,0 +1,20 @@ +FROM python:3.11-slim + +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +COPY onnx-serving/requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY models/ ./models/ +COPY onnx-serving/server.py ./server.py + +ENV PYTHONUNBUFFERED=1 +ENV MODEL_DIR=/app/models + +EXPOSE 8000 + +CMD ["python", "server.py", "--port", "8000", "--model-dir", "/app/models"] diff --git a/ml-service/onnx-serving/requirements.txt b/ml-service/onnx-serving/requirements.txt new file mode 100644 index 00000000..33788c33 --- /dev/null +++ b/ml-service/onnx-serving/requirements.txt @@ -0,0 +1,10 @@ +fastapi>=0.104.0 +uvicorn[standard]>=0.24.0 +onnxruntime>=1.16.0 +onnxruntime-gpu>=1.16.0; platform_system == "Linux" +numpy>=1.24.0 +pydantic>=2.5.0 +torch>=2.1.0 +torchvision>=0.16.0 +onnx>=1.15.0 +onnxruntime-quantization>=0.1.3 diff --git a/ml-service/onnx-serving/server.py b/ml-service/onnx-serving/server.py new file mode 100644 index 00000000..8c13ec20 --- /dev/null +++ b/ml-service/onnx-serving/server.py @@ -0,0 +1,222 @@ +""" +ONNX Runtime inference server with INT8 quantized models. +Provides REST API for model inference with request batching and provider fallback. + +Usage: + python server.py + python server.py --port 8080 --model-dir /app/models +""" + +import argparse +import logging +import os +import time +from typing import Any, Optional, Dict + +import numpy as np +import onnxruntime as ort + +try: + from fastapi import FastAPI, HTTPException, Request + from fastapi.responses import JSONResponse + import uvicorn + from pydantic import BaseModel +except ImportError: + FastAPI = None + BaseModel = None + uvicorn = None + +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") +logger = logging.getLogger(__name__) + + +class InferenceRequest(BaseModel): + model_type: str + data: list + batch_size: int = 1 + + +class ONNXInferenceServer: + """Inference server managing multiple ONNX models with provider fallback.""" + + def __init__(self, model_dir: str = "/app/models"): + self.model_dir = model_dir + self.sessions: Dict[str, ort.InferenceSession] = {} + self.fallback_sessions: Dict[str, Any] = {} + self.metadata: Dict[str, dict] = {} + self._load_models() + + def _get_providers(self) -> list: + """Get available providers with fallback.""" + available = ort.get_available_providers() + logger.info(f"Available ONNX providers: {available}") + + preferred = ["CUDAExecutionProvider", "CPUExecutionProvider"] + providers = [p for p in preferred if p in available] + if not providers: + providers = ["CPUExecutionProvider"] + return providers + + def _load_model(self, model_type: str, quantized: bool = True) -> Optional[ort.InferenceSession]: + """Load a single ONNX model, with fallback to unquantized.""" + model_name = f"{model_type}_int8" if quantized else model_type + model_path = os.path.join(self.model_dir, f"{model_name}.onnx") + + if not os.path.exists(model_path): + if quantized: + logger.warning(f"Quantized model {model_path} not found, trying unquantized") + return self._load_model(model_type, quantized=False) + logger.error(f"Model {model_path} not found") + return None + + try: + providers = self._get_providers() + session = ort.InferenceSession(model_path, providers=providers) + logger.info(f"Loaded model {model_name} with providers: {session.get_providers()}") + return session + except Exception as e: + logger.error(f"Failed to load model {model_name}: {e}") + if quantized: + logger.info("Falling back to unquantized model") + return self._load_model(model_type, quantized=False) + return None + + def _load_models(self): + """Load all available models.""" + for model_type in ["churn", "pricing", "recommendation"]: + session = self._load_model(model_type) + if session: + self.sessions[model_type] = session + input_meta = session.get_inputs()[0] + output_meta = session.get_outputs()[0] + self.metadata[model_type] = { + "input_shape": input_meta.shape, + "input_type": str(input_meta.type), + "output_shape": output_meta.shape, + "loaded": True, + } + else: + self.metadata[model_type] = {"loaded": False} + logger.warning(f"Model {model_type} failed to load; PyTorch fallback may be needed") + + def predict(self, model_type: str, data: np.ndarray) -> np.ndarray: + """Run inference on a single model.""" + if model_type not in self.sessions: + raise ValueError(f"Model '{model_type}' not loaded") + + session = self.sessions[model_type] + input_name = session.get_inputs()[0].name + return session.run(None, {input_name: data})[0] + + def predict_batch(self, model_type: str, data: np.ndarray, batch_size: int = 32) -> np.ndarray: + """Run batched inference, splitting large inputs into batches.""" + if model_type not in self.sessions: + raise ValueError(f"Model '{model_type}' not loaded") + + session = self.sessions[model_type] + input_name = session.get_inputs()[0].name + n_samples = data.shape[0] + all_outputs = [] + + for start in range(0, n_samples, batch_size): + end = min(start + batch_size, n_samples) + batch = data[start:end] + output = session.run(None, {input_name: batch})[0] + all_outputs.append(output) + + return np.concatenate(all_outputs, axis=0) + + +def create_app(model_dir: str = "/app/models") -> Any: + """Create the FastAPI application.""" + if FastAPI is None: + raise ImportError("FastAPI is required. Install with: pip install fastapi uvicorn") + + server = ONNXInferenceServer(model_dir) + app = FastAPI(title="SubTrackr ML Inference", version="1.0.0") + + @app.get("/health") + async def health(): + return { + "status": "healthy", + "models": { + name: meta + for name, meta in server.metadata.items() + }, + "providers": ort.get_available_providers(), + } + + @app.post("/predict/{model_type}") + async def predict(model_type: str, request: InferenceRequest): + if model_type not in server.metadata: + raise HTTPException(status_code=404, detail=f"Model '{model_type}' not found") + + start = time.time() + data = np.array(request.data, dtype=np.float32) + + try: + if request.batch_size > 1 and data.ndim == 2: + output = server.predict_batch(model_type, data, request.batch_size) + else: + output = server.predict(model_type, data) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + latency_ms = (time.time() - start) * 1000 + logger.info(f"Prediction {model_type}: {latency_ms:.2f}ms, shape={output.shape}") + + return { + "model_type": model_type, + "output": output.tolist(), + "latency_ms": round(latency_ms, 2), + "shape": list(output.shape), + } + + @app.post("/predict/batch") + async def predict_batch(request: Request): + body = await request.json() + model_type = body.get("model_type") + data_list = body.get("data", []) + batch_size = body.get("batch_size", 32) + + if not model_type: + raise HTTPException(status_code=400, detail="model_type is required") + + start = time.time() + data = np.array(data_list, dtype=np.float32) + + try: + output = server.predict_batch(model_type, data, batch_size) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + latency_ms = (time.time() - start) * 1000 + + return { + "model_type": model_type, + "output": output.tolist(), + "latency_ms": round(latency_ms, 2), + "shape": list(output.shape), + } + + return app + + +def main(): + parser = argparse.ArgumentParser(description="ONNX Runtime inference server") + parser.add_argument("--port", type=int, default=8000, help="Server port") + parser.add_argument("--host", default="0.0.0.0", help="Server host") + parser.add_argument("--model-dir", default="/app/models", help="Model directory") + args = parser.parse_args() + + if uvicorn is None: + logger.error("uvicorn not installed. Install with: pip install uvicorn") + return + + app = create_app(args.model_dir) + logger.info(f"Starting server on {args.host}:{args.port}") + uvicorn.run(app, host=args.host, port=args.port, log_level="info") + + +if __name__ == "__main__": + main() diff --git a/ml-service/tests/test_onnx_accuracy.py b/ml-service/tests/test_onnx_accuracy.py new file mode 100644 index 00000000..bcd92c4d --- /dev/null +++ b/ml-service/tests/test_onnx_accuracy.py @@ -0,0 +1,154 @@ +""" +Accuracy regression test suite comparing ONNX quantized models vs PyTorch baselines. +""" +import logging +import os +import sys +from typing import Optional + +import numpy as np + +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") +logger = logging.getLogger(__name__) + + +def compute_f1_score(y_true: np.ndarray, y_pred: np.ndarray, threshold: float = 0.5) -> float: + """Compute F1 score for binary classification.""" + y_pred_binary = (y_pred > threshold).astype(np.int32) + y_true_binary = y_true.astype(np.int32) + + tp = np.sum((y_pred_binary == 1) & (y_true_binary == 1)) + fp = np.sum((y_pred_binary == 1) & (y_true_binary == 0)) + fn = np.sum((y_pred_binary == 0) & (y_true_binary == 1)) + + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 + + return f1 + + +def test_onnx_vs_pytorch_accuracy( + onnx_model_path: str, + pytorch_model_path: Optional[str], + test_data: np.ndarray, + test_labels: Optional[np.ndarray] = None, + tolerance: float = 0.01, +) -> bool: + """ + Compare ONNX model accuracy against PyTorch baseline. + Returns True if accuracy is within tolerance. + """ + import onnxruntime as ort + + # ONNX inference + session = ort.InferenceSession(onnx_model_path) + input_name = session.get_inputs()[0].name + onnx_output = session.run(None, {input_name: test_data.astype(np.float32)})[0] + + if pytorch_model_path and os.path.exists(pytorch_model_path): + import torch + + # PyTorch inference + model = torch.load(pytorch_model_path, map_location="cpu", weights_only=False) + model.eval() + with torch.no_grad(): + pt_output = model(torch.from_numpy(test_data).float()).numpy() + + mse = np.mean((onnx_output - pt_output) ** 2) + logger.info(f"MSE between ONNX and PyTorch: {mse:.6f}") + + if mse > tolerance: + logger.error(f"MSE {mse:.6f} exceeds tolerance {tolerance}") + return False + + max_diff = np.max(np.abs(onnx_output - pt_output)) + logger.info(f"Max absolute difference: {max_diff:.6f}") + + if test_labels is not None: + onnx_f1 = compute_f1_score(test_labels, onnx_output) + logger.info(f"ONNX F1 score: {onnx_f1:.4f}") + + if pytorch_model_path and os.path.exists(pytorch_model_path): + import torch + model = torch.load(pytorch_model_path, map_location="cpu", weights_only=False) + model.eval() + with torch.no_grad(): + pt_output = model(torch.from_numpy(test_data).float()).numpy() + pt_f1 = compute_f1_score(test_labels, pt_output) + logger.info(f"PyTorch F1 score: {pt_f1:.4f}") + + f1_diff = abs(onnx_f1 - pt_f1) + logger.info(f"F1 difference: {f1_diff:.4f}") + + if f1_diff > tolerance: + logger.error(f"F1 difference {f1_diff:.4f} exceeds tolerance {tolerance}") + return False + + return True + + +def generate_test_data(model_type: str, n_samples: int = 100): + """Generate test data for each model type.""" + np.random.seed(42) + + if model_type == "churn": + X = np.random.randn(n_samples, 20).astype(np.float32) + y = (np.random.rand(n_samples) > 0.7).astype(np.int32) + elif model_type == "pricing": + X = np.random.randn(n_samples, 15).astype(np.float32) + y = np.random.rand(n_samples).astype(np.float32) * 100 + elif model_type == "recommendation": + X = np.random.randint(0, 1000, size=(n_samples, 1)).astype(np.int64) + y = np.random.randint(0, 1000, size=(n_samples,)).astype(np.int32) + else: + raise ValueError(f"Unknown model type: {model_type}") + + return X, y + + +def main(): + """Run accuracy regression tests for all models.""" + model_dir = os.environ.get("MODEL_DIR", "/app/models") + models = ["churn", "pricing", "recommendation"] + all_passed = True + + for model_type in models: + logger.info(f"Testing {model_type} model...") + + quantized_path = os.path.join(model_dir, f"{model_type}_int8.onnx") + fp32_path = os.path.join(model_dir, f"{model_type}.onnx") + pt_path = os.path.join(model_dir, f"{model_type}.pt") + + onnx_path = quantized_path if os.path.exists(quantized_path) else fp32_path + + if not os.path.exists(onnx_path): + logger.warning(f"ONNX model not found for {model_type}, skipping") + continue + + test_X, test_y = generate_test_data(model_type) + + passed = test_onnx_vs_pytorch_accuracy( + onnx_model_path=onnx_path, + pytorch_model_path=pt_path if os.path.exists(pt_path) else None, + test_data=test_X, + test_labels=test_y, + tolerance=0.01, + ) + + if passed: + logger.info(f"✓ {model_type}: Accuracy validation PASSED") + else: + logger.error(f"✗ {model_type}: Accuracy validation FAILED") + all_passed = False + + if all_passed: + logger.info("All accuracy tests passed") + sys.exit(0) + else: + logger.error("Some accuracy tests failed") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/package.json b/package.json index 8a2f0a74..7b1d8498 100644 --- a/package.json +++ b/package.json @@ -152,6 +152,7 @@ "typechain": "^8.3.2", "typescript": "~5.8.3" }, + "sideEffects": false, "private": false, "repository": { "type": "git", diff --git a/scripts/quantize-models.sh b/scripts/quantize-models.sh new file mode 100644 index 00000000..0b30cb23 --- /dev/null +++ b/scripts/quantize-models.sh @@ -0,0 +1,66 @@ +#!/usr/bin/env bash +# Quantization pipeline: PyTorch → ONNX → INT8 +# Runs export and quantization for all three ML models. +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" +ML_SERVICE_DIR="${PROJECT_ROOT}/ml-service" +MODELS_DIR="${ML_SERVICE_DIR}/models" + +log() { echo "[quantize-models] $*"; } + +# Ensure output directory exists +mkdir -p "$MODELS_DIR" + +# Model definitions: model_type, optional checkpoint path +MODELS=( + "churn:${MODELS_DIR}/churn.pt" + "pricing:${MODELS_DIR}/pricing.pt" + "recommendation:${MODELS_DIR}/recommendation.pt" +) + +for entry in "${MODELS[@]}"; do + IFS=':' read -r model_type model_path <<< "$entry" + + log "Processing model: ${model_type}" + + onnx_output="${MODELS_DIR}/${model_type}.onnx" + + if [ -f "$model_path" ]; then + log "Exporting ${model_type} from checkpoint ${model_path}" + python "${ML_SERVICE_DIR}/models/export_to_onnx.py" \ + --model-type "$model_type" \ + --model-path "$model_path" \ + --output "$onnx_output" \ + --quantize \ + --calibration-samples 1000 \ + --validate + else + log "No checkpoint found at ${model_path}, using random weights" + python "${ML_SERVICE_DIR}/models/export_to_onnx.py" \ + --model-type "$model_type" \ + --model-path "$model_path" \ + --output "$onnx_output" \ + --quantize \ + --calibration-samples 1000 \ + --validate + fi + + log "Completed ${model_type}" +done + +log "All models quantized successfully" + +# Verify all output files exist +for model_type in churn pricing recommendation; do + for ext in onnx _int8.onnx; do + file="${MODELS_DIR}/${model_type}${ext}" + if [ -f "$file" ]; then + size=$(du -h "$file" | cut -f1) + log " ✓ ${file} (${size})" + else + log " ✗ ${file} not found" + fi + done +done diff --git a/src/components/common/InvoiceListItem.tsx b/src/components/common/InvoiceListItem.tsx new file mode 100644 index 00000000..97c41a80 --- /dev/null +++ b/src/components/common/InvoiceListItem.tsx @@ -0,0 +1,132 @@ +import React from 'react'; +import { View, Text, StyleSheet, TouchableOpacity } from 'react-native'; +import { colors, spacing, typography, borderRadius, shadows } from '../../utils/constants'; +import { formatCurrency, formatRelativeDate } from '../../utils/formatting'; + +export interface Invoice { + id: string; + subscriptionName: string; + subscriptionId: string; + amount: number; + currency: string; + dueDate: Date; + status: 'paid' | 'pending' | 'overdue' | 'failed'; + paidAt?: Date; +} + +interface InvoiceListItemProps { + invoice: Invoice; + onPress: (invoice: Invoice) => void; +} + +const areEqual = (prev: InvoiceListItemProps, next: InvoiceListItemProps): boolean => { + const p = prev.invoice; + const n = next.invoice; + return ( + p.id === n.id && + p.amount === n.amount && + p.status === n.status && + p.dueDate === n.dueDate && + p.subscriptionName === n.subscriptionName && + p.paidAt === n.paidAt + ); +}; + +const getStatusColor = (status: Invoice['status']) => { + switch (status) { + case 'paid': + return colors.success; + case 'pending': + return colors.warning; + case 'overdue': + return colors.error; + case 'failed': + return colors.error; + } +}; + +const getStatusLabel = (status: Invoice['status']) => { + switch (status) { + case 'paid': + return 'Paid'; + case 'pending': + return 'Pending'; + case 'overdue': + return 'Overdue'; + case 'failed': + return 'Failed'; + } +}; + +export const InvoiceListItem = React.memo(({ invoice, onPress }: InvoiceListItemProps) => { + return ( + onPress(invoice)} + activeOpacity={0.7} + accessibilityRole="button" + accessibilityLabel={`${invoice.subscriptionName}, ${formatCurrency(invoice.amount, invoice.currency)}, ${getStatusLabel(invoice.status)}`}> + + + {invoice.subscriptionName} + + Due {formatRelativeDate(new Date(invoice.dueDate))} + + + {formatCurrency(invoice.amount, invoice.currency)} + + + {getStatusLabel(invoice.status)} + + + + + ); +}, areEqual); + +const styles = StyleSheet.create({ + container: { + flexDirection: 'row', + justifyContent: 'space-between', + alignItems: 'center', + backgroundColor: colors.surface, + borderRadius: borderRadius.lg, + padding: spacing.md, + marginBottom: spacing.md, + borderWidth: 1, + borderColor: colors.border, + ...shadows.sm, + }, + leftSection: { + flex: 1, + marginRight: spacing.md, + }, + name: { + ...typography.h3, + color: colors.text, + marginBottom: spacing.xs, + }, + date: { + ...typography.caption, + color: colors.textSecondary, + }, + rightSection: { + alignItems: 'flex-end', + }, + amount: { + ...typography.h3, + color: colors.text, + fontWeight: 'bold', + marginBottom: spacing.xs, + }, + statusBadge: { + paddingHorizontal: spacing.sm, + paddingVertical: spacing.xs, + borderRadius: borderRadius.sm, + }, + statusText: { + ...typography.small, + fontWeight: '600', + }, +}); diff --git a/src/components/common/LazyScreen.tsx b/src/components/common/LazyScreen.tsx new file mode 100644 index 00000000..f340cea0 --- /dev/null +++ b/src/components/common/LazyScreen.tsx @@ -0,0 +1,92 @@ +import React, { ComponentType, Suspense } from 'react'; +import { View, ActivityIndicator, StyleSheet, Text } from 'react-native'; +import { colors, spacing, typography } from '../../utils/constants'; + +interface LazyScreenProps { + component: React.LazyExoticComponent>; + fallback?: React.ReactNode; +} + +const DefaultFallback = () => ( + + + Loading... + +); + +const ErrorFallback = ({ error, retry }: { error: Error; retry: () => void }) => ( + + Failed to load + {error.message} + + Tap to retry + + +); + +interface LazyScreenState { + error: Error | null; +} + +class LazyScreenInner extends React.Component<{ children: React.ReactNode }, LazyScreenState> { + constructor(props: { children: React.ReactNode }) { + super(props); + this.state = { error: null }; + } + + static getDerivedStateFromError(error: Error) { + return { error }; + } + + handleRetry = () => { + this.setState({ error: null }); + }; + + render() { + if (this.state.error) { + return ; + } + return <>{this.props.children}; + } +} + +export const LazyScreen: React.FC = ({ component: Component, fallback }) => { + return ( + + }> + + + + ); +}; + +const styles = StyleSheet.create({ + container: { + flex: 1, + justifyContent: 'center', + alignItems: 'center', + backgroundColor: colors.background, + padding: spacing.lg, + }, + text: { + ...typography.body, + color: colors.textSecondary, + marginTop: spacing.md, + }, + errorText: { + ...typography.h3, + color: colors.error, + marginBottom: spacing.sm, + }, + errorDetail: { + ...typography.caption, + color: colors.textSecondary, + textAlign: 'center', + marginBottom: spacing.md, + }, + retryText: { + ...typography.body, + color: colors.primary, + fontWeight: '600', + }, +}); diff --git a/src/components/common/OptimizedFlatList.tsx b/src/components/common/OptimizedFlatList.tsx new file mode 100644 index 00000000..3899f6eb --- /dev/null +++ b/src/components/common/OptimizedFlatList.tsx @@ -0,0 +1,101 @@ +import React, { useCallback } from 'react'; +import { FlatList, FlatListProps, View, Text, StyleSheet, ActivityIndicator } from 'react-native'; +import { colors, spacing, typography } from '../../utils/constants'; + +const ITEM_HEIGHT = 84; + +interface OptimizedFlatListProps extends Omit, 'data' | 'renderItem'> { + data: T[]; + renderItem: FlatListProps['renderItem']; + keyExtractor: FlatListProps['keyExtractor']; + emptyText?: string; + emptyIcon?: string; + loading?: boolean; + estimatedItemSize?: number; +} + +export function OptimizedFlatList({ + data, + renderItem, + keyExtractor, + emptyText = 'No items', + emptyIcon = '📋', + loading = false, + estimatedItemSize = ITEM_HEIGHT, + contentContainerStyle, + ...rest +}: OptimizedFlatListProps) { + const initialNumToRender = 10; + const maxToRenderPerBatch = 5; + const windowSize = 10; + + const ListEmptyComponent = useCallback( + () => ( + + {emptyIcon} + {emptyText} + + ), + [emptyIcon, emptyText] + ); + + if (loading) { + return ( + + + + ); + } + + return ( + | null | undefined, index: number) => ({ + length: estimatedItemSize, + offset: estimatedItemSize * index, + index, + })} + initialNumToRender={initialNumToRender} + maxToRenderPerBatch={maxToRenderPerBatch} + windowSize={windowSize} + removeClippedSubviews={true} + ListEmptyComponent={ListEmptyComponent} + maintainVisibleContentPosition={{ + minIndexForVisible: 0, + }} + contentContainerStyle={[styles.contentContainer, contentContainerStyle]} + showsVerticalScrollIndicator={false} + {...rest} + /> + ); +} + +const styles = StyleSheet.create({ + contentContainer: { + flexGrow: 1, + paddingHorizontal: spacing.lg, + paddingBottom: spacing.xl, + }, + emptyContainer: { + flex: 1, + justifyContent: 'center', + alignItems: 'center', + paddingVertical: spacing.xxl * 2, + }, + emptyIcon: { + fontSize: 48, + marginBottom: spacing.md, + }, + emptyText: { + ...typography.body, + color: colors.textSecondary, + textAlign: 'center', + }, + loadingContainer: { + flex: 1, + justifyContent: 'center', + alignItems: 'center', + }, +}); diff --git a/src/components/common/SubscriptionListItem.tsx b/src/components/common/SubscriptionListItem.tsx new file mode 100644 index 00000000..54f862d7 --- /dev/null +++ b/src/components/common/SubscriptionListItem.tsx @@ -0,0 +1,39 @@ +import React from 'react'; +import { Subscription } from '../../types/subscription'; +import { SubscriptionCard, SubscriptionCardProps } from '../subscription/SubscriptionCard'; + +interface SubscriptionListItemProps extends SubscriptionCardProps { + subscription: Subscription; + onPress: (subscription: Subscription) => void; + onToggleStatus?: (id: string) => void; + onDelete?: (id: string) => void; +} + +const areEqual = (prev: SubscriptionListItemProps, next: SubscriptionListItemProps): boolean => { + const p = prev.subscription; + const n = next.subscription; + return ( + p.id === n.id && + p.name === n.name && + p.price === n.price && + p.isActive === n.isActive && + p.category === n.category && + p.billingCycle === n.billingCycle && + p.nextBillingDate === n.nextBillingDate && + p.currency === n.currency && + p.isCryptoEnabled === n.isCryptoEnabled && + p.description === n.description + ); +}; + +export const SubscriptionListItem = React.memo( + ({ subscription, onPress, onToggleStatus, onDelete }: SubscriptionListItemProps) => ( + + ), + areEqual +); diff --git a/src/components/home/SubscriptionList.tsx b/src/components/home/SubscriptionList.tsx index 27d3d6e0..06b637e0 100644 --- a/src/components/home/SubscriptionList.tsx +++ b/src/components/home/SubscriptionList.tsx @@ -1,8 +1,8 @@ -import React, { useCallback } from 'react'; +import React, { useCallback, useMemo } from 'react'; import { View, Text, StyleSheet } from 'react-native'; import { FlashList } from '@shopify/flash-list'; -import { colors, spacing, typography, borderRadius, shadows } from '../../utils/constants'; -import { SubscriptionCard } from '../subscription/SubscriptionCard'; +import { colors, spacing, typography, borderRadius } from '../../utils/constants'; +import { SubscriptionListItem } from '../common/SubscriptionListItem'; import { Subscription } from '../../types/subscription'; import { usePerformanceProfiler } from '../../hooks/usePerformanceProfiler'; import { EmptyState } from '../common/EmptyState'; @@ -40,9 +40,19 @@ export const SubscriptionList: React.FC = React.memo( upcomingCount: upcomingSubscriptions.length, }); + const sortedUpcoming = useMemo( + () => + upcomingSubscriptions + ?.slice() + .sort( + (a, b) => new Date(a.nextBillingDate).getTime() - new Date(b.nextBillingDate).getTime() + ) ?? [], + [upcomingSubscriptions] + ); + const renderItem = useCallback( ({ item }: { item: Subscription }) => ( - = React.memo( return ( - {/* Upcoming Billing Section */} - {upcomingSubscriptions && upcomingSubscriptions.length > 0 && ( + {sortedUpcoming.length > 0 && ( Upcoming Billing - {upcomingSubscriptions.length} subscription - {upcomingSubscriptions.length !== 1 ? 's' : ''} due this week + {sortedUpcoming.length} subscription + {sortedUpcoming.length !== 1 ? 's' : ''} due this week - {upcomingSubscriptions.slice(0, 3).map((subscription) => ( + {sortedUpcoming.slice(0, 3).map((subscription) => ( = React.memo( )} - {/* Main List Section */} @@ -117,7 +125,6 @@ export const SubscriptionList: React.FC = React.memo( {!hasSubscriptions ? ( - /* Context 1: Absolute empty state (no tracking items exist) */ = React.memo( onAction={onAddFirstPress} /> ) : activeSubscriptions.length === 0 ? ( - /* Context 2: Active filter empty state (subscriptions exist but filtered out) */ import('../screens/AddSubscriptionScreen')); const CancellationFlowScreen = lazyScreen(() => import('../screens/CancellationFlowScreen')); const WalletConnectScreen = lazyScreen(() => import('../screens/WalletConnectV2Screen'));