Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions scripts/test_vuln_remediation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,34 @@ def test_manager_limits_socket_plan_batch_size(self) -> None:
self.assertEqual(context["deferred"][0]["id"], "GHSA-two")
self.assertIn("limited to 1 fix", context["deferred"][0]["reason"])

def test_manager_prefers_simple_direct_socket_fixes(self) -> None:
context = vr.build_context(
{"alerts": []},
{
"type": "only-direct-dependency-upgrades",
"fixes": {
"GHSA-transitive-first-alphabetically": {
"directDependencies": [
{
"purl": "pkg:npm/wrapper@1.0.0",
"transitiveFixes": [
{"purl": "pkg:npm/transitive@1.0.0", "fixedVersion": "1.0.1"}
],
}
]
},
"GHSA-direct-second-alphabetically": {
"directDependencies": [
{"purl": "pkg:npm/direct@1.0.0", "fixedVersion": "1.0.1"}
]
},
},
},
max_fixes=1,
)

self.assertEqual(context["fixes"][0]["id"], "GHSA-direct-second-alphabetically")

def test_manager_focuses_on_reachable_and_potentially_reachable(self) -> None:
fix_plan = {
"type": "only-direct-dependency-upgrades",
Expand Down
33 changes: 32 additions & 1 deletion scripts/vuln_remediation.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def build_context(remediation_input: dict[str, Any], fix_plan: dict[str, Any] |
)
seen_ids.add(vuln_id)

for vuln_id, plan_entry in plan_by_id.items():
for vuln_id, plan_entry in sorted_plan_entries(plan_by_id):
if vuln_id in seen_ids:
continue
state = fix_plan_state(plan_entry) or default_plan_state
Expand Down Expand Up @@ -272,6 +272,37 @@ def build_context(remediation_input: dict[str, Any], fix_plan: dict[str, Any] |
return {"fixes": items, "deferred": deferred}


def sorted_plan_entries(plan_by_id: dict[str, Any]) -> list[tuple[str, dict[str, Any]]]:
entries = [
(vuln_id, plan_entry)
for vuln_id, plan_entry in plan_by_id.items()
if isinstance(plan_entry, dict)
]
return sorted(entries, key=lambda item: plan_complexity_score(item[0], item[1]))


def plan_complexity_score(vuln_id: str, plan_entry: dict[str, Any]) -> tuple[int, int, int, int, str]:
direct_dependencies = [
dependency
for dependency in plan_entry.get("directDependencies") or []
if isinstance(dependency, dict)
]
direct_updates = sum(1 for dependency in direct_dependencies if dependency.get("fixedVersion"))
transitive_updates = sum(len(dependency.get("transitiveFixes") or []) for dependency in direct_dependencies)
package_count = len(responsible_direct_dependencies(plan_entry))

# Prefer the fixes Socket is most likely to apply cleanly:
# direct dependency bumps, fewer packages, fewer transitive edges.
direct_priority = 0 if direct_updates > 0 else 1
return (
direct_priority,
package_count or 999,
transitive_updates,
-direct_updates,
vuln_id,
)


def responsible_direct_dependencies(plan_entry: dict[str, Any]) -> list[str]:
details = ((plan_entry.get("value") or {}).get("fixDetails") or {})
raw = details.get("responsibleDirectDependencies") or {}
Expand Down
Loading