Skip to content

Command Line Interface

AIF-Gen is intended to be primarily used as a command line tool:

foo@bar:~$ aif --help

          / _ | /  _/ __/ / ___/ __/ |/ /
         / __ |_/ // _/  / (_ / _//    /
        /_/ |_/___/_/    \___/___/_/|_/

A tool for generating synthetic continual RLHF datasets.

Usage: aif [OPTIONS] COMMAND [ARGS]...

Options:
  --log_file FILE  Optional log file to use.  [default: aif_gen.log]
  --help           Show this message and exit.

Commands:
  generate   Generate a new ContinualAlignmentDataset.
  merge      Merge a set of ContinualAlignmentDatasets.
  preview    Preview a ContinualAlignmentDataset.
  sample     Downsample a ContinualAlignmentDataset.
  transform  Transform a ContinualAlignmentDataset.
  validate   Validate a ContinualAlignmentDataset.

generate

generate(
    data_config_name: Path,
    model: str,
    output_file: Path,
    max_concurrency: int,
    max_tokens_prompt_response: int,
    max_tokens_chosen_rejected_response: int,
    random_seed: int,
    dry_run: bool,
    hf_repo_id: Optional[str],
    include_preference_axes: bool,
    temperature: float,
) -> None

Generate a new ContinualAlignmentDataset.

DATA_CONFIG_NAME: Path to the dataset configuration file to use for dataset generation. MODEL: vLLM-compatible model to use for data generation.

Source code in aif_gen/cli/commands/generate.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@click.command(context_settings={'show_default': True})
@click.argument(
    'data_config_name',
    type=click.Path(exists=True, dir_okay=False, path_type=pathlib.Path),
)
@click.argument(
    'model',
    type=click.STRING,
)
@click.option(
    '--output_file',
    type=click.Path(dir_okay=False, path_type=pathlib.Path),
    help='Path to write the generated dataset.',
    default=lambda: f'data/{get_run_id(name=click.get_current_context().params["data_config_name"].stem)}/data.json',
)
@click.option(
    '--max_concurrency',
    type=click.IntRange(min=1, max=256, clamp=True),
    help='Max number of concurrent inference requests to send to the vLLM model',
    default=128,
)
@click.option(
    '--max_tokens_prompt_response',
    type=click.IntRange(min=1, max=32768, clamp=True),
    help='Limit the max_tokens on the prompt response from the vLLM model.',
    default=1024,
)
@click.option(
    '--max_tokens_chosen_rejected_response',
    type=click.IntRange(min=1, max=65536, clamp=True),
    help='Limit the max_tokens on the chosen/rejected response pair from the vLLM model.',
    default=2048,
)
@click.option(
    '--random_seed',
    type=int,
    help='Random seed for data generation.',
    default=0,
)
@click.option(
    '-n',
    '--dry-run',
    is_flag=True,
    default=False,
    help='Ignore the input config and generate a dummy sample ensuring the model endpoint is setup.',
)
@click.option(
    '--hf-repo-id',
    type=click.STRING,
    default=None,
    help='If not None, push the generated dataset to a HuggingFace remote repository with the associated repo-id.',
)
@click.option(
    '--include-preference-axes',
    is_flag=True,
    default=False,
    help='Include preference axes in the generated dataset.',
)
@click.option(
    '--temperature',
    type=click.FloatRange(min=0.0, max=2.0, clamp=True),
    default=0.99,
    help='Temperature for sampling from the model.',
)
def generate(
    data_config_name: pathlib.Path,
    model: str,
    output_file: pathlib.Path,
    max_concurrency: int,
    max_tokens_prompt_response: int,
    max_tokens_chosen_rejected_response: int,
    random_seed: int,
    dry_run: bool,
    hf_repo_id: Optional[str],
    include_preference_axes: bool,
    temperature: float,
) -> None:
    r"""Generate a new ContinualAlignmentDataset.

    DATA_CONFIG_NAME: Path to the dataset configuration file to use for dataset generation.
    MODEL: vLLM-compatible model to use for data generation.
    """
    logging.info(f'Using data configuration file: {data_config_name}')
    logging.info(f'Using model: {model}')
    logging.info(f'Random seed: {random_seed}')
    seed_everything(random_seed)

    data_config = yaml.safe_load(data_config_name.read_text())
    logging.debug(f'Configuration: {data_config}')

    output_file.parent.mkdir(parents=True, exist_ok=True)

    if not dry_run:
        config = copy.deepcopy(data_config)
        config['model'] = model
        config['max_concurrency'] = max_concurrency
        with open(output_file.parent / 'config.json', 'w') as f:
            json.dump(config, f)

    try:
        client = openai.AsyncOpenAI()
    except (openai.OpenAIError, Exception) as e:
        logging.exception(f'Could not create openAI client: {e}')
        return

    async_semaphore = asyncio.Semaphore(max_concurrency)
    future = generate_continual_dataset(
        data_config,
        model,
        client,
        async_semaphore,
        max_tokens_prompt_response,
        max_tokens_chosen_rejected_response,
        dry_run,
        include_preference_axes=include_preference_axes,
        temperature=temperature,
    )
    dataset = asyncio.get_event_loop().run_until_complete(future)
    if dataset is not None:
        logging.info(f'Writing {len(dataset)} samples to {output_file}')
        dataset.to_json(output_file)
        logging.info(f'Wrote {len(dataset)} samples to {output_file}')

        if hf_repo_id is not None:
            upload_to_hf(repo_id=hf_repo_id, local_path=output_file)

