Skip to content

Generation

Generation

LLM Inference Engine Service

engine

Functions:

generate_continual_dataset async

generate_continual_dataset(
    data_config: Dict[str, Any],
    model_name: str,
    client: AsyncOpenAI,
    async_semaphore: Semaphore,
    max_tokens_prompt_response: int = 1024,
    max_tokens_chosen_rejected_response: int = 2048,
    dry_run: bool = False,
    include_preference_axes: bool = False,
    temperature: float = 1.0,
) -> Optional[ContinualAlignmentDataset]

Generate a ContinualAlignmentDataset dataset given the AlignmentTask, and model.

Parameters:

  • data_config (Dict[str, Any]) –

    Configuration file storing tasks specifications and model info.

  • model_name (str) –

    The vLLM-compatible model alias to use for generation synthetic samples.

  • client (AsyncOpenAI) –

    Handle to openAI client.

  • async_semaphore (Semaphore) –

    Semaphore that manages number of concurrent API requests.

  • max_tokens_prompt_response (int, default: 1024 ) –

    Configurable limit on the max_tokens for the generated prompt response.

  • max_tokens_chosen_rejected_response (int, default: 2048 ) –

    Configurable limit on the max_tokens for the generated chosen and rejected response.

  • dry_run (bool, default: False ) –

    If True, ignore the config and generate a dummy sample to ensure the model is setup correctly.

  • include_preference_axes (bool, default: False ) –

    If True, include the preference axes in the prompt for response mapper.

  • temperature (float, default: 1.0 ) –

    Temperature for the model.

Returns:

