Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,12 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
yield self.__build_full_text_response(text)
text = ''
yield llm_response
if (
message.server_content.input_transcription
and message.server_content.input_transcription.text
):
if message.server_content.input_transcription:
llm_response = LlmResponse(
input_transcription=message.server_content.input_transcription,
)
yield llm_response
if (
message.server_content.output_transcription
and message.server_content.output_transcription.text
):
if message.server_content.output_transcription:
llm_response = LlmResponse(
output_transcription=message.server_content.output_transcription
)
Expand Down
50 changes: 50 additions & 0 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,53 @@ async def test_close(gemini_connection, mock_gemini_session):
await gemini_connection.close()

mock_gemini_session.close.assert_called_once()


@pytest.mark.asyncio
@pytest.mark.parametrize('tx_direction', ['input', 'output'])
async def test_receive_transcript_finished(
gemini_connection, mock_gemini_session, tx_direction
):
"""Test receive_transcript_finished for input and output transcription."""

finished_tx = types.Transcription(finished=True)

class Msg:

def __init__(self):
self.server_content = mock.Mock()
sc = self.server_content
sc.model_turn = None
if tx_direction == 'input':
sc.input_transcription = finished_tx
sc.output_transcription = None
else:
sc.input_transcription = None
sc.output_transcription = finished_tx
sc.interrupted = False
sc.turn_complete = False
self.tool_call = None
self.session_resumption_update = None

async def gen():
yield Msg()
Comment on lines +123 to +141

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability and conciseness, you can create the mock message object directly by setting attributes on a mock.Mock() instance, instead of defining a local Msg class. This makes the test setup more straightforward and less verbose.

  msg = mock.Mock()
  msg.tool_call = None
  msg.session_resumption_update = None
  msg.server_content.model_turn = None
  msg.server_content.interrupted = False
  msg.server_content.turn_complete = False
  msg.server_content.input_transcription = (
      finished_tx if tx_direction == 'input' else None
  )
  msg.server_content.output_transcription = (
      finished_tx if tx_direction == 'output' else None
  )

  async def gen():
    yield msg


mock_gemini_session.receive = mock.Mock(return_value=gen())

responses = []
async for r in gemini_connection.receive():
responses.append(r)

if tx_direction == 'input':
tx_resps = [r for r in responses if r.input_transcription]
else:
tx_resps = [r for r in responses if r.output_transcription]

if tx_direction == 'input':
assert tx_resps, 'Excpected input transcription response'
assert tx_resps[0].input_transcription.finished is True
assert not tx_resps[0].input_transcription.text
else:
assert tx_resps, 'Expected output transcription response'
assert tx_resps[0].output_transcription.finished is True
assert not tx_resps[0].output_transcription.text
Comment on lines +149 to +161

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion logic is duplicated for the 'input' and 'output' cases. You can refactor this to be more DRY (Don't Repeat Yourself) by using a variable for the attribute name and consolidating the checks. This makes the test easier to read and maintain. This change also corrects a typo in the assertion message ('Excpected' -> 'Expected').

  attr_name = f'{tx_direction}_transcription'
  tx_resps = [r for r in responses if getattr(r, attr_name)]
  assert tx_resps, f'Expected {tx_direction} transcription response'

  transcription = getattr(tx_resps[0], attr_name)
  assert transcription.finished is True
  assert not transcription.text