Skip to content

Commit 8643da7

Browse files
authored
Merge pull request #205 from IINemo/fix-warning
Fix warning
2 parents 1bc1ed0 + 8f1dcbc commit 8643da7

26 files changed

+71
-7
lines changed

examples/configs/polygraph_eval_aeslc.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

examples/configs/polygraph_eval_babiqa.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

examples/configs/polygraph_eval_coqa.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

examples/configs/polygraph_eval_gsm8k.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

examples/configs/polygraph_eval_mmlu.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

examples/configs/polygraph_eval_person_bio.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

examples/configs/polygraph_eval_triviaqa.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

examples/configs/polygraph_eval_wiki_bio.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

examples/configs/polygraph_eval_wmt14_deen.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

examples/configs/polygraph_eval_wmt14_fren.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

examples/configs/polygraph_eval_wmt19_deen.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

examples/configs/polygraph_eval_xsum.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

scripts/polygraph_eval

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import json
1111

1212
import logging
1313

14-
log = logging.getLogger()
14+
log = logging.getLogger('lm_polygraph')
1515

1616
from lm_polygraph.utils.manager import UEManager
1717
from lm_polygraph.utils.dataset import Dataset
@@ -27,7 +27,6 @@ from lm_polygraph.ue_metrics import *
2727

2828
hydra_config = Path(os.environ["HYDRA_CONFIG"])
2929

30-
3130
@hydra.main(
3231
version_base=None,
3332
config_path=str(hydra_config.parent),
@@ -95,13 +94,19 @@ def main(args):
9594
load_from_disk=args.load_from_disk,
9695
**cache_kwargs
9796
)
97+
log.info("Done with loading eval data.")
9898

99+
log.info("="*100)
100+
log.info("Initializing UE estimators...")
99101
estimators = []
100102
estimators += get_ue_methods(args, model)
101103
density_based_ue_methods = get_density_based_ue_methods(args, model.model_type)
102104
estimators += density_based_ue_methods
105+
log.info("Done loading UE estimators")
103106

104107
if any([not getattr(method, "is_fitted", False) for method in density_based_ue_methods]):
108+
log.info("="*100)
109+
log.info(f"Loading train dataset...")
105110
if (args.train_dataset is not None) and (
106111
args.train_dataset != args.dataset
107112
):
@@ -162,15 +167,14 @@ def main(args):
162167
background_train_dataset.subsample(
163168
args.subsample_background_train_dataset, seed=seed
164169
)
170+
log.info(f"Done loading train data.")
165171
else:
166172
train_dataset = None
167173
background_train_dataset = None
168174

169175
if args.subsample_eval_dataset != -1:
170176
dataset.subsample(args.subsample_eval_dataset, seed=seed)
171177

172-
log.info("Done with loading data.")
173-
174178
generation_metrics = get_generation_metrics(args)
175179

176180
ue_metrics = get_ue_metrics(args)
@@ -339,6 +343,9 @@ def get_ue_methods(args, model):
339343

340344

