Skip to content

Validation

Basic Statistics

counts

Functions:

  • count_validation

    Count the number of 'unique' samples and average phrase length in the dataset.

count_validation

count_validation(
    dataset: Dataset, remove_stop_words: bool = False
) -> List[Dict[str, int | float]]

Count the number of 'unique' samples and average phrase length in the dataset.

Parameters:

Returns:

  • List[Dict[str, int | float]]

    List[Dict[str, int]]: For every AligmentDataset, returns a dictionary with the following entries:

  • List[Dict[str, int | float]]

    'samples' -> int: The total number of samples in the AlignmentDataset.

  • List[Dict[str, int | float]]

    'unique_samples' -> int: The number of unique samples in the AlignmentDataset.

  • List[Dict[str, int | float]]

    'unique_prompts' -> int: The number of unique prompts in the AlignmentDataset.

  • List[Dict[str, int | float]]

    'unique_chosen' -> int: The number of unique chosen responses in the AlignmentDataset.

  • List[Dict[str, int | float]]

    'unique_rejected' -> int: The number of unique rejected responses in the AlignmentDataset.

  • List[Dict[str, int | float]]

    'avg_prompt_length' -> float: The average length of prompts in the AlignmentDataset.

  • List[Dict[str, int | float]]

    'avg_chosen_length' -> float: The average length of chosen responses in the AlignmentDataset.

  • List[Dict[str, int | float]]

    'avg_rejected_length' -> float: The average length of rejected responses in the AlignmentDataset.

Note

If the input dataset is an AlignmentDataset (non-continual), this function returns a 1 element list with the relevant statistics.

Source code in aif_gen/validation/counts.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def count_validation(
    dataset: Dataset, remove_stop_words: bool = False
) -> List[Dict[str, int | float]]:
    r"""Count the number of 'unique' samples and average phrase length in the dataset.

    Args:
        dataset (Union[ContinualAlignmentDataset, AlignmentDataset]): The dataset to validate.
        remove_stop_words (bool): If true, applies stop word removal before computing dataset counts.

    Returns:
        List[Dict[str, int]]: For every AligmentDataset, returns a dictionary with the following entries:

        'samples'               -> int: The total number of samples in the AlignmentDataset.
        'unique_samples'        -> int: The number of unique samples in the AlignmentDataset.
        'unique_prompts'        -> int: The number of unique prompts in the AlignmentDataset.
        'unique_chosen'         -> int: The number of unique chosen responses in the AlignmentDataset.
        'unique_rejected'       -> int: The number of unique rejected responses in the AlignmentDataset.
        'avg_prompt_length'     -> float: The average length of prompts in the AlignmentDataset.
        'avg_chosen_length'     -> float: The average length of chosen responses in the AlignmentDataset.
        'avg_rejected_length'   -> float: The average length of rejected responses in the AlignmentDataset.

    Note:
        If the input dataset is an AlignmentDataset (non-continual), this function
        returns a 1 element list with the relevant statistics.
    """
    if isinstance(dataset, AlignmentDataset):
        datasets = [dataset]
    else:
        # This assert is here to make mypy happy
        assert isinstance(dataset, ContinualAlignmentDataset)
        datasets = dataset.datasets

    results = []
    for dataset in datasets:
        results.append(_count_validation(dataset, remove_stop_words))
    return results

entropy

Functions:

  • entropy_validation

    Report various entropy measures on tokens in the dataset samples.

entropy_validation

entropy_validation(
    dataset: Dataset, remove_stop_words: bool = False
) -> List[Dict[str, float]]

Report various entropy measures on tokens in the dataset samples.

Parameters:

Returns:

  • List[Dict[str, float]]

    List[Dict[str, int]]: For every AligmentDataset, returns a dictionary with the following entries:

  • List[Dict[str, float]]

    'token_entropy' -> float: The entropy across tokens (prompts and responses combined) for all samples in the AlignmentDataset.

  • List[Dict[str, float]]

    'prompt_entropy' -> float: The entropy across prompts in samples of the AlignmentDataset.

  • List[Dict[str, float]]

    'chosen_entropy' -> float: The entropy across chosen responses in samples of the AlignmentDataset.

  • List[Dict[str, float]]

    'rejected_entropy' -> float: The entropy across rejected responses in the samples of the AlignmentDataset.

Note

If the input dataset is an AlignmentDataset (non-continual), this function returns a 1 element list with the relevant statistics.