Source code in aif_gen/generate/engine.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
async def generate_continual_dataset(
    data_config: Dict[str, Any],
    model_name: str,
    client: openai.AsyncOpenAI,
    async_semaphore: asyncio.Semaphore,
    max_tokens_prompt_response: int = 1024,
    max_tokens_chosen_rejected_response: int = 2048,
    dry_run: bool = False,
    include_preference_axes: bool = False,
    temperature: float = 1.0,
) -> Optional[ContinualAlignmentDataset]:
    r"""Generate a ContinualAlignmentDataset dataset given the AlignmentTask, and model.

    Args:
        data_config (Dict[str, Any]): Configuration file storing tasks specifications and model info.
        model_name (str): The vLLM-compatible model alias to use for generation synthetic samples.
        client (openai.AsyncOpenAI): Handle to openAI client.
        async_semaphore (asyncio.Semaphore): Semaphore that manages number of concurrent API requests.
        max_tokens_prompt_response (int): Configurable limit on the max_tokens for the generated prompt response.
        max_tokens_chosen_rejected_response (int): Configurable limit on the max_tokens for the generated chosen and rejected response.
        dry_run (bool): If True, ignore the config and generate a dummy sample to ensure the model is setup correctly.
        include_preference_axes (bool): If True, include the preference axes in the prompt for response mapper.
        temperature (float): Temperature for the model.

    Returns:
        Optional[ContinualAlignmentDataset]: The synthetically generated dataset.
    """
    prompt_mapper = PromptMapper()
    response_mapper = ResponseMapper()
    task_specs = data_config['task_specs']

    if dry_run:
        logging.info(f'Doing dry-run data generation on a single sample...')
        mock_task = AlignmentTask.from_dict(task_specs[0]['alignment_task'])
        coro = _generate_sample(
            mock_task,
            client,
            model_name,
            prompt_mapper,
            response_mapper,
            async_semaphore,
            max_tokens_prompt_response,
            max_tokens_chosen_rejected_response,
            dataset_idx=-1,
            prompt_idx=-1,
            cache=None,
            include_preference_axes=include_preference_axes,
            temperature=temperature,
        )
        try:
            _ = await coro
        except BaseException as e:
            logging.exception(f'Exception on dry-run, skipping generation: {e}')
            raise e
        logging.info('Dry run was a success.')
        return None

    cache = await AsyncElasticsearchCache.maybe_from_env_var(
        index_name=f'CACHE_DATA_GENERATION_{model_name}'
    )
    futures, tasks, dataset_sizes = [], [], []
    for dataset_idx, task_spec in enumerate(task_specs):
        task = AlignmentTask.from_dict(task_spec['alignment_task'])
        dataset_size = task_spec['num_samples']
        logging.info(f'Generating Dataset ({dataset_size} samples) {task}')

        tasks.append(task)
        dataset_sizes.append(dataset_size)
        for _sample_idx in range(dataset_size):
            coro = _generate_sample(
                task,
                client,
                model_name,
                prompt_mapper,
                response_mapper,
                async_semaphore,
                max_tokens_prompt_response,
                max_tokens_chosen_rejected_response,
                dataset_idx=dataset_idx,
                prompt_idx=_sample_idx,
                cache=cache,
                include_preference_axes=include_preference_axes,
                temperature=temperature,
            )
            futures.append(asyncio.create_task(coro))

    try:
        samples: List[List[AlignmentDatasetSample]] = [
            [] for _ in range(len(dataset_sizes))
        ]
        for fut in tqdm.as_completed(futures, total=len(futures)):
            result = await fut
            if result is not None:
                sample, dataset_idx = result
                samples[dataset_idx].append(sample)

        continual_dataset = ContinualAlignmentDataset(datasets=[])
        for i in range(len(samples)):
            if len(samples[i]) != dataset_sizes[i]:
                logging.warning(
                    f'Dataset {i} requested {dataset_sizes[i]} samples but LM generated {len(samples[i])}'
                )
            continual_dataset.append(AlignmentDataset(tasks[i], samples[i]))

        # If preference axes included, use judge to pick chosen/rejected responses
        if include_preference_axes:
            from aif_gen.validation.llm_judge import (
                _get_judge_prompt,
                _get_score,
            )

            cache_judge = await AsyncElasticsearchCache.maybe_from_env_var(
                index_name=f'CACHE_DATA_GENERATION_JUDGE_{model_name}'
            )
            assert isinstance(continual_dataset, ContinualAlignmentDataset)

            futures = []
            datasets = continual_dataset.datasets
            for dataset_idx, dataset in enumerate(datasets):
                dataset_size = len(dataset)
                logging.info(f'Judging dataset ({dataset_size} samples) {dataset.task}')
                preference = dataset.task.preference
                for sample in dataset.samples:
                    judge_coro = _get_score(
                        _get_judge_prompt(
                            sample.prompt, sample.chosen, sample.rejected, preference
                        ),
                        client,
                        model_name,
                        async_semaphore,
                        max_tokens_judge_response=64,
                        dataset_idx=dataset_idx,
                        metric_name='alignment_generation',
                        cache=cache_judge,
                    )
                    futures.append(asyncio.create_task(judge_coro))  # type: ignore
            try:
                results: List[Dict[str, List[float]]] = [
                    defaultdict(list) for _ in range(len(datasets))
                ]
                for fut in tqdm.as_completed(futures, total=len(futures)):
                    result = await fut
                    if result is None:
                        continue

                    score, dataset_idx, metric_name = result
                    if score is not None:
                        results[dataset_idx][metric_name].append(score)

                for dataset_idx, dataset in enumerate(datasets):
                    if not len(dataset):
                        logging.warning(f'Dataset {dataset_idx} empty, skipping judge.')
                        continue

                    dataset_scores = results[dataset_idx]
                    dataset_samples = dataset.samples
                    for sample_idx, sample in enumerate(dataset_samples):
                        # guard against missing / malformed scores
                        scores = dataset_scores.get('alignment_generation', [])
                        if sample_idx >= len(scores):
                            logging.warning(
                                f'No judge score for sample {sample_idx} in dataset {dataset_idx}, skipping.'
                            )
                            continue
                        score = scores[sample_idx]
                        if score not in (0.0, 1.0):
                            logging.warning(
                                f'Bad judge score {score!r} for sample {sample_idx}, skipping.'
                            )
                            continue
                        if score == 0.0:  # swap if judge says response 2 is better
                            sample.chosen, sample.rejected = (
                                sample.rejected,
                                sample.chosen,
                            )
                logging.info('Judging preference completed.')

            except BaseException as e:
                logging.exception(f'Exception while judging preference: {e}')
                for fut in futures:
                    fut.cancel()
                await asyncio.gather(*futures, return_exceptions=True)
                return None
            finally:
                if cache_judge is not None:
                    await cache_judge.close()
                if cache is not None:
                    await cache.close()
        return continual_dataset

    except BaseException as e:
        logging.exception(f'Exception occurred while generating dataset: {e}')
        for fut in futures:
            fut.cancel()
        await tqdm.gather(*futures)
        return None

    finally:
        if cache is not None:
            await cache.close()

