Skip to content

4 random sampling #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
06e2ca5
shifted some functions around to tidy up repo
J-Dymond May 12, 2025
d322190
WIP EDA script
J-Dymond May 14, 2025
9290034
numpy typing
jack89roberts May 15, 2025
a096581
eda script which pulls out the most common words, entropies, and soft…
J-Dymond May 15, 2025
3f11963
Merge remote-tracking branch 'refs/remotes/origin/4-random-sampling' …
J-Dymond May 15, 2025
cc1ca09
wip synthetic model
jack89roberts May 15, 2025
0f1e1a7
added functionality of imbalancing classes the other way in random_sa…
J-Dymond May 16, 2025
023d330
dataset imbalancing now happens in get_reddit_data
J-Dymond May 16, 2025
2d01b48
Merge branch 'main' into 4-random-sampling
J-Dymond May 16, 2025
3f50daa
Add synthetic model functionality
jack89roberts May 16, 2025
86d6438
Delete fit_beta.ipynb
jack89roberts May 16, 2025
1e1080d
Updates to type-hinting, gitignore, added ipykernel to project
J-Dymond May 19, 2025
366d561
refactoring, moving functions to appropriate files
J-Dymond May 19, 2025
aa4e8cb
addressing some comments
J-Dymond May 19, 2025
5ed0bce
some refactoring changes
J-Dymond May 19, 2025
e622bc8
update gitignore
J-Dymond May 20, 2025
f20ab34
merge with 4-random-sampling PR
J-Dymond May 20, 2025
6ddc961
type_checker fix
J-Dymond May 20, 2025
7679692
changes to get_reddit_data for multi-class classification
J-Dymond May 20, 2025
18e0253
Merge pull request #14 from alan-turing-institute/12-synthetic-data-m…
klh5 May 21, 2025
3ff1c31
addressed comments
J-Dymond May 21, 2025
0caca7f
minor change to fix comment
J-Dymond May 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,4 @@ Thumbs.db
# project files
data/*
outputs/
tmp/
6 changes: 6 additions & 0 deletions configs/data/reddit_sport_small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
data_name: reddit_dataset_12
data_args:
n_rows: 15000
setting: multi-class
target_config: sport
balanced: True
3 changes: 3 additions & 0 deletions configs/model/wiki-tf-idf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model_id: "tomaarsen/glove-wikipedia-tf-idf"
model_kwargs:
num_labels: 2
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ dependencies = [
"tqdm>=4.67.1",
"transformers>=4.51.3",
"accelerate>=1.6.0",
"seaborn>=0.13.2",
"ipykernel>=6.29.5",
"sentence-transformers>=4.1.0",
]

[project.optional-dependencies]
Expand Down
125 changes: 125 additions & 0 deletions scripts/analysis/eval_eda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import argparse
import json
import os
from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from arc_tigers.data.reddit_data import get_reddit_data
from arc_tigers.utils import load_yaml


def main(args):
counter_dict = {
"all": Counter(),
"correct": {0: Counter(), 1: Counter()},
"incorrect": {0: Counter(), 1: Counter()},
}
experiment_dir = args.experiment_dir

eval_stats = os.path.join(experiment_dir, "stats_full.json")
with open(eval_stats) as f:
eval_data = json.load(f)

accuracy_vector = np.array(eval_data["accuracy"])
entropy_vector = np.array(eval_data["entropy"])
softmax_vector = np.array(eval_data["softmax"])
n_classes = softmax_vector.shape[1]

softmax_probs = np.max(softmax_vector, axis=1)
normalised_entropy = entropy_vector / np.log(n_classes)

correct_indices = np.concatenate(np.argwhere(accuracy_vector == 1))
incorrect_indices = np.concatenate(np.argwhere(accuracy_vector == 0))

correct_softmax = softmax_probs[correct_indices]
incorrect_softmax = softmax_probs[incorrect_indices]
correct_entropy = normalised_entropy[correct_indices]
incorrect_entropy = normalised_entropy[incorrect_indices]

plt.hist(correct_softmax, bins=100, alpha=0.5, label="correct")
plt.hist(incorrect_softmax, bins=100, alpha=0.5, label="incorrect")
plt.yscale("log")
plt.xlabel("Predicted Softmax")
plt.ylabel("Counts")
plt.legend(title="Prediction type")
plt.savefig(experiment_dir + "/softmax_histogram.pdf")
plt.clf()

plt.hist(correct_entropy, bins=100, alpha=0.5, label="correct")
plt.hist(incorrect_entropy, bins=100, alpha=0.5, label="incorrect")
plt.yscale("log")
plt.xlabel("Normalised Predicted Entropy")
plt.ylabel("Counts")
plt.legend(title="Prediction type")
plt.savefig(experiment_dir + "/entropy_histogram.pdf")
plt.clf()

exp_config = os.path.join(experiment_dir, "../experiment_config.json")
with open(exp_config) as f:
exp_config = json.load(f)
data_config = load_yaml(exp_config["data_config_pth"])
_, _, test_data, meta_data = get_reddit_data(
**data_config["data_args"], random_seed=exp_config["seed"], tokenizer=None
)
subreddit_label_map: dict[str, int] = meta_data["test_target_map"]
label_subreddit_map: dict[int, str] = {v: k for k, v in subreddit_label_map.items()}

for input in test_data["text"]:
counter_dict["all"].update(input.split())

correct_inputs = test_data[correct_indices]["text"]
correct_labels = test_data[correct_indices]["label"]

incorrect_inputs = test_data[incorrect_indices]["text"]
incorrect_labels = test_data[incorrect_indices]["label"]

for label, inp in tqdm(zip(incorrect_labels, incorrect_inputs, strict=True)):
counter_dict["incorrect"][label].update(inp.split())

for label, inp in tqdm(zip(correct_labels, correct_inputs, strict=True)):
counter_dict["correct"][label].update(inp.split())

# remove the top 50 frequent words from the all counter
most_common_words = counter_dict["all"].most_common(100)
for word, _ in most_common_words:
for counter in counter_dict["correct"].values():
if word in counter:
counter.pop(word)
for counter in counter_dict["incorrect"].values():
if word in counter:
counter.pop(word)

for class_label in range(n_classes):
fig, axes = plt.subplots(2, 1, figsize=(15, 5))
for ax_idx, acc in enumerate(["incorrect", "correct"]):
# print(f"Most common words in {acc} inputs:")
# print(label_subreddit_map[class_label])
# print(counter_dict[acc][class_label].most_common(50))
top_50 = counter_dict[acc][class_label].most_common(50)
words, counts = zip(*top_50, strict=True)
x_pos = np.arange(len(words))
axes[ax_idx].set_title(f"{acc} inputs")
axes[ax_idx].bar(x_pos, counts, align="center")
axes[ax_idx].set_xticks(x_pos, words, rotation=70)
axes[ax_idx].set_ylabel("Counts")
axes[ax_idx].set_xlabel("Words")
fig.tight_layout()
fig.savefig(
experiment_dir
+ f"/{label_subreddit_map[class_label].lstrip('r/')}_word_counts.pdf"
)
fig.clear()


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Evaluate an experiment and save outputs."
)
parser.add_argument(
"experiment_dir", type=str, help="Path to the experiment directory."
)
args = parser.parse_args()
main(args)
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from datasets import load_dataset
from tqdm import tqdm

from arc_tigers.constants import DATA_DIR


def main(args):
dataset_name = args.dataset_name
max_rows = args.max_rows
target_subreddits = args.target_subreddits
# Ensure the output directory exists
os.makedirs("../data/", exist_ok=True)

# Load the dataset in streaming mode to avoid downloading the entire dataset
unfiltered = load_dataset(dataset_name, streaming=True)["train"]
Expand All @@ -37,7 +37,11 @@ def main(args):
break

# Save the filtered data to a JSON file
output_dir = f"../data/{dataset_name.split('/')[-1]}/{max_rows}_rows/"
if target_subreddits == "data/top_subreddits.json":
dataset_name = dataset_name.split("/")[-1]
else:
dataset_name = target_subreddits.split("/")[-1].rstrip(".json")
output_dir = f"{DATA_DIR}/{dataset_name}/{max_rows}_rows/"
os.makedirs(output_dir, exist_ok=True)
save_pth = f"{output_dir}/filtered_rows.json"
json_data = json.dumps(data, indent=2)
Expand All @@ -52,12 +56,15 @@ def main(args):
"--max_rows", type=int, required=True, help="Maximum number of rows to process."
)
parser.add_argument(
"--dataset_name", type=str, required=True, help="Name of the dataset to load."
"--dataset_name",
type=str,
default="bit0/reddit_dataset_12",
help="Name of the dataset to load.",
)
parser.add_argument(
"--target_subreddits",
type=str,
default="../data/top_subreddits.json",
default="data/top_subreddits.json",
help="Optional: Path to a JSON file containing a list of target subreddits.",
)
args = parser.parse_args()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pandas as pd
from tqdm import tqdm

from arc_tigers.data.utils import ONE_VS_ALL_COMBINATIONS, clean_row, flag_row
from arc_tigers.data.reddit_data import ONE_VS_ALL_COMBINATIONS
from arc_tigers.data.utils import clean_row, flag_row


def main(args):
Expand Down Expand Up @@ -42,9 +43,9 @@ def main(args):

# Imbalance the dataset
if args.r is not None:
imbalance_ratio = args.r
if isinstance(imbalance_ratio, int):
n_targets = imbalance_ratio
imbalance_ratio = float(args.r)
if imbalance_ratio > 1:
n_targets = int(imbalance_ratio)
else:
n_targets = int(len(non_targets[0]) * imbalance_ratio)

Expand Down Expand Up @@ -89,7 +90,7 @@ def main(args):
parser.add_argument(
"data_dir",
type=str,
default="../data/reddit_dataset_12/15000000_rows/filtered_rows.json",
default="data/reddit_dataset_12/15000000_rows/filtered_rows.json",
help="Path to the data used for generation",
)
parser.add_argument(
Expand Down
File renamed without changes.
140 changes: 140 additions & 0 deletions scripts/experiments/random_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import argparse
import json
import os

import numpy as np
import pandas as pd
from datasets import Dataset
from tqdm import tqdm

from arc_tigers.data.utils import sample_dataset_metrics
from arc_tigers.eval.reddit_eval import get_preds
from arc_tigers.eval.utils import evaluate, get_stats


def main(
save_dir: str,
n_repeats: int,
dataset: Dataset,
preds,
init_seed: int,
max_labels: int | None = None,
evaluate_steps: list[int] | None = None,
class_balance: float = 1.0,
):
"""
Iteratively sample a dataset and compute metrics for the labelled subset.

Args:
save_dir: Directory to save the metrics files.
n_repeats: Number of times to repeat the sampling.
dataset: The dataset to sample from
preds: The predictions for the dataset.
model: The model to compute metrics with.
init_seed: The initial seed for random sampling (determines the seed used for
each repeat).
max_labels: The maximum number of labels to sample. If None, the whole dataset
will be sampled.
"""

rng = np.random.default_rng(init_seed)
if class_balance != 1.0:
output_dir = (
f"{save_dir}/imbalanced_random_sampling_outputs_"
f"{str(class_balance).replace('.', '')}/"
)
else:
output_dir = f"{save_dir}/random_sampling_outputs/"
os.makedirs(output_dir, exist_ok=True)

# full dataset stats
metrics = evaluate(dataset, preds)
stats = get_stats(preds, dataset["label"])

with open(f"{output_dir}/metrics_full.json", "w") as f:
json.dump(metrics, f, indent=2)
with open(f"{output_dir}/stats_full.json", "w") as f:
json.dump(stats, f, indent=2)

# iteratively sample dataset and compute metrics, repeated n_repeats times
for _ in tqdm(range(n_repeats)):
seed = rng.integers(1, 2**32 - 1) # Generate a random seed
metrics = sample_dataset_metrics(
dataset, preds, seed, max_labels=max_labels, evaluate_steps=evaluate_steps
)
pd.DataFrame(metrics).to_csv(f"{output_dir}/metrics_{seed}.csv", index=False)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a classifier")
parser.add_argument(
"data_config",
help=(
"path to the data config yaml file, or 'synthetic' to generate a "
"synthetic dataset"
),
)
parser.add_argument(
"model_config",
help=(
"path to the model config yaml file, or 'beta_model' to use a "
"synthetic model. model_adv must be set if using beta_model."
),
)
parser.add_argument(
"save_dir",
type=str,
default=None,
help="Path to save the model and results",
)
parser.add_argument(
"--class_balance",
type=float,
default=1.0,
help="Balance between the classes",
)
parser.add_argument("--n_repeats", type=int, required=True)
parser.add_argument("--max_labels", type=int, required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument(
"--model_adv",
type=float,
default=3.0,
help=(
"Model advantage parameter used to parameterize the performance of the "
"synthetic Beta models"
),
)
parser.add_argument(
"--synthetic_samples",
type=int,
default=10000,
help="Number of samples to generate if using a synthetic dataset",
)

args = parser.parse_args()

synthetic_args = (
{"model_adv": args.model_adv, "synthetic_samples": args.synthetic_samples}
if args.data_config == "synthetic"
else None
)

preds, test_dataset = get_preds(
data_config_path=args.data_config,
save_dir=args.save_dir,
class_balance=args.class_balance,
seed=args.seed,
synthetic_args=synthetic_args,
)

main(
args.save_dir,
args.n_repeats,
test_dataset,
preds,
args.seed,
args.max_labels,
evaluate_steps=np.arange(10, args.max_labels, 200).tolist(),
class_balance=args.class_balance,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
TrainingArguments,
)

from arc_tigers.data.reddit_data import get_reddit_data
from arc_tigers.eval.utils import compute_metrics
from arc_tigers.training.utils import (
WeightedLossTrainer,
get_label_weights,
get_reddit_data,
)
from arc_tigers.training.utils import WeightedLossTrainer, get_label_weights
from arc_tigers.utils import get_configs, load_yaml

logger = logging.getLogger(__name__)
Expand Down
Loading