Source code in aif_gen/validation/entropy.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def entropy_validation(
    dataset: Dataset, remove_stop_words: bool = False
) -> List[Dict[str, float]]:
    r"""Report various entropy measures on tokens in the dataset samples.

    Args:
        dataset (Union[ContinualAlignmentDataset, AlignmentDataset]): The dataset to validate.
        remove_stop_words (bool): If true, applies stop word removal before computing dataset counts.

    Returns:
        List[Dict[str, int]]: For every AligmentDataset, returns a dictionary with the following entries:

        'token_entropy'     -> float: The entropy across tokens (prompts and responses combined) for all samples in the AlignmentDataset.
        'prompt_entropy'    -> float: The entropy across prompts in samples of the AlignmentDataset.
        'chosen_entropy'    -> float: The entropy across chosen responses in samples of the AlignmentDataset.
        'rejected_entropy'  -> float: The entropy across rejected responses in the samples of the AlignmentDataset.

    Note:
        If the input dataset is an AlignmentDataset (non-continual), this function
        returns a 1 element list with the relevant statistics.
    """
    if isinstance(dataset, AlignmentDataset):
        datasets = [dataset]
    else:
        # This assert is here to make mypy happy
        assert isinstance(dataset, ContinualAlignmentDataset)
        datasets = dataset.datasets

    results = []
    for dataset in datasets:
        results.append(_entropy_validation(dataset, remove_stop_words))
    return results

Diversity

embedding_diversity

Functions:

  • llm_embedding_diversity

    Use the cosine distance of embeddings from an embedding model as a proxy for dataset diversity.

llm_embedding_diversity async

llm_embedding_diversity(
    dataset: Dataset,
    model_name: str,
    client: AsyncOpenAI,
    batch_size: int,
    async_semaphore: Semaphore,
    dry_run: bool = False,
) -> Optional[List[Optional[Dict[str, float]]]]

Use the cosine distance of embeddings from an embedding model as a proxy for dataset diversity.

Parameters:

  • dataset (Union[ContinualAlignmentDataset, AlignmentDataset]) –

    The dataset to validate.

  • model_name (str) –

    The vLLM-compatible model alias to use for embedding texts.

  • client (AsyncOpenAI) –

    Handle to openAI client.

  • batch_size (int) –

    Number of items to submit at a time.

  • async_semaphore (Semaphore) –

    Semaphore that manages number of concurrent API requests.

  • dry_run (bool, default: False ) –

    If True, validate a dummy sample to ensure the model is setup correctly.

Returns:

  • Optional[List[Optional[Dict[str, float]]]]

    Optional[List[Optional[Dict[str, float]]]]: For every AlignmentDataset, returns a dictionary with entries of the form '{metric}_{stat}': - Stat is one of ['mean', 'median', 'min', 'max'] - Metric denotes the embedding diversity for either ['prompt', 'chosen', 'rejected'] tokens.

Note
  • If the dataset is empty, we put None in place of the dictionary.
