Skip to content

4 random sampling #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions configs/data/reddit_sport_small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
data_name: reddit_dataset_12
data_args:
n_rows: 15000
setting: multi-class
target_config: sport
balanced: True
3 changes: 3 additions & 0 deletions configs/model/wiki-tf-idf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model_id: "tomaarsen/glove-wikipedia-tf-idf"
model_kwargs:
num_labels: 2
17 changes: 13 additions & 4 deletions scripts/dataset_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from datasets import load_dataset
from tqdm import tqdm

from arc_tigers.constants import DATA_DIR


def main(args):
dataset_name = args.dataset_name
max_rows = args.max_rows
target_subreddits = args.target_subreddits
# Ensure the output directory exists
os.makedirs("../data/", exist_ok=True)
os.makedirs("data/", exist_ok=True)

# Load the dataset in streaming mode to avoid downloading the entire dataset
unfiltered = load_dataset(dataset_name, streaming=True)["train"]
Expand All @@ -37,7 +39,11 @@ def main(args):
break

# Save the filtered data to a JSON file
output_dir = f"../data/{dataset_name.split('/')[-1]}/{max_rows}_rows/"
if target_subreddits == "data/top_subreddits.json":
dataset_name = dataset_name.split("/")[-1]
else:
dataset_name = target_subreddits.split("/")[-1].rstrip(".json")
output_dir = f"{DATA_DIR}/{dataset_name}/{max_rows}_rows/"
os.makedirs(output_dir, exist_ok=True)
save_pth = f"{output_dir}/filtered_rows.json"
json_data = json.dumps(data, indent=2)
Expand All @@ -52,12 +58,15 @@ def main(args):
"--max_rows", type=int, required=True, help="Maximum number of rows to process."
)
parser.add_argument(
"--dataset_name", type=str, required=True, help="Name of the dataset to load."
"--dataset_name",
type=str,
default="bit0/reddit_dataset_12",
help="Name of the dataset to load.",
)
parser.add_argument(
"--target_subreddits",
type=str,
default="../data/top_subreddits.json",
default="data/top_subreddits.json",
help="Optional: Path to a JSON file containing a list of target subreddits.",
)
args = parser.parse_args()
Expand Down
5 changes: 3 additions & 2 deletions scripts/dataset_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pandas as pd
from tqdm import tqdm

from arc_tigers.data.utils import ONE_VS_ALL_COMBINATIONS, clean_row, flag_row
from arc_tigers.data.reddit_data import ONE_VS_ALL_COMBINATIONS
from arc_tigers.data.utils import clean_row, flag_row