merge

merge() -> None

Merge a set of ContinualAlignmentDatasets.

Source code in aif_gen/cli/commands/merge.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
@click.command()
def merge() -> None:
    r"""Merge a set of ContinualAlignmentDatasets."""
    merged: List[AlignmentDataset] = []
    total_samples = 0
    while True:
        click.echo(
            f'{len(merged)} Datasets Buffered ({total_samples} samples)\t\t[a]dd / [m]erge / q[uit]'
        )
        c = click.getchar()
        click.echo()
        if c.lower() == 'q':
            return
        elif c.lower() == 'm':
            break
        elif c.lower() == 'a':
            path = click.prompt('> Path to dataset')
            try:
                dataset = ContinualAlignmentDataset.from_json(path)
            except Exception as e:
                click.secho(f'Failed to read dataset from {path}: {e}', fg='red')
                continue
            click.secho(
                f'Read ContinualAlignmentDataset with {len(dataset.datasets)} constituents'
                f' and {dataset.num_samples} samples',
                fg='green',
            )
            merged.extend(dataset.datasets)
            total_samples += len(dataset)

    if not len(merged):
        click.secho(f'No datasets in buffer, skipping writedown', fg='red')
        return

    click.echo(f'Merging {len(merged)} datasets')
    dataset = ContinualAlignmentDataset(merged)
    path = click.prompt('> Path to save merged dataset')
    click.echo(f'Writing dataset to: {path}')
    dataset.to_json(path)
    click.secho(f'Wrote {len(dataset)} samples to: {path}', fg='green')

preview

preview(
    input_data_file: Path,
    shuffle: bool,
    hf_repo_id: Optional[str],
) -> None

Preview a ContinualAlignmentDataset.

INPUT_DATA_FILE: Path to the input dataset.