341345
def get_generation_metrics(args):
346+
log.info("="*100)
347+
log.info("Initializing generation metrics...")
348+
342349
generation_metrics = getattr(args, "generation_metrics", None)
343350
if not generation_metrics:
344351
result = [
@@ -372,6 +379,9 @@ def get_generation_metrics(args):
372379
raise ValueError("BartScoreSeqMetric does not support multiref")
373380
metric_class = globals()[metric_name]
374381
result.append(metric_class(*metric.get("args", [])))
382+
383+
log.info("Done with initializing generation metrics.")
384+
375385
return result
376386

377387

src/lm_polygraph/estimators/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
from abc import ABC, abstractmethod
44
from typing import List, Dict
5+
from lm_polygraph.utils.common import polygraph_module_init
56

67

78
class Estimator(ABC):
89
"""
910
Abstract estimator class, which estimates the uncertainty of a language model.
1011
"""
1112

13+
@polygraph_module_init
1214
def __init__(self, stats_dependencies: List[str], level: str):
1315
"""
1416
Parameters:

src/lm_polygraph/estimators/lexical_similarity.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66

77
from .estimator import Estimator
88

9+
from absl import logging as absl_logging
10+
11+
# This prevents bullshit spam from rouge scorer
12+
absl_logging.set_verbosity(absl_logging.WARNING)
13+
914

1015
class LexicalSimilarity(Estimator):
1116
"""

src/lm_polygraph/generation_metrics/generation_metric.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import List, Dict
44
from abc import ABC, abstractmethod
5+
from lm_polygraph.utils.common import polygraph_module_init
56

67

78
class GenerationMetric(ABC):
@@ -11,6 +12,7 @@ class GenerationMetric(ABC):
1112
compared with different estimators' uncertainties in UEManager using ue_metrics.
1213
"""
1314

15+
@polygraph_module_init
1416
def __init__(self, stats_dependencies: List[str], level: str):
1517
"""
1618
Parameters:

src/lm_polygraph/generation_metrics/rouge.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
from typing import List, Dict
55
from .generation_metric import GenerationMetric
66

7+
from absl import logging as absl_logging
8+
9+
# This prevents bullshit spam from rouge scorer
10+
absl_logging.set_verbosity(absl_logging.WARNING)
11+
712

813
class RougeMetric(GenerationMetric):
914
"""

src/lm_polygraph/stat_calculators/greedy_alternatives_nli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def _eval_nli_model(nli_queue: List[Tuple[str, str]], deberta: Deberta) -> List[
4242

4343

4444
class GreedyAlternativesNLICalculator(StatCalculator):
45+
4546
def __init__(self, nli_model):
4647
super().__init__(
4748
[

src/lm_polygraph/stat_calculators/model_score.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def _batch_tokens(tokens_list: List[List[int]], model: WhiteboxModel):
1919

2020

2121
class ModelScoreCalculator(StatCalculator):
22+
2223
def __init__(self, prompt: str = 'Paraphrase "{}": ', batch_size: int = 10):
2324
super().__init__(["model_rh"], ["greedy_tokens", "input_tokens"])
2425
self.batch_size = batch_size

src/lm_polygraph/stat_calculators/stat_calculator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import List, Dict
44
from abc import ABC, abstractmethod
55
from lm_polygraph.utils.model import Model
6+
from lm_polygraph.utils.common import polygraph_module_init
67

78

89
class StatCalculator(ABC):
@@ -20,6 +21,7 @@ class StatCalculator(ABC):
2021
Each new StatCalculator needs to be registered at lm_polygraph/stat_calculators/__init__.py to be seen be UEManager.
2122
"""
2223

24+
@polygraph_module_init
2325
def __init__(self, stats: List[str], stat_dependencies: List[str]):
2426
"""
2527
Parameters:

src/lm_polygraph/utils/common.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import logging
2+
3+
log = logging.getLogger("lm_polygraph")
4+
5+
6+
def polygraph_module_init(func):
7+
def wrapper(*args, **kwargs):
8+
if func.__name__ == "__init__":
9+
log.info(f"Initializing {args[0].__class__.__name__}")
10+
func(*args, **kwargs)
11+
12+
return wrapper

src/lm_polygraph/utils/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -520,9 +520,7 @@ def tokenize(self, texts: List[str]) -> Dict[str, torch.Tensor]:
520520
return_token_type_ids=False,
521521
)
522522
else:
523-
tokenized = self.tokenizer(
524-
texts, truncation=True, padding=True, return_tensors="pt"
525-
)
523+
tokenized = self.tokenizer(texts, padding=True, return_tensors="pt")
526524

527525
return tokenized
528526

src/lm_polygraph/utils/register_stat_calculators.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import os
2+
import logging
23

34
from lm_polygraph.stat_calculators import *
45
from lm_polygraph.utils.deberta import Deberta
56
from lm_polygraph.utils.openai_chat import OpenAIChat
67

78
from typing import Dict, List, Optional, Tuple
89

10+
log = logging.getLogger("lm_polygraph")
11+
912

1013
def register_stat_calculators(
1114
deberta_batch_size: int = 10, # TODO: rename to NLI model
@@ -20,7 +23,13 @@ def register_stat_calculators(
2023
stat_calculators: Dict[str, "StatCalculator"] = {}
2124
stat_dependencies: Dict[str, List[str]] = {}
2225

26+
log.info("=" * 100)
27+
log.info("Loading NLI model...")
2328
nli_model = Deberta(batch_size=deberta_batch_size, device=deberta_device)
29+
30+
log.info("=" * 100)
31+
log.info("Initializing stat calculators...")
32+
2433
openai_chat = OpenAIChat(cache_path=cache_path)
2534

2635
def _register(calculator_class: StatCalculator):
@@ -75,4 +84,6 @@ def _register(calculator_class: StatCalculator):
7584
_register(GreedyAlternativesFactPrefNLICalculator(nli_model=nli_model))
7685
_register(ClaimsExtractor(openai_chat=openai_chat))
7786

87+
log.info("Done intitializing stat calculators...")
88+
7889
return stat_calculators, stat_dependencies

test/configs/test_polygraph_eval.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

test/configs/test_polygraph_eval_ensemble.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: default
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

test/configs/test_polygraph_eval_seq_ue.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ hydra:
44

55
defaults:
66
- model: bloomz-560m
7+
- _self_
78

89
cache_path: ./workdir/output
910
save_path: '${hydra:run.dir}'

0 commit comments

Comments
 (0)