def main(args):
Expand Down Expand Up @@ -89,7 +90,7 @@ def main(args):
parser.add_argument(
"data_dir",
type=str,
default="../data/reddit_dataset_12/15000000_rows/filtered_rows.json",
default="data/reddit_dataset_12/15000000_rows/filtered_rows.json",
help="Path to the data used for generation",
)
parser.add_argument(
Expand Down
35 changes: 0 additions & 35 deletions scripts/eval.py

This file was deleted.

125 changes: 125 additions & 0 deletions scripts/eval_eda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import argparse
import json
import os
from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from arc_tigers.data.reddit_data import get_reddit_data
from arc_tigers.utils import load_yaml


def main(args):
counter_dict = {
"all": Counter(),
"correct": {0: Counter(), 1: Counter()},
"incorrect": {0: Counter(), 1: Counter()},
}
experiment_dir = args.experiment_dir

eval_stats = os.path.join(experiment_dir, "stats_full.json")
with open(eval_stats) as f:
eval_data = json.load(f)

accuracy_vector = np.array(eval_data["accuracy"])
entropy_vector = np.array(eval_data["entropy"])
softmax_vector = np.array(eval_data["softmax"])
n_classes = softmax_vector.shape[1]

softmax_probs = np.max(softmax_vector, axis=1)
normalised_entropy = entropy_vector / np.log(n_classes)

correct_indices = np.concatenate(np.argwhere(accuracy_vector == 1))
incorrect_indices = np.concatenate(np.argwhere(accuracy_vector == 0))

correct_softmax = softmax_probs[correct_indices]
incorrect_softmax = softmax_probs[incorrect_indices]
correct_entropy = normalised_entropy[correct_indices]
incorrect_entropy = normalised_entropy[incorrect_indices]

plt.hist(correct_softmax, bins=100, alpha=0.5, label="correct")
plt.hist(incorrect_softmax, bins=100, alpha=0.5, label="incorrect")
plt.yscale("log")
plt.xlabel("Predicted Softmax")
plt.ylabel("Counts")
plt.legend(title="Prediction type")
plt.savefig(experiment_dir + "/softmax_histogram.pdf")
plt.clf()

plt.hist(correct_entropy, bins=100, alpha=0.5, label="correct")
plt.hist(incorrect_entropy, bins=100, alpha=0.5, label="incorrect")
plt.yscale("log")
plt.xlabel("Normalised Predicted Entropy")
plt.ylabel("Counts")
plt.legend(title="Prediction type")
plt.savefig(experiment_dir + "/entropy_histogram.pdf")
plt.clf()

exp_config = os.path.join(experiment_dir, "../experiment_config.json")
with open(exp_config) as f:
exp_config = json.load(f)
data_config = load_yaml(exp_config["data_config_pth"])
_, _, test_data, meta_data = get_reddit_data(
**data_config["data_args"], random_seed=exp_config["seed"], tokenizer=None
)
subreddit_label_map: dict[str, int] = meta_data["test_target_map"]
label_subreddit_map: dict[int, str] = {v: k for k, v in subreddit_label_map.items()}

for input in test_data["text"]:
counter_dict["all"].update(input.split())

correct_inputs = test_data[correct_indices]["text"]
correct_labels = test_data[correct_indices]["label"]

incorrect_inputs = test_data[incorrect_indices]["text"]
incorrect_labels = test_data[incorrect_indices]["label"]

for label, inp in tqdm(zip(incorrect_labels, incorrect_inputs, strict=True)):
counter_dict["incorrect"][label].update(inp.split())

for label, inp in tqdm(zip(correct_labels, correct_inputs, strict=True)):
counter_dict["correct"][label].update(inp.split())

# remove the top 50 frequent words from the all counter
most_common_words = counter_dict["all"].most_common(100)
for word, _ in most_common_words:
for counter in counter_dict["correct"].values():
if word in counter:
counter.pop(word)
for counter in counter_dict["incorrect"].values():
if word in counter:
counter.pop(word)

for class_label in range(n_classes):
fig, axes = plt.subplots(2, 1, figsize=(15, 5))
for ax_idx, acc in enumerate(["incorrect", "correct"]):
# print(f"Most common words in {acc} inputs:")
# print(label_subreddit_map[class_label])
# print(counter_dict[acc][class_label].most_common(50))
top_50 = counter_dict[acc][class_label].most_common(50)
words, counts = zip(*top_50, strict=True)
x_pos = np.arange(len(words))
axes[ax_idx].set_title(f"{acc} inputs")
axes[ax_idx].bar(x_pos, counts, align="center")
axes[ax_idx].set_xticks(x_pos, words, rotation=70)
axes[ax_idx].set_ylabel("Counts")
axes[ax_idx].set_xlabel("Words")
fig.tight_layout()
fig.savefig(
experiment_dir
+ f"/{label_subreddit_map[class_label].lstrip('r/')}_word_counts.pdf"
)
fig.clear()


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Evaluate an experiment and save outputs."
)
parser.add_argument(
"experiment_dir", type=str, help="Path to the experiment directory."
)
args = parser.parse_args()
main(args)
42 changes: 34 additions & 8 deletions scripts/random_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,40 @@
TrainingArguments,
)