Source code in aif_gen/cli/commands/preview.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
@click.command(context_settings={'show_default': True})
@click.argument(
    'input_data_file',
    type=click.Path(exists=True, dir_okay=False, path_type=pathlib.Path),
)
@click.option(
    '--shuffle/--no-shuffle',
    is_flag=True,
    default=True,
    help='Shuffle the order in which samples are previewed.',
)
@click.option(
    '--hf-repo-id',
    type=click.STRING,
    default=None,
    help='If not None, pull the dataset to and from a HuggingFace remote repository with the associated repo-id.',
)
def preview(
    input_data_file: pathlib.Path,
    shuffle: bool,
    hf_repo_id: Optional[str],
) -> None:
    r"""Preview a ContinualAlignmentDataset.

    INPUT_DATA_FILE: Path to the input dataset.
    """
    if hf_repo_id is not None:
        input_data_file = download_from_hf(hf_repo_id, input_data_file)

    logging.info(f'Reading dataset from: {input_data_file}')
    dataset = ContinualAlignmentDataset.from_json(input_data_file)
    logging.info(f'Read {len(dataset)} samples from: {input_data_file}')

    if len(dataset) == 0:
        logging.error('Dataset is empty!')
        return

    tasks = [
        {
            'AlignmentDataset Index': f'[{i + 1}]/[{len(dataset.datasets)}]',
            'Objective': data.task.objective,
            'Preference': data.task.preference,
        }
        for i, data in enumerate(dataset.datasets)
    ]

    dataset_idx = 0
    while True:
        pprint.pp(tasks[dataset_idx])
        if len(tasks) == 1:
            click.echo('\n> [y]es / [q]uit')
        elif dataset_idx == 0:
            click.echo('\n> [y]es / [n]ext / [q]uit')
        elif dataset_idx == len(tasks) - 1:
            click.echo('\n> [y]es / [p]revious / [q]uit')
        else:
            click.echo('\n> [y]es / [n]ext / [p]revious / [q]uit')

        c = click.getchar()
        click.echo()
        if c.lower() == 'q':
            return
        elif c.lower() == 'n' and dataset_idx < len(tasks) - 1:
            dataset_idx += 1
        elif c.lower() == 'p' and dataset_idx > 0:
            dataset_idx -= 1
        elif c.lower() == 'y':
            splits = [
                dataset.datasets[dataset_idx].train,
                dataset.datasets[dataset_idx].test,
            ]
            if shuffle:
                random.shuffle(splits[0])
                random.shuffle(splits[1])

            if len(splits[0]) == 0 and len(splits[1]) == 0:
                click.echo('No data available!')
                break

            sample_idx, split_idx = 0, 0
            split_names = ['Train', 'Test']
            while True:
                if len(splits[split_idx]) == 0:
                    click.echo(f'No data for {split_names[split_idx]} split!')
                    break

                sample_dict = {
                    f'{split_names[split_idx]} Index': f'{sample_idx + 1}/{len(splits[split_idx])}'
                }
                sample_dict.update(asdict(splits[split_idx][sample_idx]))

                pprint.pp(tasks[dataset_idx])
                pprint.pp(sample_dict)

                other_split = split_names[1 - split_idx]
                if len(splits[split_idx]) == 1:
                    click.echo(f'\n> [s]witch to {other_split} / [b]ack / [q]uit')
                elif sample_idx == 0:
                    click.echo(
                        f'\n> [n]ext / [s]witch to {other_split} / [b]ack / [q]uit'
                    )
                elif sample_idx == len(splits[split_idx]) - 1:
                    click.echo(
                        f'\n> [p]revious / [s]witch to {other_split} / [b]ack / [q]uit'
                    )
                else:
                    click.echo(
                        f'\n> [n]ext / [p]revious / [s]witch to {other_split} / [b]ack / [q]uit'
                    )

                c = click.getchar()
                click.echo()
                if c.lower() == 'q':
                    return
                if c.lower() == 'b':
                    break
                elif c.lower() == 's':
                    split_idx = 1 - split_idx
                    sample_idx = 0
                elif c.lower() == 'n' and sample_idx < len(splits[split_idx]) - 1:
                    sample_idx += 1
                elif c.lower() == 'p' and sample_idx > 0:
                    sample_idx -= 1

sample

sample(
    input_data_file: Path,
    keep_ratio_train: float,
    keep_ratio_test: float,
    keep_amount_train: Optional[int],
    keep_amount_test: Optional[int],
    hf_repo_id: Optional[str],
    hf_repo_id_out: Optional[str],
    output_file: Path,
    random_seed: int,
) -> None

Downsample a ContinualAlignmentDataset.

INPUT_DATA_FILE: Path to the input dataset. KEEP_RATIO_TRAIN: Ratio of samples to keep in the train dataset. KEEP_RATIO_TEST: Ratio of samples to keep in the test dataset.