Source code in aif_gen/validation/embedding_diversity.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
async def llm_embedding_diversity(
    dataset: Dataset,
    model_name: str,
    client: openai.AsyncOpenAI,
    batch_size: int,
    async_semaphore: asyncio.Semaphore,
    dry_run: bool = False,
) -> Optional[List[Optional[Dict[str, float]]]]:
    r"""Use the cosine distance of embeddings from an embedding model as a proxy for dataset diversity.

    Args:
        dataset (Union[ContinualAlignmentDataset, AlignmentDataset]): The dataset to validate.
        model_name (str): The vLLM-compatible model alias to use for embedding texts.
        client (openai.AsyncOpenAI): Handle to openAI client.
        batch_size (int): Number of items to submit at a time.
        async_semaphore (asyncio.Semaphore): Semaphore that manages number of concurrent API requests.
        dry_run (bool): If True, validate a dummy sample to ensure the model is setup correctly.

    Returns:
        Optional[List[Optional[Dict[str, float]]]]: For every AlignmentDataset, returns a dictionary with entries of the form '{metric}_{stat}':
            - Stat is one of ['mean', 'median', 'min', 'max']
            - Metric denotes the embedding diversity for either ['prompt', 'chosen', 'rejected'] tokens.

    Note:
        - If the dataset is empty, we put None in place of the dictionary.
    """
    if dry_run:
        logging.info(f'Doing dry-run data validation on a single sample...')
        coro = _batch_embed(
            [''],
            client=client,
            model_name=model_name,
            async_semaphore=async_semaphore,
            extra_data='prompt',
        )
        try:
            _ = await coro
        except BaseException as e:
            logging.exception(f'Exception on dry-run, skipping validation: {e}')
            raise e
        logging.info('Dry run was a success.')
        return None

    if isinstance(dataset, AlignmentDataset):
        datasets = [dataset]
    else:
        # This assert is here to make mypy happy
        assert isinstance(dataset, ContinualAlignmentDataset)
        datasets = dataset.datasets

    futures = []
    for dataset_idx, dataset in enumerate(datasets):
        dataset_size = len(dataset)
        logging.info(f'Validating Dataset ({dataset_size} samples)')

        for batch in _batch_iterable(dataset.samples, batch_size=batch_size):
            texts = {
                'prompt': [sample.prompt for sample in batch],
                'chosen': [sample.chosen for sample in batch],
                'rejected': [sample.rejected for sample in batch],
            }
            for text_type, text in texts.items():
                coro = _batch_embed(
                    text,
                    client=client,
                    model_name=model_name,
                    async_semaphore=async_semaphore,
                    extra_data={'dataset_idx': dataset_idx, 'text_type': text_type},
                )
                futures.append(asyncio.create_task(coro))
    try:
        results: List[Dict[str, List]] = [
            defaultdict(list) for _ in range(len(datasets))
        ]
        for fut in tqdm.as_completed(futures, total=len(futures)):
            result = await fut
            if result is None:
                continue
            embeddings, extra_data = result
            dataset_idx, text_type = extra_data['dataset_idx'], extra_data['text_type']
            results[dataset_idx][text_type].extend(embeddings)

        aggregated_results: List[Optional[Dict[str, float]]] = []
        for i, dataset in enumerate(datasets):
            if not len(dataset):
                logging.warning('Skipping Embedding Diversity Eval for empty dataset')
                aggregated_results.append(None)
                continue

            for metric_name, metric_values in results[i].items():
                if len(metric_values) != len(dataset):
                    logging.warning(
                        f'Dataset {i} {metric_name} validation coverage: {len(metric_values)} / {len(dataset)}'
                    )
                if len(metric_values) == 0:
                    raise RuntimeError(f'No samples could be parsed in dataset {i}')

            aggregated_results.append(_compute_statistics(results[i]))
        return aggregated_results

    except BaseException as e:
        logging.exception(f'Exception occurred while generating dataset: {e}')
        for fut in futures:
            fut.cancel()
        await tqdm.gather(*futures)
        return None

LLM Judge

llm_judge

Functions:

llm_judge_validation async

llm_judge_validation(
    dataset: Dataset,
    model_name: str,
    client: AsyncOpenAI,
    async_semaphore: Semaphore,
    max_tokens_judge_response: int = 32,
    dry_run: bool = False,
) -> Optional[List[Optional[Dict[str, float]]]]

Use an LLM to judge the quality of the dataset.

Parameters:

  • dataset (Union[ContinualAlignmentDataset, AlignmentDataset]) –

    The dataset to validate.

  • model_name (str) –

    The vLLM-compatible model alias to use for validating the data.

  • client (AsyncOpenAI) –

    Handle to openAI client.

  • async_semaphore (Semaphore) –

    Semaphore that manages number of concurrent API requests.

  • max_tokens_judge_response (int, default: 32 ) –

    Configurable limit on the max_tokens for the generated judge response.

  • dry_run (bool, default: False ) –

    If True, validate a dummy sample to ensure the model is setup correctly.

Returns:

  • Optional[List[Optional[Dict[str, float]]]]

    Optional[List[Optional[Dict[str, float]]]]: For every AlignmentDataset, returns a dictionary with entries of the form '{metric}_{stat}': - Stat is one of ['mean', 'median', 'min', 'max'] - Metric is one of: 'alignment' -> Whether the chosen response is more aligned with the prompt compared to the rejected response. 'coherence_chosen' -> The coherence in the chosen response, as determined by the LLM. 'coherence_rejected' -> The coherence in the rejected response, as determined by the LLM.

Note
  • If the dataset is empty, we put None in place of the dictionary.