Prompt Mapper

PromptMapperBase

Bases: ABC

Methods:

  • generate_prompt

    Generate a prompt that, when given to a language model, produces a prompt for a given AlignmentTask.

generate_prompt abstractmethod

generate_prompt(task: AlignmentTask) -> str

Generate a prompt that, when given to a language model, produces a prompt for a given AlignmentTask.

Parameters:

  • task (AlignmentTask) –

    The alignment task containing the domain, objective, and preferences.

Returns:

  • str ( str ) –

    A structured prompt string for the LLM.

Source code in aif_gen/generate/mappers/base.py
 9
10
11
12
13
14
15
16
17
18
@abstractmethod
def generate_prompt(self, task: AlignmentTask) -> str:
    r"""Generate a prompt that, when given to a language model, produces a prompt for a given AlignmentTask.

    Args:
        task (AlignmentTask): The alignment task containing the domain, objective, and preferences.

    Returns:
        str: A structured prompt string for the LLM.
    """

PromptMapper

PromptMapper(
    max_seed_word_samples: int = 2,
    suffix_context: Optional[str] = None,
)

Bases: PromptMapperBase

Generate a prompt that, when given to a language model, produces a prompt for a given AlignmentTask.

Samples domain component seed words (without replacement) from the AlignmentTask to contextualize the prompt. The sampling is parameterized by the weight of each component of the domain.

Parameters:

  • max_seed_word_samples (int, default: 2 ) –

    Maximum number of seed words to sample across all domain components (default=2)

  • suffix_context (Optional[str]=None, default: None ) –

    Optional suffix text to add at the end of the generated prompt.

Attributes:

Source code in aif_gen/generate/mappers/prompt_mapper.py
22
23
24
25
26
27
28
def __init__(
    self, max_seed_word_samples: int = 2, suffix_context: Optional[str] = None
) -> None:
    if max_seed_word_samples <= 0:
        raise ValueError(f'Got negative seed word samples: {max_seed_word_samples}')
    self._max_seed_word_samples = max_seed_word_samples
    self._suffix_context = suffix_context

max_seed_word_samples property

max_seed_word_samples: int

Maximum number of seed words to sample across all domain components.

Response Mapper

ResponseMapperBase

Bases: ABC

Methods:

  • generate_prompt

    Generate a prompt that, when given to a language model, produces a (chosen, rejected)

generate_prompt abstractmethod

generate_prompt(
    task: AlignmentTask, task_prompt: str
) -> str

Generate a prompt that, when given to a language model, produces a (chosen, rejected) response pair for the task_prompt and AlignmentTask.

Parameters:

  • task (AlignmentTask) –

    The alignment task containing the domain, objective, and preferences.

  • task_prompt (str) –

    The task prompt to generated responses for.

Returns:

  • str ( str ) –

    A structured prompt string for the LLM.

Source code in aif_gen/generate/mappers/base.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@abstractmethod
def generate_prompt(
    self,
    task: AlignmentTask,
    task_prompt: str,
) -> str:
    r"""Generate a prompt that, when given to a language model, produces a (chosen, rejected)
    response pair for the task_prompt and AlignmentTask.

    Args:
        task (AlignmentTask): The alignment task containing the domain, objective, and preferences.
        task_prompt (str): The task prompt to generated responses for.

    Returns:
        str: A structured prompt string for the LLM.
    """

ResponseMapper

ResponseMapper(suffix_context: Optional[str] = None)

Bases: ResponseMapperBase

Generate a prompt that, when given to a language model, produces a winning and losing response to the task_prompt.

Parameters:

  • suffix_context (Optional[str]=None, default: None ) –

    Optional suffix text to add at the end of the generated prompt.

Source code in aif_gen/generate/mappers/response_mapper.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def __init__(self, suffix_context: Optional[str] = None) -> None:
    self._suffix_context = suffix_context
    self._preference_axes = [
        ('short', 'long'),
        ('formal', 'casual'),
        ('helpful', 'harmful'),
        ('expert', 'eli5'),
        ('direct', 'hinted'),
        ('authoritative', 'tentative'),
        ('friendly', 'distance'),
        ('optimistic', 'pessimistic'),
        ('serious', 'humorous'),
        ('respectful', 'disrespectful'),
        ('complex', 'simple'),
        ('neutral', 'biased'),
        ('detailed', 'abstract'),
    ]  # TODO could be added to the config - or finalized