Source code in aif_gen/cli/commands/sample.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@click.command(context_settings={'show_default': True})
@click.argument(
    'input_data_file',
    type=click.Path(exists=True, dir_okay=False, path_type=pathlib.Path),
)
@click.argument(
    'keep_ratio_train',
    type=click.FloatRange(min=0.0, max=1.0, clamp=True),
)
@click.argument(
    'keep_ratio_test',
    type=click.FloatRange(min=0.0, max=1.0, clamp=True),
)
@click.option(
    '--keep_amount_train',
    type=int,
    help='Amount of samples to keep in the dataset. If not None, overrides the keep_ratio_train and keep_ratio_test options.',
    default=None,
)
@click.option(
    '--keep_amount_test',
    type=int,
    help='Amount of samples to keep in the dataset. If not None, overrides the keep_ratio_train and keep_ratio_test options.',
    default=None,
)
@click.option(
    '--hf-repo-id',
    type=click.STRING,
    default=None,
    help='If not None, pull the dataset to and from a HuggingFace remote repository with the associated repo-id.',
)
@click.option(
    '--hf-repo-id-out',
    type=click.STRING,
    default=None,
    help='If not None, push the dataset to a HuggingFace remote repository with the associated repo-id.',
)
@click.option(
    '--output_file',
    type=click.Path(dir_okay=False, path_type=pathlib.Path),
    help='Path to write the generated dataset.',
    default=lambda: f'data/{get_run_id(name=click.get_current_context().params["input_data_file"].stem)}/data.json',
)
@click.option(
    '--random_seed',
    type=int,
    help='Random seed for test data selection.',
    default=0,
)
def sample(
    input_data_file: pathlib.Path,
    keep_ratio_train: float,
    keep_ratio_test: float,
    keep_amount_train: Optional[int],
    keep_amount_test: Optional[int],
    hf_repo_id: Optional[str],
    hf_repo_id_out: Optional[str],
    output_file: pathlib.Path,
    random_seed: int,
) -> None:
    r"""Downsample a ContinualAlignmentDataset.

    INPUT_DATA_FILE: Path to the input dataset.
    KEEP_RATIO_TRAIN: Ratio of samples to keep in the train dataset.
    KEEP_RATIO_TEST: Ratio of samples to keep in the test dataset.
    """
    if hf_repo_id is not None:
        input_data_file = download_from_hf(hf_repo_id, input_data_file)

    logging.info(f'Reading dataset from: {input_data_file}')
    dataset = ContinualAlignmentDataset.from_json(input_data_file)
    logging.info(f'Read {dataset.num_samples} samples from: {input_data_file}')

    if len(dataset) == 0:
        logging.error('Dataset is empty!')
        return

    logging.info(f'Original dataset has {dataset.num_samples} samples.')
    logging.info(f'Original dataset has {dataset.num_datasets} tasks.')
    logging.info(f'Starting sampling with seed {random_seed} for each task.')
    seed_everything(random_seed)

    for i, data in tqdm(enumerate(dataset.datasets), desc='Processing datasets'):
        train, test = data.train, data.test

        train_size = int(keep_ratio_train * len(train))
        if keep_amount_train is not None:
            train_size = keep_amount_train

        test_size = int(keep_ratio_test * len(test))
        if keep_amount_test is not None:
            test_size = keep_amount_test

        train_size = min(train_size, len(train))
        test_size = min(test_size, len(test))

        new_train = random.sample(train, train_size)
        new_test = random.sample(test, test_size)
        dataset.datasets[i] = AlignmentDataset(
            task=data.task, samples=new_train + new_test, train_frac=data.train_frac
        )

    logging.info(f'Writing dataset to: {output_file}')
    dataset.to_json(output_file)
    logging.info(f'Wrote {dataset.num_samples} samples to: {output_file}')

    if hf_repo_id_out is not None:
        upload_to_hf(hf_repo_id_out, output_file)
        logging.info(f'Uploaded dataset to HuggingFace repo: {hf_repo_id_out}')

transform

Functions:

  • preference_swap

    Swap the 'chosen' and 'rejected' responses for each sample in the dataset with probability.

  • split

    Split a ContinualAlignmentDataset into train and test datasets.

  • transform

    Transform a ContinualAlignmentDataset.