Source code in aif_gen/validation/llm_judge.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
async def llm_judge_validation(
    dataset: Dataset,
    model_name: str,
    client: openai.AsyncOpenAI,
    async_semaphore: asyncio.Semaphore,
    max_tokens_judge_response: int = 32,
    dry_run: bool = False,
) -> Optional[List[Optional[Dict[str, float]]]]:
    r"""Use an LLM to judge the quality of the dataset.

    Args:
        dataset (Union[ContinualAlignmentDataset, AlignmentDataset]): The dataset to validate.
        model_name (str): The vLLM-compatible model alias to use for validating the data.
        client (openai.AsyncOpenAI): Handle to openAI client.
        async_semaphore (asyncio.Semaphore): Semaphore that manages number of concurrent API requests.
        max_tokens_judge_response (int): Configurable limit on the max_tokens for the generated judge response.
        dry_run (bool): If True, validate a dummy sample to ensure the model is setup correctly.

    Returns:
        Optional[List[Optional[Dict[str, float]]]]: For every AlignmentDataset, returns a dictionary with entries of the form '{metric}_{stat}':
            - Stat is one of ['mean', 'median', 'min', 'max']
            - Metric is one of:
                'alignment'           -> Whether the chosen response is more aligned with the prompt compared to the rejected response.
                'coherence_chosen'    -> The coherence in the chosen response, as determined by the LLM.
                'coherence_rejected'  -> The coherence in the rejected response, as determined by the LLM.

    Note:
        - If the dataset is empty, we put None in place of the dictionary.
    """
    cache = await AsyncElasticsearchCache.maybe_from_env_var(
        f'CACHE_VALIDATION_{model_name}'
    )

    if dry_run:
        logging.info(f'Doing dry-run data validation on a single sample...')
        mock_sample = AlignmentDatasetSample('Mock', 'Mock', 'Mock')
        _prompt = _get_judge_prompt(
            mock_sample.prompt, mock_sample.chosen, mock_sample.rejected, 'foo'
        )
        coro = _get_score(
            _prompt,
            client,
            model_name,
            async_semaphore,
            max_tokens_judge_response,
            dataset_idx=-1,
            metric_name='',
            cache=cache,
        )
        try:
            _ = await coro
        except BaseException as e:
            logging.exception(f'Exception on dry-run, skipping validation: {e}')
            raise e
        finally:
            if cache is not None:
                await cache.close()

        logging.info('Dry run was a success.')
        return None

    if isinstance(dataset, AlignmentDataset):
        datasets = [dataset]
    else:
        # This assert is here to make mypy happy
        assert isinstance(dataset, ContinualAlignmentDataset)
        datasets = dataset.datasets

    futures = []
    for dataset_idx, dataset in enumerate(datasets):
        dataset_size = len(dataset)
        logging.info(f'Validating Dataset ({dataset_size} samples)')
        preference = dataset.task.preference
        for sample in dataset.samples:
            prompts = {
                'alignment': _get_judge_prompt(
                    sample.prompt, sample.chosen, sample.rejected, preference
                ),
                'coherence_chosen': _get_coherence_prompt(sample.chosen),
                'coherence_rejected': _get_coherence_prompt(sample.rejected),
            }
            for metric_name, prompt in prompts.items():
                coro = _get_score(
                    prompt,
                    client,
                    model_name,
                    async_semaphore,
                    max_tokens_judge_response,
                    dataset_idx=dataset_idx,
                    metric_name=metric_name,
                    cache=cache,
                )
                futures.append(asyncio.create_task(coro))
    try:
        results: List[Dict[str, List]] = [
            defaultdict(list) for _ in range(len(datasets))
        ]
        for fut in tqdm.as_completed(futures, total=len(futures)):
            result = await fut
            if result is None:
                continue
            score, dataset_idx, metric_name = result
            if score is not None:
                results[dataset_idx][metric_name].append(score)

        aggregated_results: List[Optional[Dict[str, float]]] = []
        for i, dataset in enumerate(datasets):
            if not len(dataset):
                logging.warning('Skipping LLM judge for empty dataset')
                aggregated_results.append(None)
                continue

            for metric_name, metric_values in results[i].items():
                if len(metric_values) != len(dataset):
                    logging.warning(
                        f'Dataset {i} {metric_name} validation coverage: {len(metric_values)} / {len(dataset)}'
                    )
                if len(metric_values) == 0:
                    raise RuntimeError(f'No samples could be parsed in dataset {i}')

            aggregated_results.append(_compute_statistics(results[i]))
        return aggregated_results

    except BaseException as e:
        logging.exception(f'Exception occurred while generating dataset: {e}')
        for fut in futures:
            fut.cancel()
        await tqdm.gather(*futures)
        return None

    finally:
        if cache is not None:
            await cache.close()