Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions internal/cli/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func runDatasetPush(ctx context.Context, out, errOut io.Writer, a runDatasetPush
case a.Spec.Category == "":
// Left empty by a caller; let the schema produce the canonical
// "category is required" error downstream.
case push.IsTabular(a.Spec.Category) || a.Spec.Category == "image_classification":
case push.IsTabular(a.Spec.Category) || push.IsText(a.Spec.Category) || a.Spec.Category == "image_classification":
// supported
case push.IsImage(a.Spec.Category):
return &exitError{code: 2, err: fmt.Errorf(
Expand All @@ -290,9 +290,11 @@ func runDatasetPush(ctx context.Context, out, errOut io.Writer, a runDatasetPush
default:
return &exitError{code: 2, err: fmt.Errorf(
"category %q isn't supported by the CLI yet. Supported: image_classification, "+
"tabular_classification, tabular_regression, time_series_forecasting, "+
"time_to_event_prediction. (Text / detection / segmentation are coming; "+
"use the helm flow for those meanwhile.)", a.Spec.Category)}
"text_classification, masked_language_modeling, and the tabular / "+
"time-series family (tabular_classification, tabular_regression, "+
"time_series_forecasting, time_to_event_prediction). (Object detection / "+
"keypoint / segmentation are coming; use the helm flow for those meanwhile.)",
a.Spec.Category)}
}

// 3. Walk the local directory FIRST (local "fail fast"), dispatched
Expand All @@ -305,9 +307,12 @@ func runDatasetPush(ctx context.Context, out, errOut io.Writer, a runDatasetPush
layout *push.LocalLayout
err error
)
if push.IsTabular(a.Spec.Category) {
switch {
case push.IsTabular(a.Spec.Category):
layout, err = push.DiscoverTabular(a.LocalPath)
} else {
case push.IsText(a.Spec.Category):
layout, err = push.DiscoverText(a.Spec.Category, a.LocalPath)
default:
layout, err = push.Discover(a.LocalPath)
}
if err != nil {
Expand All @@ -316,7 +321,8 @@ func runDatasetPush(ctx context.Context, out, errOut io.Writer, a runDatasetPush

// 3a. Per-category spec resolution from the local data, so the
// synthesized spec carries the right fields before validation.
if push.IsTabular(a.Spec.Category) {
switch {
case push.IsTabular(a.Spec.Category):
// Column schema: an explicit --schema wins; otherwise infer
// INT/FLOAT/VARCHAR types from the CSV so the customer doesn't
// hand-write one for the common case.
Expand All @@ -340,7 +346,7 @@ func runDatasetPush(ctx context.Context, out, errOut io.Writer, a runDatasetPush
" (skipped framework-managed column(s): %s)\n", strings.Join(skipped, ", "))
}
}
} else {
case push.IsImage(a.Spec.Category):
// Image target resolution: the ingestor's image_classification
// default is 512x512 and it VALIDATES (it does not resize), so
// a mismatch hard-fails. Honour an explicit --target-size;
Expand All @@ -365,6 +371,10 @@ func runDatasetPush(ctx context.Context, out, errOut io.Writer, a runDatasetPush
"resolution mismatch.\n", derr)
}
}
default:
// Text family: no extra per-category resolution. The label (for
// text_classification) comes straight from --label-column;
// masked_language_modeling needs neither a label nor a schema.
}

// 4. Synthesize the spec from flags + validate against schema.
Expand Down Expand Up @@ -604,16 +614,23 @@ func printPushPreflight(
// shouldn't convert success into failure. The exit code is
// the contract.
cat, _ := spec["category"].(string)
tabular := push.IsTabular(cat)

_, _ = fmt.Fprintf(out, "Local dataset:\n")
_, _ = fmt.Fprintf(out, " root: %s\n", layout.Root)
if tabular {
switch {
case push.IsTabular(cat):
_, _ = fmt.Fprintf(out, " data CSV: %s\n", layout.LabelsCSV)
if sch, ok := spec["schema"].(map[string]string); ok {
_, _ = fmt.Fprintf(out, " columns: %d\n", len(sch))
}
} else {
case push.IsText(cat):
dir := push.TextSidecarDir(cat)
_, _ = fmt.Fprintf(out, " labels.csv: %s\n", layout.LabelsCSV)
_, _ = fmt.Fprintf(out, " %-15s%d files\n", dir+":", len(layout.Sidecars[dir]))
if _, ok := layout.ExtraFiles["tokenizer.json"]; ok {
_, _ = fmt.Fprintf(out, " %-15s%s\n", "tokenizer:", "tokenizer.json")
}
default:
_, _ = fmt.Fprintf(out, " labels.csv: %s\n", layout.LabelsCSV)
_, _ = fmt.Fprintf(out, " images: %d files\n", len(layout.Images))
}
Expand Down Expand Up @@ -651,7 +668,7 @@ func printPushPreflight(

if !dryRun {
_, _ = fmt.Fprintf(out, "Next: stage %d files (%s) for table %q\n",
1+len(layout.Images), push.HumanBytes(layout.TotalBytes), spec["table"])
layout.FileCount(), push.HumanBytes(layout.TotalBytes), spec["table"])
_, _ = fmt.Fprintln(out)
}
}
4 changes: 2 additions & 2 deletions internal/cli/dataset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ func execDatasetPush(t *testing.T, args []string) (exitCode int, stdout, stderr
func TestDatasetPush_UnsupportedCategory_ExitsTwo(t *testing.T) {
root := imgcLayout(t)
for _, badCategory := range []string{
"object_detection", // image category, needs sidecar staging (later)
"text_classification", // text family (later)
"object_detection", // image category, needs annotation sidecar (later)
"keypoint_detection", // image category, needs keypoint flags (later)
"definitely-not-a-category", // nonsense; gate catches this too
} {
t.Run(badCategory, func(t *testing.T) {
Expand Down
24 changes: 24 additions & 0 deletions internal/push/category.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ var regressionClassCategories = map[string]bool{
"time_to_event_prediction": true,
}

// textCategories take a labels CSV + a directory of text files
// (texts/ for classification, sequences/ for masked language
// modeling). masked_language_modeling additionally needs a
// tokenizer.json at the dataset root and has NO label.
var textCategories = map[string]bool{
"text_classification": true,
"masked_language_modeling": true,
}

// IsImage reports whether category uses the labels.csv + images/
// local layout.
func IsImage(category string) bool { return imageCategories[category] }
Expand All @@ -49,3 +58,18 @@ func IsTabular(category string) bool { return tabularCategories[category] }
// IsRegressionClass reports whether category predicts a numeric
// target and therefore needs label.policy (object label form).
func IsRegressionClass(category string) bool { return regressionClassCategories[category] }

// IsText reports whether category uses the labels.csv + text-file
// directory (texts/ or sequences/) local layout.
func IsText(category string) bool { return textCategories[category] }

// TextSidecarDir returns the sidecar directory name a text category
// expects: "sequences" for masked_language_modeling, "texts" for
// text_classification. (Used both as the local subdir to stage and
// the spec field to emit.)
func TextSidecarDir(category string) string {
if category == "masked_language_modeling" {
return "sequences"
}
return "texts"
}
21 changes: 19 additions & 2 deletions internal/push/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,33 @@ func (a SpecArgs) Build() map[string]any {
"intent": a.Intent,
"csv": path.Join(prefix, "labels.csv"),
}
if IsTabular(a.Category) {
switch {
case IsTabular(a.Category):
a.buildTabular(spec)
} else {
case IsText(a.Category):
a.buildText(spec, prefix)
default:
// Image categories (and any not-yet-special-cased category —
// the schema validator produces the canonical error for those).
a.buildImage(spec, prefix)
}
return spec
}

// buildText fills in the text-family fields: the text-file sidecar
// directory (texts/ for text_classification, sequences/ for
// masked_language_modeling) and the label. masked_language_modeling
// has NO label (the schema doesn't require one for it).
func (a SpecArgs) buildText(spec map[string]any, prefix string) {
dir := TextSidecarDir(a.Category)
// Trailing slash matches the directory-glob convention the
// ingestor uses for sidecar dirs.
spec[dir] = path.Join(prefix, dir) + "/"
if a.Category == "text_classification" {
spec["label"] = a.LabelColumn
}
}

// buildImage fills in the image-category fields: the images/ sidecar
// dir, the label column, and the optional target_size override.
func (a SpecArgs) buildImage(spec map[string]any, prefix string) {
Expand Down
4 changes: 2 additions & 2 deletions internal/push/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func Stage(ctx context.Context, opts StageOptions) error {
// 5. Stream the tar. This is where actual bytes flow. The
// progress bar (if TTY) renders during this call.
_, _ = fmt.Fprintf(opts.Out, "Streaming %d files (%s) for table %q...\n",
1+len(opts.Layout.Images), HumanBytes(opts.Layout.TotalBytes), opts.Table)
opts.Layout.FileCount(), HumanBytes(opts.Layout.TotalBytes), opts.Table)

if err := StreamLayout(ctx, opts.Executor,
opts.Namespace, podName, "stage",
Expand All @@ -148,6 +148,6 @@ func Stage(ctx context.Context, opts StageOptions) error {

// 6. Print "done" message. The deferred cleanup runs after this.
_, _ = fmt.Fprintf(opts.Out, "Staged %d files for table %q\n",
1+len(opts.Layout.Images), opts.Table)
opts.Layout.FileCount(), opts.Table)
return nil
}
47 changes: 47 additions & 0 deletions internal/push/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os"
"path"
"path/filepath"
"sort"

corev1 "k8s.io/api/core/v1"
"k8s.io/client-go/kubernetes"
Expand Down Expand Up @@ -388,12 +389,58 @@ func writeLayoutTar(w io.Writer, layout *LocalLayout) (err error) {
}
}

// Extra root-level files (e.g. masked_language_modeling's
// tokenizer.json), staged at the table root under their dest name.
// Sorted for deterministic stream order.
for _, dest := range sortedKeys(layout.ExtraFiles) {
n, err := writeTarFile(tw, layout.ExtraFiles[dest], dest)
if err != nil {
return fmt.Errorf("packaging %s: %w", dest, err)
}
totalBytes += n
if totalBytes > MaxTotalBytes {
return fmt.Errorf(
"dataset exceeded v0.1 total cap of %s after streaming %s (reached %s)",
HumanBytes(MaxTotalBytes), dest, HumanBytes(totalBytes))
}
}

// Generic sidecar directories (texts/, sequences/, and — later —
// annotations/, masks/), each staged under "<name>/<basename>".
// Sorted by dir name for deterministic stream order.
for _, name := range sortedKeys(layout.Sidecars) {
for _, abs := range layout.Sidecars[name] {
dst := path.Join(name, filepath.Base(abs))
n, err := writeTarFile(tw, abs, dst)
if err != nil {
return fmt.Errorf("packaging %s: %w", abs, err)
}
totalBytes += n
if totalBytes > MaxTotalBytes {
return fmt.Errorf(
"dataset exceeded v0.1 total cap of %s after streaming %s (reached %s)",
HumanBytes(MaxTotalBytes), dst, HumanBytes(totalBytes))
}
}
}

// tw.Close() in the defer above writes the tar footer
// (two zero blocks). Without that, GNU tar treats the archive
// as truncated and refuses to extract.
return nil
}

// sortedKeys returns a map's string keys in sorted order, for
// deterministic iteration when packaging ExtraFiles / Sidecars.
func sortedKeys[V any](m map[string]V) []string {
ks := make([]string, 0, len(m))
for k := range m {
ks = append(ks, k)
}
sort.Strings(ks)
return ks
}

// writeTarFile writes one file from `src` into tw under the
// archive-relative name `dst`. Streams the file body — no full-
// read into memory — so a single 500 MiB image doesn't balloon
Expand Down
Loading
Loading