preference_swap

preference_swap(
    input_data_file: Path,
    output_data_file: Path,
    p: float,
    hf_repo_id: Optional[str],
    hf_repo_id_out: Optional[str],
    random_seed: int,
) -> None

Swap the 'chosen' and 'rejected' responses for each sample in the dataset with probability.

INPUT_DATA_FILE: Path to the input dataset. OUTPUT_DATA_FILE: Path to the output (transformed) dataset.

Source code in aif_gen/cli/commands/transform.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@transform.command(context_settings={'show_default': True})
@click.argument(
    'input_data_file',
    type=click.Path(exists=True, dir_okay=False, path_type=pathlib.Path),
)
@click.argument(
    'output_data_file',
    type=click.Path(dir_okay=False, path_type=pathlib.Path),
)
@click.option(
    '--hf-repo-id',
    type=click.STRING,
    default=None,
    help='If not None, pull and push the transformed dataset to and from a HuggingFace remote repository with the associated repo-id.',
)
@click.option(
    '--hf-repo-id-out',
    type=click.STRING,
    default=None,
    help='If not None, push the dataset to a HuggingFace remote repository with the associated repo-id.',
)
@click.option(
    '--p',
    type=click.FloatRange(min=0, max=1),
    default=1,
    help="Probability with which to swap each 'chosen' and 'rejected' in the dataset",
)
@click.option(
    '--random_seed',
    type=int,
    help='Random seed for test data selection.',
    default=0,
)
def preference_swap(
    input_data_file: pathlib.Path,
    output_data_file: pathlib.Path,
    p: float,
    hf_repo_id: Optional[str],
    hf_repo_id_out: Optional[str],
    random_seed: int,
) -> None:
    r"""Swap the 'chosen' and 'rejected' responses for each sample in the dataset with probability.

    INPUT_DATA_FILE: Path to the input dataset.
    OUTPUT_DATA_FILE: Path to the output (transformed) dataset.
    """
    if hf_repo_id is not None:
        input_data_file = download_from_hf(hf_repo_id, input_data_file)

    logging.info(f'Reading dataset from: {input_data_file}')
    dataset = ContinualAlignmentDataset.from_json(input_data_file)
    logging.info(f'Read {len(dataset)} samples from: {input_data_file}')

    seed_everything(random_seed)

    logging.info(f'Applying preference swap transform with p={p}')
    transformed_dataset = F.preference_swap_transform(dataset, swap_probability=p)
    logging.info(f'Transformed dataset.')

    logging.info(f'Writing dataset to: {output_data_file}')
    transformed_dataset.to_json(output_data_file)
    logging.info(f'Wrote {len(transformed_dataset)} samples from: {output_data_file}')

    if hf_repo_id_out is not None:
        upload_to_hf(hf_repo_id_out, output_data_file)
        logging.info(f'Uploaded dataset to HuggingFace repo: {hf_repo_id_out}')

split

split(
    input_data_file: Path,
    hf_repo_id: Optional[str],
    hf_repo_id_out: Optional[str],
    output_file: Path,
    test_sample_ratio: float,
    random_seed: int,
) -> None

Split a ContinualAlignmentDataset into train and test datasets.

INPUT_DATA_FILE: Path to the input dataset.

