diff --git a/dpsynth/text/bulk_inference.py b/dpsynth/text/bulk_inference.py index 13a436e..f167ff9 100644 --- a/dpsynth/text/bulk_inference.py +++ b/dpsynth/text/bulk_inference.py @@ -17,7 +17,9 @@ from collections.abc import Sequence import dataclasses import enum +import functools import re +import time from typing import Protocol from absl import logging @@ -63,10 +65,11 @@ def annotate( Args: texts: Input texts to annotate. - schema: Pydantic model class defining the features to extract. The model's - field names, ``Literal`` type annotations, and field descriptions guide - the LLM. This same class is used as the ``response_schema`` for - constrained decoding in supported backends. + schema: Pydantic model class defining the features to extract. Fields may + use ``Literal`` type annotations for constrained decoding (the model is + forced to pick from the allowed values) or plain types such as ``str`` + for open-ended annotation where the model can produce any value. Field + names and descriptions guide the LLM. system_prompt: System-level instructions for the LLM describing how to annotate the texts. @@ -90,7 +93,7 @@ def generate(self, prompts: Sequence[str]) -> list[str]: ... -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class GenAIBackend: """TextGenerationBackend using the google.genai API. @@ -100,19 +103,26 @@ class GenAIBackend: Attributes: model: Model name string (e.g., ``'gemini-2.5-flash-lite'``). Accepts any ``ModelName`` enum value or arbitrary string for unlisted models. - api_key: API key for authentication. If None, uses Application Default - Credentials (ADC). + api_key: API key for authentication. + poll_interval_seconds: How often to poll for batch job completion. + chunk_size: Number of texts per batch job. + max_concurrent_jobs: Maximum number of active parallel batch jobs. """ model: str = ModelName.GEMINI_2_5_FLASH_LITE api_key: str | None = None + poll_interval_seconds: int = 60 + chunk_size: int = 100 + max_concurrent_jobs: int = 2 - def _make_client(self) -> genai.Client: - """Creates a genai client.""" - kwargs: dict[str, object] = { - 'http_options': types.HttpOptions(api_version='v1alpha'), - } - if self.api_key is not None: + # NOTE: client is cached on first access. Do not mutate attributes + # (model, api_key) after the client has + # been created — the cached instance will not reflect the changes. + @functools.cached_property + def client(self) -> genai.Client: + """Creates and caches a genai.Client.""" + kwargs = {'http_options': types.HttpOptions(api_version='v1alpha')} + if self.api_key: kwargs['api_key'] = self.api_key return genai.Client(**kwargs) @@ -122,31 +132,40 @@ def annotate( schema: type[pydantic.BaseModel], system_prompt: str, ) -> pd.DataFrame: - """Extract structured features via constrained decoding. + """Extract structured features via google.genai API (sequential). + + Always passes the ``schema`` as the ``response_schema`` to + ``generate_content``. When the schema contains ``Literal`` fields the + model is constrained to the allowed values; schemas with plain types + (e.g. ``str``) still benefit from the structural guidance but allow the + model to produce any value. Args: texts: Input texts to annotate. schema: Pydantic model used as the ``response_schema`` for constrained - decoding. + decoding when it contains ``Literal`` fields. Schemas with plain types + (e.g. ``str``) trigger free-form JSON generation guided by the system + prompt and field descriptions. system_prompt: System-level instructions for the LLM. Returns: DataFrame with exactly ``len(texts)`` rows. Failed rows have ``None``. """ - client = self._make_client() + client = self.client field_names = list(schema.model_fields.keys()) null_row = {f: None for f in field_names} rows: list[dict[str, str | None]] = [] + config = types.GenerateContentConfig( + system_instruction=system_prompt, + response_mime_type='application/json', + response_schema=schema, + ) for i, text in enumerate(texts): try: response = client.models.generate_content( model=self.model, contents=text, - config=types.GenerateContentConfig( - system_instruction=system_prompt, - response_mime_type='application/json', - response_schema=schema, - ), + config=config, ) if response.text: cleaned = _strip_markdown_fences(response.text) @@ -155,11 +174,195 @@ def annotate( else: logging.warning('Empty annotation response for text %d.', i) rows.append(null_row) - except Exception: # pylint: disable=broad-except - logging.warning('Annotation failed for text %d.', i) + except Exception as e: # pylint: disable=broad-except + logging.warning( + 'Annotation failed for text %d. Error: %s', i, e, exc_info=True + ) rows.append(null_row) return pd.DataFrame(rows) + def batch_annotate( + self, + texts: Sequence[str], + schema: type[pydantic.BaseModel], + system_prompt: str, + chunk_size: int | None = None, + max_concurrent_jobs: int | None = None, + ) -> pd.DataFrame: + """Extract structured features via the GenAI Batch API. + + Submits texts as inlined requests to the batch prediction endpoint, + polls for completion, and parses the inlined responses. + + Args: + texts: Input texts to annotate. + schema: Pydantic model used as the ``response_schema``. + system_prompt: System-level instructions for the LLM. + chunk_size: Number of texts per batch job. + max_concurrent_jobs: Maximum number of active parallel batch jobs. + + Returns: + DataFrame with exactly ``len(texts)`` rows. Failed rows have ``None``. + + Raises: + RuntimeError: If the batch job fails or is cancelled. + """ + client = self.client + field_names = list(schema.model_fields.keys()) + null_row = {f: None for f in field_names} + + if chunk_size is None: + chunk_size = self.chunk_size + if max_concurrent_jobs is None: + max_concurrent_jobs = self.max_concurrent_jobs + + if chunk_size <= 0: + raise ValueError('chunk_size must be positive.') + if max_concurrent_jobs <= 0: + raise ValueError('max_concurrent_jobs must be positive.') + + jobs = [] + + config = types.GenerateContentConfig( + system_instruction=system_prompt, + response_mime_type='application/json', + response_schema=schema, + ) + + offsets = list(range(0, len(texts), chunk_size)) + num_chunks = len(offsets) + active_jobs = [] + chunk_idx = 0 + + logging.info( + 'Batch annotate: starting processing of %d chunks with concurrency' + ' limit %d...', + num_chunks, + max_concurrent_jobs, + ) + + while chunk_idx < num_chunks or active_jobs: + # Submit new jobs up to the concurrency limit + while len(active_jobs) < max_concurrent_jobs and chunk_idx < num_chunks: + offset = offsets[chunk_idx] + chunk_texts = texts[offset : offset + chunk_size] + logging.info( + 'Batch annotate: submitting inline chunk %d/%d (size=%d)...', + chunk_idx + 1, + num_chunks, + len(chunk_texts), + ) + inlined_requests = [ + types.InlinedRequest(contents=text, config=config) + for text in chunk_texts + ] + batch_job = client.batches.create( + model=self.model, + src=inlined_requests, + ) + logging.info( + 'Batch annotate: job %s created for chunk %d/%d', + batch_job.name, + chunk_idx + 1, + num_chunks, + ) + job_info = { + 'chunk_idx': chunk_idx, + 'chunk_texts': chunk_texts, + 'job_name': batch_job.name, + 'job': batch_job, + } + jobs.append(job_info) + active_jobs.append(job_info) + chunk_idx += 1 + + # Poll active jobs if there are any + if active_jobs: + logging.info( + 'Batch annotate: %d active jobs. Polling in %ds...', + len(active_jobs), + self.poll_interval_seconds, + ) + time.sleep(self.poll_interval_seconds) + + for j in active_jobs: + try: + j['job'] = client.batches.get(name=j['job_name']) + except Exception as e: # pylint: disable=broad-except + logging.warning('Failed to poll job %s: %s', j['job_name'], e) + + # Filter out finished jobs + still_active = [] + for j in active_jobs: + if j['job'].done: + logging.info( + 'Batch annotate: job %s completed with state=%s', + j['job_name'], + j['job'].state, + ) + else: + still_active.append(j) + active_jobs = still_active + + logging.info('Batch annotate: all jobs completed. Parsing responses...') + + # Step 4: Parse responses in order + all_rows = [] + for j in jobs: + batch_job = j['job'] + chunk_texts = j['chunk_texts'] + job_name = j['job_name'] + + if batch_job.state != types.JobState.JOB_STATE_SUCCEEDED: + error_msg = f'Batch job {job_name} ended with state={batch_job.state}.' + if batch_job.error: + error_msg += f' Error: {batch_job.error}' + raise RuntimeError(error_msg) + + inlined_responses = ( + batch_job.dest.inlined_responses if batch_job.dest else [] + ) or [] + + chunk_rows = [] + for i, inlined_resp in enumerate(inlined_responses): + try: + if inlined_resp.error: + logging.warning( + 'Batch result %d in job %s had error: %s', + i, + job_name, + inlined_resp.error, + ) + chunk_rows.append(null_row) + continue + + response = inlined_resp.response + if response and response.text: + cleaned = _strip_markdown_fences(response.text) + parsed = schema.model_validate_json(cleaned) + chunk_rows.append(parsed.model_dump()) + else: + logging.warning( + 'Empty batch response in job %s for text %d.', job_name, i + ) + chunk_rows.append(null_row) + except Exception as e: # pylint: disable=broad-except + logging.warning( + 'Failed to parse batch result %d in job %s: %s', i, job_name, e + ) + chunk_rows.append(null_row) + + # Ensure index alignment for this chunk + if len(chunk_rows) != len(chunk_texts): + raise ValueError( + f'Batch annotate: job {job_name} got {len(chunk_rows)} results for' + f' {len(chunk_texts)} inputs.' + ) + + all_rows.extend(chunk_rows) + + return pd.DataFrame(all_rows) + def generate(self, prompts: Sequence[str]) -> list[str]: """Generate free-form text via google.genai. @@ -169,7 +372,7 @@ def generate(self, prompts: Sequence[str]) -> list[str]: Returns: List of exactly ``len(prompts)`` strings. Empty string on failure. """ - client = self._make_client() + client = self.client results: list[str] = [] for i, prompt in enumerate(prompts): try: @@ -178,8 +381,10 @@ def generate(self, prompts: Sequence[str]) -> list[str]: contents=prompt, ) results.append(response.text or '') - except Exception: # pylint: disable=broad-except - logging.warning('Generation failed for prompt %d.', i) + except Exception as e: # pylint: disable=broad-except + logging.warning( + 'Generation failed for prompt %d. Error: %s', i, e, exc_info=True + ) results.append('') return results diff --git a/tests/text/bulk_inference_test.py b/tests/text/bulk_inference_test.py index 995b647..8c3fe3f 100644 --- a/tests/text/bulk_inference_test.py +++ b/tests/text/bulk_inference_test.py @@ -131,6 +131,222 @@ def test_annotate_handles_markdown_fenced_json(self, mock_client_cls): self.assertEqual(df.iloc[0]['complexity'], 'High') +@mock.patch('google.genai.Client', autospec=True) +class GenAIBackendBatchAnnotateTest(absltest.TestCase): + """Tests for GenAIBackend.batch_annotate.""" + + @staticmethod + def _create_inlined_response( + text: str | None = None, error: str | None = None + ) -> mock.MagicMock: + """Creates a mock inlined batch response.""" + mock_response = mock.MagicMock() + mock_response.text = text + return mock.MagicMock( + error=error, response=mock_response if text else None + ) + + @staticmethod + def _create_mock_job( + state, + inlined_responses: list[mock.MagicMock] | None = None, + done_side_effect: list[bool] | None = None, + inlined_responses_side_effect: list[list[mock.MagicMock]] | None = None, + error: str | None = None, + name: str = 'job', + ) -> mock.MagicMock: + """Creates a mock batch job.""" + job = mock.MagicMock() + job.name = name + if done_side_effect is not None: + type(job).done = mock.PropertyMock(side_effect=done_side_effect) + else: + type(job).done = mock.PropertyMock(return_value=True) + job.state = state + job.error = error + + mock_dest = mock.MagicMock() + if inlined_responses_side_effect is not None: + type(mock_dest).inlined_responses = mock.PropertyMock( + side_effect=inlined_responses_side_effect + ) + elif inlined_responses is not None: + mock_dest.inlined_responses = inlined_responses + job.dest = mock_dest + return job + + def test_batch_annotate_success(self, mock_client_cls): + mock_client = mock_client_cls.return_value + inlined_resp = self._create_inlined_response( + '{"topic": "Science", "complexity": "High"}' + ) + mock_batch_job = self._create_mock_job( + state=bulk_inference.types.JobState.JOB_STATE_SUCCEEDED, + inlined_responses=[inlined_resp, inlined_resp], + done_side_effect=[False, True], + ) + + mock_client.batches.create.return_value = mock_batch_job + mock_client.batches.get.return_value = mock_batch_job + + backend = bulk_inference.GenAIBackend( + api_key='fake', poll_interval_seconds=0 + ) + df = backend.batch_annotate(['text1', 'text2'], SimpleFeatures, 'Sys.') + self.assertLen(df, 2) + self.assertEqual(df.iloc[0]['topic'], 'Science') + self.assertEqual(df.iloc[1]['topic'], 'Science') + + def test_batch_annotate_raises_on_length_mismatch(self, mock_client_cls): + mock_client = mock_client_cls.return_value + inlined_resp = self._create_inlined_response( + '{"topic": "Science", "complexity": "High"}' + ) + mock_batch_job = self._create_mock_job( + state=bulk_inference.types.JobState.JOB_STATE_SUCCEEDED, + inlined_responses=[inlined_resp], + ) + + mock_client.batches.create.return_value = mock_batch_job + mock_client.batches.get.return_value = mock_batch_job + + backend = bulk_inference.GenAIBackend( + api_key='fake', poll_interval_seconds=0 + ) + with self.assertRaisesRegex(ValueError, 'got 1 results for 2 inputs'): + backend.batch_annotate(['text1', 'text2'], SimpleFeatures, 'Sys.') + + def test_batch_annotate_fills_none_on_item_error(self, mock_client_cls): + mock_client = mock_client_cls.return_value + inlined_error = self._create_inlined_response(error='Failed item') + inlined_success = self._create_inlined_response( + '{"topic": "Science", "complexity": "High"}' + ) + + mock_batch_job = self._create_mock_job( + state=bulk_inference.types.JobState.JOB_STATE_SUCCEEDED, + inlined_responses=[inlined_error, inlined_success], + ) + + mock_client.batches.create.return_value = mock_batch_job + mock_client.batches.get.return_value = mock_batch_job + + backend = bulk_inference.GenAIBackend( + api_key='fake', poll_interval_seconds=0 + ) + df = backend.batch_annotate(['text1', 'text2'], SimpleFeatures, 'Sys.') + self.assertLen(df, 2) + self.assertTrue(pd.isna(df.iloc[0]['topic'])) + self.assertEqual(df.iloc[1]['topic'], 'Science') + + def test_batch_annotate_raises_on_failed_job(self, mock_client_cls): + mock_client = mock_client_cls.return_value + mock_batch_job = self._create_mock_job( + state=bulk_inference.types.JobState.JOB_STATE_FAILED, + error='Something went wrong', + ) + + mock_client.batches.create.return_value = mock_batch_job + mock_client.batches.get.return_value = mock_batch_job + + backend = bulk_inference.GenAIBackend( + api_key='fake', poll_interval_seconds=0 + ) + with self.assertRaisesRegex( + RuntimeError, + 'Batch job .* ended with state.* Error: Something went wrong', + ): + backend.batch_annotate(['text1'], SimpleFeatures, 'Sys.') + + def test_batch_annotate_respects_class_chunk_size(self, mock_client_cls): + mock_client = mock_client_cls.return_value + inlined_resp = self._create_inlined_response( + '{"topic": "Science", "complexity": "High"}' + ) + mock_batch_job = self._create_mock_job( + state=bulk_inference.types.JobState.JOB_STATE_SUCCEEDED, + inlined_responses=[inlined_resp, inlined_resp], + ) + + mock_client.batches.create.return_value = mock_batch_job + mock_client.batches.get.return_value = mock_batch_job + + backend = bulk_inference.GenAIBackend( + api_key='fake', poll_interval_seconds=0, chunk_size=2 + ) + # 4 texts, chunk_size=2 -> 2 chunks/jobs + backend.batch_annotate(['t1', 't2', 't3', 't4'], SimpleFeatures, 'Sys.') + self.assertEqual(mock_client.batches.create.call_count, 2) + + def test_batch_annotate_respects_override_chunk_size(self, mock_client_cls): + mock_client = mock_client_cls.return_value + inlined_resp = self._create_inlined_response( + '{"topic": "Science", "complexity": "High"}' + ) + + mock_batch_job = self._create_mock_job( + state=bulk_inference.types.JobState.JOB_STATE_SUCCEEDED, + inlined_responses_side_effect=[ + [inlined_resp] * 3, + [inlined_resp] * 1, + ], + ) + + mock_client.batches.create.return_value = mock_batch_job + mock_client.batches.get.return_value = mock_batch_job + + backend = bulk_inference.GenAIBackend( + api_key='fake', + poll_interval_seconds=0, + chunk_size=10, + max_concurrent_jobs=5, + ) + backend.batch_annotate( + ['t1', 't2', 't3', 't4'], SimpleFeatures, 'Sys.', chunk_size=3 + ) + self.assertEqual(mock_client.batches.create.call_count, 2) + + def test_batch_annotate_respects_max_concurrent_jobs(self, mock_client_cls): + mock_client = mock_client_cls.return_value + create_resp = GenAIBackendBatchAnnotateTest._create_inlined_response + create_job = GenAIBackendBatchAnnotateTest._create_mock_job + + jobs = [] + + def create_side_effect(model, src): + del model # Unused. + inlined_resp = create_resp( + '{"topic": "Science", "complexity": "High"}' + ) + job = create_job( + state=bulk_inference.types.JobState.JOB_STATE_SUCCEEDED, + inlined_responses=[inlined_resp] * len(src), + done_side_effect=[False, True], + name=f'job-{len(jobs)}', + ) + jobs.append(job) + return job + + mock_client.batches.create.side_effect = create_side_effect + + def get_side_effect(name): + idx = int(name.split('-')[1]) + return jobs[idx] + + mock_client.batches.get.side_effect = get_side_effect + + backend = bulk_inference.GenAIBackend( + api_key='fake', + poll_interval_seconds=0, + chunk_size=2, + max_concurrent_jobs=1, + ) + backend.batch_annotate( + ['t1', 't2', 't3', 't4', 't5'], SimpleFeatures, 'Sys.' + ) + self.assertEqual(mock_client.batches.create.call_count, 3) + + class GenAIBackendGenerateTest(absltest.TestCase): @mock.patch('google.genai.Client')