Skip to content

Commit 15270e6

Browse files
authored
CI: extend FA4 test matrix with causal/non-causal correctness and fwd+bwd benchmark seqlen 1K-32K (#2428)
1 parent 65bfd9a commit 15270e6

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env bash
22
set -euo pipefail
33

4-
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)
4+
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../../" && pwd)
55

66
python3 "$SCRIPT_DIR/tools/ci/run_fa4_ci.py" \
77
--repo-root "$SCRIPT_DIR" \

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ permissions:
99

1010
env:
1111
CI_WORK_DIR: ${{ vars.CI_WORK_DIR || format('/scratch/user/{0}', github.actor) }}
12-
FA4_TEST_FILTER: "1-1-128-True-0-0.0-False-False-False-mha-dtype0"
12+
FA4_TEST_FILTER: "1024-1024-128-True-0-0.0-False-False-False-mha-dtype0 or 1024-1024-128-False-0-0.0-False-False-False-mha-dtype0"
1313

1414
jobs:
1515
lint:
@@ -23,13 +23,13 @@ jobs:
2323
- name: Ruff format
2424
run: ruff format --check flash_attn/cute/ --exclude "flash_attn/cute/flash_bwd.py,flash_attn/cute/flash_fwd.py,flash_attn/cute/flash_fwd_sm100.py,flash_attn/cute/interface.py"
2525

26-
test:
26+
fa4-correctness-and-benchmark:
2727
strategy:
2828
fail-fast: false
2929
matrix:
3030
gpu: [b200]
3131
runs-on: [self-hosted, '${{ matrix.gpu }}']
32-
name: test (${{ matrix.gpu }})
32+
name: fa4-correctness-and-benchmark (${{ matrix.gpu }})
3333
timeout-minutes: 60
3434
steps:
3535
- uses: actions/checkout@v4

tools/ci/run_fa4_ci.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,12 @@ def build_step_plan(
8787
]
8888
if not skip_benchmark:
8989
steps.append(Step(
90-
name="Benchmark (FA4 fwd, hdim=128, seqlen=8192)",
90+
name="Benchmark (FA4 fwd, hdim=128, causal=both, seqlen=1K-32K)",
9191
command=[
9292
"python3", "benchmarks/benchmark_attn.py",
93-
"--backend", "fa4", "--fwd",
94-
"--headdim", "128", "--seqlen", "8192",
93+
"--backend", "fa4", "--fwd", "--bwd",
94+
"--headdim", "128",
95+
"--seqlen", "1024,2048,4096,8192,16384,32768",
9596
"--causal", "both", "--warmup", "1", "--rep", "3",
9697
],
9798
extra_env={"CUDA_VISIBLE_DEVICES": benchmark_visible_devices},

0 commit comments

Comments
 (0)