Skip to content

Commit 6446aec

Browse files
J03D03claude
andcommitted
Accept Iterable[DatasetEntry] in transformations and exporters
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 9ade43f commit 6446aec

16 files changed

Lines changed: 59 additions & 45 deletions

‎examples/04_enrich_with_commit_data.py‎

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
setup_logging("enrich_with_commit_data")
99

1010
if __name__ == "__main__":
11-
entries = list(vfc_datasets.DevignDataset()) # Devign contains only 2 projects (FFmpeg, QEMU).
12-
13-
entries = transformations.add_commit_information_local(entries)
11+
# Devign contains only 2 projects (FFmpeg, QEMU).
12+
entries = transformations.add_commit_information_local(vfc_datasets.DevignDataset())
1413
log_dataset_stats(entries)

‎examples/05_save_and_export.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
setup_logging("save_and_export")
1010

1111
if __name__ == "__main__":
12-
entries = list(vfc_datasets.DevignDataset())
12+
entries = vfc_datasets.DevignDataset()
1313

1414
output_dir = Path(".data/exports")
1515
save_entries(entries, output_dir / "devign.jsonl")

‎examples/06_create_splits.py‎

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99
setup_logging("create_splits")
1010

1111
if __name__ == "__main__":
12-
entries = list(vfc_datasets.DevignDataset())
13-
1412
create_random_split(
15-
entries,
13+
vfc_datasets.DevignDataset(),
1614
name="devign",
1715
output_path=Path(".data/splits"),
1816
seed=42,

‎examples/07_filter_by_language.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
logger = logging.getLogger(__name__)
1111

1212
if __name__ == "__main__":
13-
entries = list(vfc_datasets.BigVulDataset())
13+
dataset = vfc_datasets.BigVulDataset()
1414

15-
logger.info("Entries before filtering: %d", len(entries))
16-
entries = transformations.filter_by_extension(entries, {"c", "h", "cpp", "hpp", "cc"})
15+
logger.info("Entries before filtering: %d", len(dataset))
16+
entries = transformations.filter_by_extension(dataset, {"c", "h", "cpp", "hpp", "cc"})
1717
logger.info("Entries after filtering: %d", len(entries))

‎src/vfc_datasets/transformations/enrichment/add_commit_data_api.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import logging
5+
from collections.abc import Iterable
56
from typing import Any
67

78
from tqdm.asyncio import tqdm
@@ -75,8 +76,9 @@ async def _enrich_entries_async(entries: list[DatasetEntry]) -> tuple[int, int]:
7576
return success, len(entries) - success
7677

7778

78-
def add_commit_information_api(entries: list[DatasetEntry]) -> list[DatasetEntry]:
79+
def add_commit_information_api(entries: Iterable[DatasetEntry]) -> list[DatasetEntry]:
7980
"""Enrich entries with commit data from the GitHub API and return the modified list."""
81+
entries = list(entries)
8082
entries_to_process = [e for e in entries if needs_enrichment(e)]
8183

8284
if not entries_to_process:

‎src/vfc_datasets/transformations/enrichment/add_commit_data_local.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from collections.abc import Iterable
23
from datetime import UTC, datetime
34

45
from git import Repo
@@ -53,13 +54,13 @@ def _process_commit_batch(args: tuple[str, list[str], int]) -> dict[str, CommitD
5354
return results
5455

5556

56-
def add_commit_information_local(entries: list[DatasetEntry]) -> list[DatasetEntry]:
57+
def add_commit_information_local(entries: Iterable[DatasetEntry]) -> list[DatasetEntry]:
5758
"""Enrich DatasetEntry objects with commit information from local git repos."""
5859
logger.info("Add commit information [LOCAL]")
5960
logger.info("Max diff size limit: %dK chars", MAX_DIFF_SIZE // 1000)
6061

6162
return process_commits_in_batches(
62-
entries,
63+
list(entries),
6364
filter_fn=needs_enrichment,
6465
batch_fn=_process_commit_batch,
6566
apply_fn=apply_commit_data,

‎src/vfc_datasets/transformations/enrichment/add_no_comment.py‎

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import os
33
import tempfile
4+
from collections.abc import Iterable
45
from pathlib import PurePosixPath
56
from typing import Any
67

@@ -290,18 +291,19 @@ def _apply_diff(entry: DatasetEntry, diff: str) -> None:
290291

291292

292293
def strip_diff_comments(
293-
entries: list[DatasetEntry], *, include_unsupported: bool = True
294+
entries: Iterable[DatasetEntry], *, include_unsupported: bool = True
294295
) -> list[DatasetEntry]:
295296
"""Strip comments from commit diffs in-place using tree-sitter.
296297
297298
Args:
298-
entries: List of dataset entries to process
299+
entries: Dataset entries to process
299300
include_unsupported: If True (default), include original diff for unsupported files.
300301
If False, skip unsupported files entirely.
301302
"""
302303
logger.info("Strip comments from commit diffs [LOCAL]")
303304
logger.info("Max diff size: %dK chars", MAX_DIFF_SIZE // 1000)
304305

306+
entries = list(entries)
305307
needs_processing = [e for e in entries if e.commit_diff]
306308
skipped = sum(
307309
1

‎src/vfc_datasets/transformations/enrichment/commit_id_enrichment.py‎

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import logging
33
from collections import defaultdict
4+
from collections.abc import Iterable
45
from concurrent.futures import ProcessPoolExecutor, as_completed
56

67
from tqdm.asyncio import tqdm as async_tqdm
@@ -114,9 +115,9 @@ async def _extend_commit_id_api_async(
114115
return entry, entry.commit_id, False
115116

116117

117-
def extend_commit_ids_api(entries: list[DatasetEntry]) -> list[DatasetEntry]:
118+
def extend_commit_ids_api(entries: Iterable[DatasetEntry]) -> list[DatasetEntry]:
118119
"""Extend commit IDs using the GitHub API (async) and return the (possibly) modified list of entries."""
119-
# Filter entries that need extension
120+
entries = list(entries)
120121
entries_to_process = [
121122
entry for entry in entries if entry.commit_id and len(entry.commit_id) < 40
122123
]
@@ -139,8 +140,9 @@ def extend_commit_ids_api(entries: list[DatasetEntry]) -> list[DatasetEntry]:
139140
return entries
140141

141142

142-
def extend_commit_ids_local(entries: list[DatasetEntry]) -> list[DatasetEntry]:
143+
def extend_commit_ids_local(entries: Iterable[DatasetEntry]) -> list[DatasetEntry]:
143144
"""Extend commit IDs in-place by cloning repositories locally and return the list of modified entries."""
145+
entries = list(entries)
144146
entries_to_process = [
145147
entry for entry in entries if entry.commit_id and len(entry.commit_id) < 40
146148
]

‎src/vfc_datasets/transformations/enrichment/project_urls/update_project_urls.py‎

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Transformation functions for updating and filtering project URLs.
22
3-
These functions operate on a list of DatasetEntry objects to:
3+
These functions operate on dataset entries to:
44
- Remove entries with unreachable project URLs
55
- Update entries with moved/renamed project URLs
66
"""
77

88
import logging
9+
from collections.abc import Iterable
910

1011
from vfc_datasets.dataset_entry import DatasetEntry
1112

@@ -17,16 +18,17 @@
1718
logger = logging.getLogger(__name__)
1819

1920

20-
def filter_unreachable_project_urls(entries: list[DatasetEntry]) -> list[DatasetEntry]:
21+
def filter_unreachable_project_urls(entries: Iterable[DatasetEntry]) -> list[DatasetEntry]:
2122
"""Remove entries with unreachable project_url's"""
2223
logger.info("REMOVE unreachable project_urls")
2324
unreachable_urls = get_unreachable_urls()
2425
return [e for e in entries if e.project_url not in unreachable_urls]
2526

2627

27-
def update_project_urls_inplace(entries: list[DatasetEntry]) -> list[DatasetEntry]:
28+
def update_project_urls_inplace(entries: Iterable[DatasetEntry]) -> list[DatasetEntry]:
2829
"""Update moved or fixed project_url's in-place and return the modified entries."""
2930
logger.info("UPDATE project_urls")
31+
entries = list(entries)
3032
changed_urls = 0
3133
moved_urls = get_moved_urls()
3234
for entry in entries:

‎src/vfc_datasets/transformations/filters/collapse_to_commit_level.py‎

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from collections import defaultdict
3+
from collections.abc import Iterable
34

45
from vfc_datasets.dataset_entry import DatasetEntry
56

@@ -9,7 +10,7 @@
910

1011

1112
def collapse_to_commit_level(
12-
entries: list[DatasetEntry],
13+
entries: Iterable[DatasetEntry],
1314
*,
1415
include_benign_only: bool = True,
1516
) -> list[DatasetEntry]:
@@ -23,13 +24,13 @@ def collapse_to_commit_level(
2324
"functions. This could lead to mislabeling VFCs as non-VFC."
2425
)
2526

26-
if not entries:
27-
return []
28-
2927
groups: dict[tuple[str, str], list[DatasetEntry]] = defaultdict(list)
3028
for entry in entries:
3129
groups[(entry.project_url, entry.commit_id)].append(entry)
3230

31+
if not groups:
32+
return []
33+
3334
result = []
3435
dropped = 0
3536
for group_entries in groups.values():

0 commit comments

Comments
 (0)