Skip to content

Commit 79c87e8

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent df6e0e9 commit 79c87e8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+114
-121
lines changed

ac_dc/anonymization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def apply_regex_anonymization(
3030
tag_type=tag_type,
3131
)
3232
if anonymize_condition:
33-
for (ent, start, end, tag) in ner:
33+
for ent, start, end, tag in ner:
3434
# we need to actually walk through and replace by start, end span.
3535
sentence = sentence.replace(ent, f" <{tag}> ")
3636
return sentence, ner

ac_dc/deduplicate/self_deduplicate.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# @Date : 2022-01-08 22:39:29
43
# @Author : Chenghao Mou ([email protected])
54
# @Description: Self-deduplication with `datasets`
@@ -28,7 +27,7 @@
2827

2928
def main(conf: str) -> None:
3029

31-
with open(conf, "r") as f:
30+
with open(conf) as f:
3231
conf = yaml.safe_load(f.read())
3332

3433
if conf["load_from_disk"]["path"]:

ac_dc/visualization/get_data_for_visualization.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ def compute_stats(self):
9090
)
9191
for n in range(2, 16)
9292
}
93-
stats_document[
94-
"character_repetition_ratio"
95-
] = character_repetition_ratios
93+
stats_document["character_repetition_ratio"] = (
94+
character_repetition_ratios
95+
)
9696