from arc_tigers.eval.utils import compute_metrics
from arc_tigers.data.reddit_data import get_reddit_data
from arc_tigers.eval.utils import compute_metrics, get_stats
from arc_tigers.sample.random import RandomSampler
from arc_tigers.training.utils import get_reddit_data
from arc_tigers.utils import load_yaml


def imbalance_dataset(dataset, seed, class_balance):
def imbalance_dataset(dataset: Dataset, seed: int, class_balance: float) -> Dataset:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I think changing the imbalance would be handled in the data scripts, and not the sampling script

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we want different levels of imbalance when training? And would this just be in the binary case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically I just think that as far as the sampling script is concerned the dataset logic shouldn't be much more than test_data = load_dataset(config) or similar. We may want the option of imbalance in the training splits too, but that's separate to anything to do with this script (apart from making sure the test data doesn't overlap with the training data).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made a change now to reflect this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In arc_tigers.data.get_reddit_data If balanced is False, it checks for a class_balance argument and uses that to imbalance the train and test splits.

"""
Imbalance the dataset based on the class_balance variable.

Args:
dataset: The dataset to imbalance.
seed: The random seed for sampling.
class_balance: The balance between the classes. A value of 1.0 means
balanced classes, while a value of 0.5 means class 1 is half the size of
class 0. A negative value means class 1 is larger than class 0.

Returns:
The imbalanced dataset.
"""
# Imbalance the dataset based on the class_balance variable
class_labels = np.array(dataset["label"])
class_0_indices = np.where(class_labels == 0)[0]
class_1_indices = np.where(class_labels == 1)[0]

# Calculate the number of samples for each class based on class_balance
n_class_0 = len(class_0_indices)
n_class_1 = int(n_class_0 * class_balance)
assert abs(class_balance) <= 1.0, "class balance must be between -1.0 and 1.0"
if class_balance < 0:
class_balance = -1.0 * class_balance
n_class_1 = len(class_1_indices)
n_class_0 = int(n_class_1 * class_balance)
else:
n_class_0 = len(class_0_indices)
n_class_1 = int(n_class_0 * class_balance)

# Randomly sample indices for each class
rng = np.random.default_rng(seed)
Expand Down Expand Up @@ -156,14 +175,17 @@ def main(
f"{str(class_balance).replace('.', '')}/"
)
else:
output_dir = f"{save_dir}/new_random_sampling_outputs/"
output_dir = f"{save_dir}/random_sampling_outputs/"
os.makedirs(output_dir, exist_ok=True)

# full dataset stats
metrics = evaluate(dataset, preds)
print(metrics)
stats = get_stats(preds, dataset["label"])

with open(f"{output_dir}/metrics_full.json", "w") as f:
json.dump(metrics, f)
json.dump(metrics, f, indent=2)
with open(f"{output_dir}/stats_full.json", "w") as f:
json.dump(stats, f, indent=2)

# iteratively sample dataset and compute metrics, repeated n_repeats times
for _ in tqdm(range(n_repeats)):
Expand Down Expand Up @@ -214,6 +236,10 @@ def main(
_, _, test_dataset, meta_data = get_reddit_data(
**data_config["data_args"], tokenizer=tokenizer
)
# Save meta_data to the save_dir
meta_data_path = os.path.join(args.save_dir, "data_stats.json")
with open(meta_data_path, "w") as meta_file:
json.dump(meta_data, meta_file, indent=2)

if args.class_balance != 1.0:
test_dataset = imbalance_dataset(
Expand Down
7 changes: 2 additions & 5 deletions scripts/train_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
TrainingArguments,
)

from arc_tigers.data.reddit_data import get_reddit_data
from arc_tigers.eval.utils import compute_metrics
from arc_tigers.training.utils import (
WeightedLossTrainer,
get_label_weights,
get_reddit_data,
)
from arc_tigers.training.utils import WeightedLossTrainer, get_label_weights
from arc_tigers.utils import get_configs, load_yaml

logger = logging.getLogger(__name__)
Expand Down
Loading