Source code in aif_gen/cli/commands/transform.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
@transform.command(context_settings={'show_default': True})
@click.argument(
    'input_data_file',
    type=click.Path(exists=True, dir_okay=False, path_type=pathlib.Path),
)
@click.option(
    '--hf-repo-id',
    type=click.STRING,
    default=None,
    help='If not None, pull the dataset to and from a HuggingFace remote repository with the associated repo-id.',
)
@click.option(
    '--hf-repo-id-out',
    type=click.STRING,
    default=None,
    help='If not None, push the dataset to a HuggingFace remote repository with the associated repo-id.',
)
@click.option(
    '--output_file',
    type=click.Path(dir_okay=False, path_type=pathlib.Path),
    help='Path to write the generated dataset.',
    default=lambda: f'data/{get_run_id(name=click.get_current_context().params["input_data_file"].stem)}/data.json',
)
@click.option(
    '--test_sample_ratio',
    type=click.FloatRange(min=0.0, max=1.0, clamp=True),
    help='Ratio of samples to use for testing in each static task of the dataset.',
    default=0.15,
)
@click.option(
    '--random_seed',
    type=int,
    help='Random seed for test data selection.',
    default=0,
)
def split(
    input_data_file: pathlib.Path,
    hf_repo_id: Optional[str],
    hf_repo_id_out: Optional[str],
    output_file: pathlib.Path,
    test_sample_ratio: float,
    random_seed: int,
) -> None:
    r"""Split a ContinualAlignmentDataset into train and test datasets.

    INPUT_DATA_FILE: Path to the input dataset.
    """
    if hf_repo_id is not None:
        input_data_file = download_from_hf(hf_repo_id, input_data_file)

    logging.info(f'Reading dataset from: {input_data_file}')
    dataset = ContinualAlignmentDataset.from_json(input_data_file)
    logging.info(f'Read {dataset.num_samples} samples from: {input_data_file}')

    if len(dataset) == 0:
        logging.error('Dataset is empty!')
        return

    seed_everything(random_seed)
    logging.info(f'Splitting dataset with test_sample_ratio={test_sample_ratio}')
    transformed_dataset = F.split_transform(dataset, test_ratio=test_sample_ratio)
    logging.info(f'Writing dataset to: {output_file}')
    transformed_dataset.to_json(output_file)
    logging.info(f'Wrote {transformed_dataset.num_samples} samples to: {output_file}')

    if hf_repo_id_out is not None:
        upload_to_hf(hf_repo_id_out, output_file)
        logging.info(f'Uploaded dataset to HuggingFace repo: {hf_repo_id_out}')

transform

transform() -> None

Transform a ContinualAlignmentDataset.

Source code in aif_gen/cli/commands/transform.py
16
17
18
@click.group()
def transform() -> None:
    r"""Transform a ContinualAlignmentDataset."""

validate

validate(
    input_data_file: Path,
    output_validation_file: Path,
    validate_count: bool,
    validate_entropy: bool,
    validate_llm_judge: bool,
    validate_embedding_diversity: bool,
    model: str,
    embedding_model: str,
    embedding_batch_size: int,
    max_concurrency: int,
    max_tokens_judge_response: int,
    dry_run: bool,
    hf_repo_id: Optional[str],
    random_seed: int,
) -> None

Validate a ContinualAlignmentDataset.

INPUT_DATA_FILE: Path to the input dataset. OUTPUT_VALIDATION_FILE: Path to the output validation file.

