-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsolokha_judge.py
247 lines (211 loc) · 7.21 KB
/
solokha_judge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
#!/usr/bin/env python3
# solokha.py - because who's better at judging than Vakula's mother, the witch Solokha?
# Character from Hohol's "Christmas Eve" (Ніч перед Різдвом)
"""
Script for parallel evaluation of translation quality using COMET models.
This module processes translation pairs from a jsonlines file, evaluates their
quality using specified COMET models, and outputs scores using multiple GPUs.
"""
import argparse
import json
import pathlib
import filelock
from multiprocessing import cpu_count
from typing import Dict, Generator, List, Set
import torch
from comet import download_model, load_from_checkpoint
from tqdm import tqdm
import smart_open
def read_processed_hashes(output_file: pathlib.Path) -> Set[str]:
"""Read already processed translation hashes from output file.
Args:
output_file: Path to the output jsonlines file.
Returns:
Set of already processed translation hashes.
"""
processed_hashes = set()
if output_file.exists():
lock = filelock.FileLock(str(output_file) + ".lock")
with lock:
with smart_open.open(output_file, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
processed_hashes.add(data["hash"])
return processed_hashes
def read_translations(
input_file: pathlib.Path,
batch_size: int,
src_lang_field: str,
tgt_lang_field: str,
processed_hashes: Set[str] = None,
) -> Generator[List[Dict[str, str]], None, None]:
"""Read translation pairs from input file in batches.
Args:
input_file: Path to the input jsonlines file.
batch_size: Number of translations to yield at once.
src_lang_field: Field name for source text in the input file.
tgt_lang_field: Field name for target text in the input file.
processed_hashes: Set of hashes to skip (for resuming).
Yields:
Batches of translation pairs.
"""
current_batch = []
with smart_open.open(input_file, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
# Skip if already processed
if processed_hashes and data["hash"] in processed_hashes:
continue
pair = {
"hash": data["hash"],
"src": data[src_lang_field],
"mt": data[tgt_lang_field],
}
processed_hashes.add(data["hash"])
current_batch.append(pair)
if len(current_batch) >= batch_size:
yield current_batch
current_batch = []
if current_batch: # Yield remaining items
yield current_batch
def evaluate_batch(
model,
model_name: str,
batch: List[Dict[str, str]],
gpus: int,
eval_batch_size: int = 8,
) -> List[Dict[str, float]]:
"""Evaluate a batch of translations using COMET model.
Args:
model: Loaded COMET model instance.
model_name: Name of the COMET model.
batch: List of translation pairs to evaluate.
eval_batch_size: Batch size for model inference.
gpus: Number of GPUs to use.
Returns:
List of dictionaries containing hashes and scores.
"""
with torch.no_grad():
batch_scores = model.predict(
batch, batch_size=eval_batch_size, gpus=gpus, num_workers=cpu_count()
)
return [
{
"hash": pair["hash"],
f"{model_name.split('/')[-1]}_score": float(score),
}
for pair, score in zip(batch, batch_scores["scores"])
]
def write_scores(output_file: pathlib.Path, scores: List[Dict[str, float]]):
"""Write scores to output file with file locking.
Args:
output_file: Path to the output file.
scores: List of score dictionaries to write.
"""
lock = filelock.FileLock(str(output_file) + ".lock")
with lock:
with smart_open.open(output_file, "a", encoding="utf-8") as f:
for score in scores:
f.write(json.dumps(score, ensure_ascii=False) + "\n")
def main():
"""Main function to orchestrate translation quality evaluation."""
parser = argparse.ArgumentParser(
description="Evaluate translation quality using COMET models in parallel"
)
parser.add_argument(
"--input",
type=pathlib.Path,
required=True,
help="Input jsonlines file with translations",
)
parser.add_argument(
"--output",
type=pathlib.Path,
required=True,
help="Output jsonlines file for scores",
)
parser.add_argument(
"--model",
choices=[
"Unbabel/wmt23-cometkiwi-da-xxl",
"Unbabel/wmt22-cometkiwi-da",
"Unbabel/wmt23-cometkiwi-da-xl",
"Unbabel/XCOMET-XXL",
],
required=True,
help="COMET model to use",
)
parser.add_argument(
"--gpus",
type=int,
default=0,
help="Number of GPU devices to use (0 for auto)",
)
parser.add_argument(
"--read-batch-size",
type=int,
default=3200,
help="Batch size for reading translations (default: 3200)",
)
parser.add_argument(
"--eval-batch-size",
type=int,
default=8,
help="Batch size for model evaluation (default: 8)",
)
parser.add_argument(
"--src-field",
default="en",
help="Field name for source text in input file (default: en)",
)
parser.add_argument(
"--tgt-field",
default="uk",
help="Field name for target text in input file (default: uk)",
)
parser.add_argument(
"--precision",
choices=["highest", "high", "medium"],
default="highest",
help="Float32 matmul precision (default: highest)",
)
args = parser.parse_args()
# Set float32 matmul precision
torch.set_float32_matmul_precision(args.precision)
# Create output directory if needed
args.output.parent.mkdir(parents=True, exist_ok=True)
# Read processed hashes for resuming
processed_hashes = read_processed_hashes(args.output)
print(f"Found {len(processed_hashes):,} already processed translations")
# Download and prepare model
print(f"Downloading model {args.model}...")
model_path = download_model(args.model)
model = load_from_checkpoint(model_path)
model = model.cuda()
model.eval()
# Process translations
translation_iterator = read_translations(
args.input,
args.read_batch_size,
args.src_field,
args.tgt_field,
processed_hashes,
)
with tqdm(desc="Evaluating translations") as pbar:
try:
for batch in translation_iterator:
scores = evaluate_batch(
model=model,
model_name=args.model,
batch=batch,
eval_batch_size=args.eval_batch_size,
gpus=args.gpus,
)
# Write scores with locking
write_scores(args.output, scores)
pbar.update(len(scores))
except Exception as e:
print(f"Error during evaluation: {str(e)}")
raise e
if __name__ == "__main__":
main()