9797
word_repetition_ratios = {
9898
n: round(

ac_dc/visualization/visualization.py

+20-20
Original file line numberDiff line numberDiff line change
@@ -290,16 +290,16 @@ def get_cond(key, cutoff, max_cutoff):
290290
"stopwords_ratio"
291291
]
292292
for i in range(len(self.docs["stopwords_ratio"])):
293-
self.docs["stopwords_ratio"].iloc[
294-
i
295-
] = Filtering.compute_stopwords_ratio(
296-
self.docs["text"].iloc[i],
297-
self.sentencepiece_model_tok,
298-
self.param["strip_characters"],
299-
self.param["cond_words_augmentation"],
300-
self.param["words_augmentation_group_sizes"],
301-
self.param["words_augmentation_join_char"],
302-
new_stopwords,
293+
self.docs["stopwords_ratio"].iloc[i] = (
294+
Filtering.compute_stopwords_ratio(
295+
self.docs["text"].iloc[i],
296+
self.sentencepiece_model_tok,
297+
self.param["strip_characters"],
298+
self.param["cond_words_augmentation"],
299+
self.param["words_augmentation_group_sizes"],
300+
self.param["words_augmentation_join_char"],
301+
new_stopwords,
302+
)
303303
)
304304
cutoff_def = "If the stop words ratio of a document is lower than this number, the document is removed."
305305
cutoff_stopwords_ratio = st.slider(
@@ -326,16 +326,16 @@ def get_cond(key, cutoff, max_cutoff):
326326
"flagged_words_ratio"
327327
]
328328
for i in range(len(self.docs["flagged_words_ratio"])):
329-
self.docs["flagged_words_ratio"].iloc[
330-
i
331-
] = Filtering.compute_flagged_words_ratio(
332-
self.docs["text"].iloc[i],
333-
self.sentencepiece_model_tok,
334-
self.param["strip_characters"],
335-
self.param["cond_words_augmentation"],
336-
self.param["words_augmentation_group_sizes"],
337-
self.param["words_augmentation_join_char"],
338-
new_flagged_words,
329+
self.docs["flagged_words_ratio"].iloc[i] = (
330+
Filtering.compute_flagged_words_ratio(
331+
self.docs["text"].iloc[i],
332+
self.sentencepiece_model_tok,
333+
self.param["strip_characters"],
334+
self.param["cond_words_augmentation"],
335+
self.param["words_augmentation_group_sizes"],
336+
self.param["words_augmentation_join_char"],
337+
new_flagged_words,
338+
)
339339
)
340340
cutoff_def = "If the flagged words ratio of a document is higher than this number, the document is removed."
341341
max_fwr = np.max(self.docs["flagged_words_ratio"])

bertin/evaluation/run_glue.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# coding=utf-8
32
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
43
#
54
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -384,19 +383,23 @@ def main():
384383
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
385384
# download model & vocab.
386385
config = AutoConfig.from_pretrained(
387-
model_args.config_name
388-
if model_args.config_name
389-
else model_args.model_name_or_path,
386+
(
387+
model_args.config_name
388+
if model_args.config_name
389+
else model_args.model_name_or_path
390+
),
390391
num_labels=num_labels,
391392
finetuning_task=data_args.task_name,
392393
cache_dir=model_args.cache_dir,
393394
revision=model_args.model_revision,
394395
use_auth_token=True if model_args.use_auth_token else None,
395396
)
396397
tokenizer = AutoTokenizer.from_pretrained(
397-
model_args.tokenizer_name
398-
if model_args.tokenizer_name
399-
else model_args.model_name_or_path,
398+
(
399+
model_args.tokenizer_name
400+
if model_args.tokenizer_name
401+
else model_args.model_name_or_path
402+
),
400403
cache_dir=model_args.cache_dir,
401404
use_fast=model_args.use_fast_tokenizer,
402405
revision=model_args.model_revision,

bertin/evaluation/run_ner.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# coding=utf-8
32
# Copyright 2020 The HuggingFace Team All rights reserved.
43
#
54
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -364,9 +363,11 @@ def get_label_list(labels):
364363
# The .from_pretrained methods guarantee that only one local process can concurrently
365364
# download model & vocab.
366365
config = AutoConfig.from_pretrained(
367-
model_args.config_name
368-
if model_args.config_name
369-
else model_args.model_name_or_path,
366+
(
367+
model_args.config_name
368+
if model_args.config_name
369+
else model_args.model_name_or_path
370+
),
370371
num_labels=num_labels,
371372
label2id=label_to_id,
372373
id2label={i: l for l, i in label_to_id.items()},
@@ -636,9 +637,9 @@ def compute_metrics(p):
636637
kwargs["dataset_tags"] = data_args.dataset_name
637638
if data_args.dataset_config_name is not None:
638639
kwargs["dataset_args"] = data_args.dataset_config_name
639-
kwargs[
640-
"dataset"
641-
] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
640+
kwargs["dataset"] = (
641+
f"{data_args.dataset_name} {data_args.dataset_config_name}"
642+
)
642643
else:
643644
kwargs["dataset"] = data_args.dataset_name
644645

bertin/mc4/mc4.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Perplexity Sampled mC4 dataset based on Common Crawl."""
22

3-
43
import gzip
54
import json
65

@@ -404,7 +403,7 @@ def _generate_examples(self, filepaths):
404403
for filepath in filepaths:
405404
logger.info("generating examples from = %s", filepath)
406405
if filepath.endswith("jsonl"):
407-
with open(filepath, "r", encoding="utf-8") as f:
406+
with open(filepath, encoding="utf-8") as f:
408407
for line in f:
409408
if line:
410409
example = json.loads(line)

bertin/run_mlm_flax.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# coding=utf-8
32
# Copyright 2021 The HuggingFace Team All rights reserved.
43
#
54
# Licensed under the Apache License, Version 2.0 (the "License");

bertin/run_mlm_flax_stream.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# coding=utf-8
32
# Copyright 2021 The HuggingFace Team All rights reserved.
43
#
54
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -446,7 +445,7 @@ def restore_checkpoint(save_dir, state):
446445
args = joblib.load(os.path.join(save_dir, "training_args.joblib"))
447446
data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))
448447

449-
with open(os.path.join(save_dir, "training_state.json"), "r") as f:
448+
with open(os.path.join(save_dir, "training_state.json")) as f:
450449
training_state = json.load(f)
451450
step = training_state["step"]
452451

bertin/utils/dataset_perplexity.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def get_perplexity(doc):
1717

1818

1919
with open("mc4-es-train-50M-stats.csv", "w") as csv:
20-
with open("mc4-es-train-50M-steps.jsonl", "r") as data:
20+
with open("mc4-es-train-50M-steps.jsonl") as data:
2121
for line in tqdm(data):
2222
text = json.loads(line)["text"]
2323
csv.write(f"{len(text.split())},{get_perplexity(text)}\n")

cc_pseudo_crawl/python_scripts/deeper.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Generate list of urls to query for next depth. We then need to use Athena to make a fancy query.
33
"""
4+
45
import csv
56
import re
67
import subprocess

cc_pseudo_crawl/python_scripts/download_warc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ def get_warcs(batch):
143143
existing_compressed_warcs,
144144
)
145145

146-
batch["compressed_warc"], batch["download_exception"] = [
146+
batch["compressed_warc"], batch["download_exception"] = (
147147
list(l) for l in zip(*warcs_or_exceptions)
148-
]
148+
)
149149
return batch
150150

151151

cc_pseudo_crawl/python_scripts/exact_deduplicates.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Taken from Teven and Leandro"""
2+
23
import gzip
34
import os
45
import shutil

cc_pseudo_crawl/python_scripts/load_all_seed_ids.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def main():
2121

2222
seed_ids = []
2323
for seed_path in args.seed_paths:
24-
with open(seed_path, "r") as fi:
24+
with open(seed_path) as fi:
2525
data = csv.reader(fi)
2626
# First line is all the headers that we remove.
2727
seed_ids += [row[0] for row_id, row in enumerate(data) if row_id > 0]

cc_pseudo_crawl/python_scripts/pseudo_crawl_seed_to_lm_dset_v2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def process_batch(batch, skip_set):
126126
# looks at up to the first 10K pages for a seed and
127127
# records lines that appear in at least 1% of the unique pages
128128
def get_lines_to_skip(dset, n_records, pourcentage_threshold, min_repetition_threshold):
129-
line_counts = defaultdict(lambda: 0)
129+
line_counts = defaultdict(int)
130130
seen_pages = set()
131131

132132
seed = SeedSequence(42)

cc_pseudo_crawl/python_scripts/shard_by_seed_id.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Deduplicating using `datasets` is much harder, we but we forgot to generate an id when building an index, so we're screwed.
33
"""
4+
45
import logging
56
import subprocess
67
from argparse import ArgumentParser

kenlm_training/cc_net/execution.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919

2020

2121
class Executor(Protocol):
22-
def __call__(self, function: Callable[..., str], *args: Iterable) -> None:
23-
...
22+
def __call__(self, function: Callable[..., str], *args: Iterable) -> None: ...
2423

2524

2625
class SubmititRetryOnTimeout(submitit.helpers.Checkpointable):

kenlm_training/cc_net/flat_hash_set.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,17 @@ def __repr__(self):
2929
implementation = type(self).__name__
3030
return f"[{implementation}, len: {len(self)}"
3131

32-
def __len__(self) -> int:
33-
...
32+
def __len__(self) -> int: ...
3433

35-
def __contains__(self, values: Sequence[np.uint64]) -> np.ndarray:
36-
...
34+
def __contains__(self, values: Sequence[np.uint64]) -> np.ndarray: ...
3735

38-
def __getitem__(self, values) -> np.ndarray:
39-
...
36+
def __getitem__(self, values) -> np.ndarray: ...
4037

41-
def __setitem__(self, keys, values) -> None:
42-
...
38+
def __setitem__(self, keys, values) -> None: ...
4339

44-
def items(self) -> Iterable[Tuple[np.uint64, np.uint8]]:
45-
...
40+
def items(self) -> Iterable[Tuple[np.uint64, np.uint8]]: ...
4641

47-
def keys(self) -> Iterable[np.uint64]:
48-
...
42+
def keys(self) -> Iterable[np.uint64]: ...
4943

5044
def __iter__(self) -> Iterator[np.uint64]:
5145
return iter(self.keys())

kenlm_training/cc_net/jsonql.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -880,8 +880,7 @@ def describe(source, columns=None, weights=None, **kwargs):
880880
continue
881881
if "." in k or k == ALL_DOCUMENTS:
882882
continue
883-
for line in display_stats(stats, k, weights=weights, **kwargs):
884-
yield line
883+
yield from display_stats(stats, k, weights=weights, **kwargs)
885884

886885

887886
def shard(lines):
@@ -902,17 +901,13 @@ def get_or_set(dictionary, key, default):
902901
class SimpleIO(Protocol):
903902
"""A subset of methods from TextIO."""
904903

905-
def close(self) -> None:
906-
...
904+
def close(self) -> None: ...
907905

908-
def write(self, line: str) -> int:
909-
...
906+
def write(self, line: str) -> int: ...
910907

911-
def __enter__(self) -> "SimpleIO":
912-
...
908+
def __enter__(self) -> "SimpleIO": ...
913909

914-
def __exit__(self, exc_type, exc_value, traceback):
915-
...
910+
def __exit__(self, exc_type, exc_value, traceback): ...
916911

917912

918913
def open_read(filename: ReadableFileLike) -> Iterable[str]:
@@ -961,7 +956,7 @@ def open_read(filename: ReadableFileLike) -> Iterable[str]:
961956
if filename.suffix == ".gz":
962957
file: TextIO = gzip.open(filename, "rt") # type: ignore
963958
else:
964-
file = open(filename, "rt")
959+
file = open(filename)
965960

966961
return _close_when_exhausted(file)
967962

@@ -1015,7 +1010,7 @@ def open_write(
10151010
if filename.suffix == ".gz":
10161011
return BlockedGzipWriter(Path(filename), mode, block_size="64M")
10171012

1018-
return open(filename, "wt")
1013+
return open(filename, "w")
10191014

10201015

10211016
def parse_size(size):

kenlm_training/tests/test_jsonql.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def do(self, x):
262262
def acc(values):
263263
print("acc: started")
264264
res = 0
265-
for (x, _) in values:
265+
for x, _ in values:
266266
res += int(x)
267267
print("acc: done")
268268
yield f"acc: result={res}"

perplexity_lenses/perplexity_lenses/data.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ def hub_dataset_to_dataframe(
3434
{
3535
text_column: sentence,
3636
"perplexity": model.get_perplexity(sentence),
37-
"label": x.get("labels", [])[0]
38-
if len(x.get("labels", [])) > 0
39-
else "NONE", # Special case for registry dataset
37+
"label": (
38+
x.get("labels", [])[0]
39+
if len(x.get("labels", [])) > 0
40+
else "NONE"
41+
), # Special case for registry dataset
4042
}
4143
for sentence in x[text_column].split("\n")
4244
]
@@ -46,9 +48,9 @@ def hub_dataset_to_dataframe(
4648
lambda x: {
4749
text_column: x[text_column],
4850
"perplexity": model.get_perplexity(x[text_column]),
49-
"label": x.get("labels", [])[0]
50-
if len(x.get("labels", [])) > 0
51-
else "NONE", # Special case for registry dataset
51+
"label": (
52+
x.get("labels", [])[0] if len(x.get("labels", [])) > 0 else "NONE"
53+
), # Special case for registry dataset
5254
}
5355
)
5456
instances = []

0 commit comments

Comments
 (0)