Source code in aif_gen/cli/commands/validate.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
@click.command(context_settings={'show_default': True})
@click.argument(
    'input_data_file',
    type=click.Path(exists=True, dir_okay=False, path_type=pathlib.Path),
)
@click.argument(
    'output_validation_file',
    type=click.Path(dir_okay=False, path_type=pathlib.Path),
)
@click.option(
    '--validate-count/--no-validate-count',
    is_flag=True,
    default=True,
    help='Perform basic count validation on the dataset.',
)
@click.option(
    '--validate-entropy/--no-validate-entropy',
    is_flag=True,
    default=True,
    help='Perform entropy validation on the dataset.',
)
@click.option(
    '--validate-llm-judge/--no-validate-llm-judge',
    is_flag=True,
    default=False,
    help='Perform llm judge validation on the dataset.',
)
@click.option(
    '--validate-embedding-diversity/--no-validate-embedding-diversity',
    is_flag=True,
    default=False,
    help='Perform embedding similarity/diversity validation on the dataset.',
)
@click.option(
    '--model',
    type=click.STRING,
    help='vLLM model to use as a judge if doing llm_judge validation',
)
@click.option(
    '--embedding-model',
    type=click.STRING,
    help='vLLM embedding model for computing embedding simiarity',
)
@click.option(
    '--embedding-batch-size',
    type=click.IntRange(min=1),
    default=256,
    help='Number of items to embed in each request.',
)
@click.option(
    '--max_concurrency',
    type=click.IntRange(min=1, max=256, clamp=True),
    help='Max number of concurrent inference requests to send to the vLLM model',
    default=128,
)
@click.option(
    '--max_tokens_judge_response',
    type=click.IntRange(min=1, max=1024, clamp=True),
    help='Limit the max_tokens on the judge response from the vLLM model if doing llm_judge validation.',
    default=128,
)
@click.option(
    '-n',
    '--dry-run',
    is_flag=True,
    default=False,
    help='Ignore the dataset and generate validate a dummy sample to ensure vLLM setup.',
)
@click.option(
    '--hf-repo-id',
    type=click.STRING,
    default=None,
    help='If not None, pull the dataset to and from a HuggingFace remote repository with the associated repo-id.',
)
@click.option(
    '--random_seed',
    type=int,
    help='Random seed for validation.',
    default=0,
)
def validate(
    input_data_file: pathlib.Path,
    output_validation_file: pathlib.Path,
    validate_count: bool,
    validate_entropy: bool,
    validate_llm_judge: bool,
    validate_embedding_diversity: bool,
    model: str,
    embedding_model: str,
    embedding_batch_size: int,
    max_concurrency: int,
    max_tokens_judge_response: int,
    dry_run: bool,
    hf_repo_id: Optional[str],
    random_seed: int,
) -> None:
    r"""Validate a ContinualAlignmentDataset.

    INPUT_DATA_FILE: Path to the input dataset.
    OUTPUT_VALIDATION_FILE: Path to the output validation file.
    """
    logging.info(f'Random seed: {random_seed}')
    seed_everything(random_seed)

    if hf_repo_id is not None:
        input_data_file = download_from_hf(hf_repo_id, input_data_file)

    logging.info(f'Reading dataset from: {input_data_file}')
    dataset = ContinualAlignmentDataset.from_json(input_data_file)
    logging.info(f'Read {len(dataset)} samples from: {input_data_file}')

    results: Dict[str, Any] = {}
    if validate_count:
        logging.info('Performing count validation')
        results['count_validation'] = count_validation(dataset)
        logging.info('Finished count validation')

    if validate_entropy:
        logging.info('Performing entropy validation')
        results['entropy_validation'] = entropy_validation(dataset)
        logging.info('Finished entropy validation')

    if validate_llm_judge:
        logging.info(f'Performing LLM judge validation with model: {model}')

        try:
            client = openai.AsyncOpenAI()
            async_semaphore = asyncio.Semaphore(max_concurrency)
            fut = llm_judge_validation(
                dataset,
                model,
                client,
                async_semaphore,
                max_tokens_judge_response,
                dry_run,
            )
            result = asyncio.get_event_loop().run_until_complete(fut)
        except (openai.OpenAIError, Exception) as e:
            logging.exception(f'Error occurred trying to validate data with vLLM: {e}')
            result = None

        results['llm_judge_validation'] = result
        logging.info('Finished LLM judge validation')

    if validate_embedding_diversity:
        logging.info(
            f'Performing embedding diversity validation with model: {embedding_model}'
        )
        try:
            client = openai.AsyncOpenAI()
            async_semaphore = asyncio.Semaphore(max_concurrency)
            fut = llm_embedding_diversity(
                dataset=dataset,
                model_name=embedding_model,
                client=client,
                batch_size=embedding_batch_size,
                async_semaphore=async_semaphore,
                dry_run=dry_run,
            )
            result = asyncio.get_event_loop().run_until_complete(fut)
        except (openai.OpenAIError, Exception) as e:
            logging.exception(f'Error occurred trying to embed data with vLLM: {e}')
            result = None

        results['llm_embedding_diversity'] = result
        logging.info('Finished embedding similarity')

    if len(results):
        logging.info(f'Writing validation results to: {output_validation_file}')
        with output_validation_file.open('w', encoding='utf-8') as f:
            json.dump(results, f)

        if hf_repo_id is not None:
            upload_to_hf(hf_repo_id, output_validation_file)
    else:
        logging.warning('No validation measure was specified, skipping writedown.')