Caching

AsyncElasticsearchCache

AsyncElasticsearchCache(
    es: AsyncElasticsearch, index_name: str
)

Methods:

  • close

    Close Elasticsearch connection.

  • get

    Try reading response from cache.

  • maybe_from_env_var

    Initialize from env var. Returns None if any of the required env vars are missing.

  • set

    Set/Update cache.

Source code in aif_gen/generate/caching.py
12
13
14
15
16
17
18
19
def __init__(self, es: AsyncElasticsearch, index_name: str) -> None:
    self.es = es
    self.index_name = index_name
    logging.info(f'Elastic Index Name: {self.index_name}')

    self.is_refresh_required = bool(os.environ.get('FORCE_CACHE_REFRESH'))
    if self.is_refresh_required:
        logging.warning('FORCE_CACHE_REFRESH is enabled. All queries will miss.')

close async

close() -> None

Close Elasticsearch connection.

Source code in aif_gen/generate/caching.py
82
83
84
async def close(self) -> None:
    """Close Elasticsearch connection."""
    await self.es.close()

get async

get(query: str, nonce: Optional[str] = None) -> str | None

Try reading response from cache.

Parameters:

  • query (str) –

    The query to fetch from cache.

  • nonce (str, default: None ) –

    An optional nonce to differentiate cache entries.

Returns:

  • str | None

    str | None: Cached result if available.

Source code in aif_gen/generate/caching.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
async def get(self, query: str, nonce: Optional[str] = None) -> str | None:
    """Try reading response from cache.

    Args:
        query (str): The query to fetch from cache.
        nonce (str, optional): An optional nonce to differentiate cache entries.

    Returns:
        str | None: Cached result if available.
    """
    if self.is_refresh_required:
        return None

    # Cache lookup
    query_hash = self._get_query_hash(query=query, nonce=nonce)
    try:
        response = await self.es.get(index=self.index_name, id=query_hash)
        if response.get('found'):
            logging.info(f'Cache hit: {query_hash}')
            return response['_source']['result']
    except Exception:
        logging.debug(f'Cache miss: {query_hash}')

    return None  # Cache miss or index doesn't exist

maybe_from_env_var async staticmethod

maybe_from_env_var(
    index_name: str,
) -> AsyncElasticsearchCache | None

Initialize from env var. Returns None if any of the required env vars are missing.

Source code in aif_gen/generate/caching.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@staticmethod
async def maybe_from_env_var(index_name: str) -> AsyncElasticsearchCache | None:
    """Initialize from env var. Returns None if any of the required env vars are missing."""
    required_env_keys = ['ELASTIC_SEARCH_HOST', 'ELASTIC_SEARCH_API_KEY']
    if not all((_key in os.environ) for _key in required_env_keys):
        logging.warning(
            'All of these are required to enable ElasticsearchCache: '
            f'{required_env_keys}.'
            ' Not enabling ElasticsearchCache since some keys are not set.'
        )
        return None

    es = AsyncElasticsearch(
        os.environ['ELASTIC_SEARCH_HOST'],
        api_key=os.environ['ELASTIC_SEARCH_API_KEY'],
        request_timeout=None,
    )

    # Ensure the index exists at startup (parse the name if '/' is present)
    index_name = index_name.lower().replace('/', '_')
    if not await es.indices.exists(index=index_name):
        await es.indices.create(index=index_name)
    return AsyncElasticsearchCache(es=es, index_name=index_name)

set async

set(
    query: str, value: str, nonce: Optional[str] = None
) -> None

Set/Update cache.

Parameters:

  • query (str) –

    The query whose result is to be cached.

  • value (str) –

    The value to store in cache.

  • nonce (str, default: None ) –

    An optional nonce to differentiate cache entries.

Source code in aif_gen/generate/caching.py
70
71
72
73
74
75
76
77
78
79
80
async def set(self, query: str, value: str, nonce: Optional[str] = None) -> None:
    """Set/Update cache.

    Args:
        query (str): The query whose result is to be cached.
        value (str): The value to store in cache.
        nonce (str, optional): An optional nonce to differentiate cache entries.
    """
    query_hash = self._get_query_hash(query=query, nonce=nonce)
    doc = {'query': query, 'result': value, 'nonce': nonce}
    await self.es.index(index=self.index_name, id=query_hash, document=doc)