Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions daft_lance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ._lance import (
compact_files,
create_scalar_index,
create_vector_index,
merge_columns,
merge_columns_df,
read_lance,
Expand All @@ -15,6 +16,7 @@
__all__ = [
"compact_files",
"create_scalar_index",
"create_vector_index",
"merge_columns",
"merge_columns_df",
"read_lance",
Expand Down
143 changes: 143 additions & 0 deletions daft_lance/_lance.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,149 @@ def create_scalar_index(
)


_VECTOR_INDEX_TYPES = frozenset({"IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", "IVF_FLAT", "IVF_SQ"})


@PublicAPI
def create_vector_index(
uri: str | pathlib.Path,
io_config: IOConfig | None = None,
*,
column: str,
index_type: str = "IVF_PQ",
name: str | None = None,
metric: str = "L2",
replace: bool = True,
num_partitions: int | None = None,
num_sub_vectors: int | None = None,
accelerator: str | None = None,
storage_options: dict[str, Any] | None = None,
version: int | str | None = None,
asof: str | None = None,
block_size: int | None = None,
commit_lock: Any | None = None,
index_cache_size: int | None = None,
default_scan_options: dict[str, Any] | None = None,
metadata_cache_size_bytes: int | None = None,
**kwargs: Any,
) -> None:
"""Create a vector index on a Lance dataset column.

This is a passthrough to ``lance.LanceDataset.create_index()`` with the same
dataset construction pattern used by other daft-lance functions (URI,
io_config, storage_options, version, etc.).

Vector indices (IVF_PQ, IVF_HNSW_PQ, etc.) require a global training phase
across all data, so this function runs the index build in-process rather than
distributing it across workers.

Args:
uri: The URI of the Lance table (supports remote URLs to object stores such as ``s3://`` or ``gs://``).
io_config: A custom IOConfig to use when accessing LanceDB data. Defaults to None.
column: Column name to index. Must be a fixed-size list (vector) column.
index_type: Type of vector index to build. Supported values:
``"IVF_PQ"``, ``"IVF_HNSW_PQ"``, ``"IVF_HNSW_SQ"``, ``"IVF_FLAT"``, ``"IVF_SQ"``.
Defaults to ``"IVF_PQ"``.
name: Name of the index. If not provided, Lance generates one from the column name.
metric: Distance metric type. One of ``"L2"`` (euclidean), ``"cosine"``, or ``"dot"``.
Defaults to ``"L2"``.
replace: Whether to replace an existing index with the same name. Defaults to True.
num_partitions: Number of IVF partitions. If None, Lance picks a default based on dataset size.
num_sub_vectors: Number of sub-vectors for Product Quantization (PQ). Only used with PQ-based
index types (``IVF_PQ``, ``IVF_HNSW_PQ``).
accelerator: Hardware accelerator for training. ``"cuda"`` (Nvidia GPU) or ``"mps"``
(Apple Silicon GPU). If None, uses CPU.
storage_options: Extra options for storage connection.
version: If specified, load a specific version of the Lance dataset.
asof: If specified, find the latest version created on or earlier than the given argument value.
block_size: Block size in bytes for I/O.
commit_lock: A custom commit lock.
index_cache_size: Index cache size (number of entries).
default_scan_options: Default scan options for the dataset.
metadata_cache_size_bytes: Size of the metadata cache in bytes.
**kwargs: Additional keyword arguments forwarded to ``lance.LanceDataset.create_index()``
(e.g., ``ivf_centroids``, ``pq_codebook``, ``target_partition_size``,
``filter_nan``, ``shuffle_partition_batches``).

Returns:
None

Raises:
ValueError: If the column does not exist, is not a vector (fixed-size list) type,
or if ``index_type`` is not a supported vector index type.

Note:
This function requires the use of `LanceDB <https://lancedb.github.io/lancedb/>`_.
Install with: ``pip install daft[lance]``

Examples:
Create an IVF_PQ index on a vector column:
>>> import daft_lance
>>> daft_lance.create_vector_index("s3://my-bucket/dataset/", column="embedding")

Create an IVF_HNSW_SQ index with cosine distance:
>>> daft_lance.create_vector_index(
... "s3://my-bucket/dataset/",
... column="embedding",
... index_type="IVF_HNSW_SQ",
... metric="cosine",
... )

Create an index with GPU acceleration:
>>> daft_lance.create_vector_index(
... "/path/to/dataset/",
... column="vector",
... accelerator="cuda",
... num_partitions=256,
... num_sub_vectors=16,
... )
"""
index_type_upper = index_type.upper()
if index_type_upper not in _VECTOR_INDEX_TYPES:
raise ValueError(
f"Unsupported vector index type: {index_type!r}. Supported types: {sorted(_VECTOR_INDEX_TYPES)}"
)

io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
storage_options = storage_options or io_config_to_storage_options(io_config, str(uri))

lance_ds = construct_lance_dataset(
uri,
storage_options=storage_options,
version=version,
asof=asof,
block_size=block_size,
commit_lock=commit_lock,
index_cache_size=index_cache_size,
default_scan_options=default_scan_options,
metadata_cache_size_bytes=metadata_cache_size_bytes,
)

# Validate column exists and is a vector type
schema = lance_ds.schema
if column not in schema.names:
raise ValueError(f"Column {column!r} not found in dataset. Available columns: {schema.names}")

col_type = schema.field(column).type
if not hasattr(col_type, "list_size"):
raise ValueError(
f"Column {column!r} has type {col_type}, which is not a vector type. "
f"Vector index requires a fixed-size list column (e.g., FixedSizeList(float32, N))."
)

lance_ds.create_index(
column,
index_type_upper,
name=name,
metric=metric,
replace=replace,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
accelerator=accelerator,
**kwargs,
)


@PublicAPI
def compact_files(
uri: str | pathlib.Path,
Expand Down
184 changes: 184 additions & 0 deletions tests/io/lancedb/test_lancedb_vector_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from __future__ import annotations

import tempfile
from pathlib import Path

import lance
import numpy as np
import pyarrow as pa
import pytest

import daft
from daft_lance import create_vector_index


@pytest.fixture
def temp_dir():
"""Create a temporary directory for testing."""
with tempfile.TemporaryDirectory() as temp_dir:
yield temp_dir


def _make_vector_dataset(path: str | Path, *, num_rows: int = 256, dim: int = 16) -> str:
"""Create a Lance dataset with a fixed-size list vector column."""
rng = np.random.default_rng(42)
vectors = rng.standard_normal((num_rows, dim)).astype(np.float32)
vector_type = pa.list_(pa.float32(), dim)
vector_array = pa.FixedSizeListArray.from_arrays(vectors.flatten(), list_size=dim)
ids = pa.array(range(num_rows), type=pa.int64())
table = pa.table({"id": ids, "vector": vector_array.cast(vector_type)})
uri = str(Path(path) / "vector_dataset.lance")
lance.write_dataset(table, uri)
return uri


class TestCreateVectorIndex:
"""Tests for the create_vector_index API."""

def test_create_ivf_pq_index(self, temp_dir):
"""Create an IVF_PQ index and verify it appears in describe_indices."""
uri = _make_vector_dataset(temp_dir)

create_vector_index(uri, column="vector", index_type="IVF_PQ", num_partitions=2, num_sub_vectors=2)

ds = lance.dataset(uri)
indices = ds.describe_indices()
assert len(indices) == 1
assert indices[0].name is not None

def test_create_ivf_pq_index_default_type(self, temp_dir):
"""Default index_type should be IVF_PQ."""
uri = _make_vector_dataset(temp_dir)

create_vector_index(uri, column="vector", num_partitions=2, num_sub_vectors=2)

ds = lance.dataset(uri)
indices = ds.describe_indices()
assert len(indices) == 1

def test_create_index_cosine_metric(self, temp_dir):
"""Create an index with cosine metric."""
uri = _make_vector_dataset(temp_dir)

create_vector_index(
uri,
column="vector",
index_type="IVF_PQ",
metric="cosine",
num_partitions=2,
num_sub_vectors=2,
)

ds = lance.dataset(uri)
indices = ds.describe_indices()
assert len(indices) == 1

def test_create_index_custom_name(self, temp_dir):
"""Create an index with a custom name and verify it."""
uri = _make_vector_dataset(temp_dir)
index_name = "my_custom_vector_idx"

create_vector_index(
uri,
column="vector",
name=index_name,
num_partitions=2,
num_sub_vectors=2,
)

ds = lance.dataset(uri)
indices = ds.describe_indices()
assert len(indices) == 1
assert indices[0].name == index_name

def test_replace_existing_index(self, temp_dir):
"""replace=True should overwrite an existing index without error."""
uri = _make_vector_dataset(temp_dir)

create_vector_index(uri, column="vector", num_partitions=2, num_sub_vectors=2)
# Should not raise because replace=True is the default
create_vector_index(uri, column="vector", num_partitions=2, num_sub_vectors=2, replace=True)

ds = lance.dataset(uri)
indices = ds.describe_indices()
# Should still have exactly one index (replaced, not duplicated)
vector_indices = [idx for idx in indices if "vector" in idx.field_names]
assert len(vector_indices) == 1

def test_replace_false_raises_on_existing(self, temp_dir):
"""replace=False should raise when an index already exists on the column."""
uri = _make_vector_dataset(temp_dir)

create_vector_index(uri, column="vector", num_partitions=2, num_sub_vectors=2)

with pytest.raises(Exception):
create_vector_index(uri, column="vector", replace=False, num_partitions=2, num_sub_vectors=2)

def test_invalid_column_not_found(self, temp_dir):
"""Raise ValueError when the column does not exist in the dataset."""
uri = _make_vector_dataset(temp_dir)

with pytest.raises(ValueError, match="not found in dataset"):
create_vector_index(uri, column="nonexistent_column")

def test_invalid_column_not_vector(self, temp_dir):
"""Raise ValueError when the column is not a vector (fixed-size list) type."""
uri = _make_vector_dataset(temp_dir)

with pytest.raises(ValueError, match="not a vector type"):
create_vector_index(uri, column="id")

def test_invalid_index_type(self, temp_dir):
"""Raise ValueError for unsupported index type."""
uri = _make_vector_dataset(temp_dir)

with pytest.raises(ValueError, match="Unsupported vector index type"):
create_vector_index(uri, column="vector", index_type="INVERTED")

def test_vector_search_after_index_creation(self, temp_dir):
"""Verify that vector search works after creating an index."""
uri = _make_vector_dataset(temp_dir, num_rows=256, dim=8)

create_vector_index(
uri,
column="vector",
index_type="IVF_PQ",
num_partitions=2,
num_sub_vectors=2,
)

# Read with vector search
query = pa.array([1.0] * 8, type=pa.float32())
nearest = {"column": "vector", "q": query, "k": 5}
df = daft.read_lance(uri, default_scan_options={"nearest": nearest})
result = df.select("id").to_pydict()

assert len(result["id"]) == 5

def test_case_insensitive_index_type(self, temp_dir):
"""Index type matching should be case-insensitive."""
uri = _make_vector_dataset(temp_dir)

# lowercase should work
create_vector_index(uri, column="vector", index_type="ivf_pq", num_partitions=2, num_sub_vectors=2)

ds = lance.dataset(uri)
indices = ds.describe_indices()
assert len(indices) == 1

def test_kwargs_forwarded(self, temp_dir):
"""Extra kwargs should be forwarded to lance without error."""
uri = _make_vector_dataset(temp_dir)

# filter_nan is a valid lance create_index kwarg
create_vector_index(
uri,
column="vector",
num_partitions=2,
num_sub_vectors=2,
filter_nan=True,
)

ds = lance.dataset(uri)
indices = ds.describe_indices()
assert len(indices) == 1
Loading
Loading