diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java index 10e0d13f47e..30511d8488f 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.image.ImageResponseFormat; /** * The configuration information for a image generation request. @@ -81,7 +82,7 @@ public class AzureOpenAiImageOptions implements ImageOptions { * b64_json. */ @JsonProperty("response_format") - private String responseFormat; + private ImageResponseFormat responseFormat; /** * The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024 for @@ -150,13 +151,21 @@ public void setHeight(Integer height) { @Override public String getResponseFormat() { + return (this.responseFormat != null) ? this.responseFormat.getValue() : null; + } + + public ImageResponseFormat getResponseFormatEnum() { return this.responseFormat; } - public void setResponseFormat(String responseFormat) { + public void setResponseFormat(ImageResponseFormat responseFormat) { this.responseFormat = responseFormat; } + public void setResponseFormat(String responseFormat) { + this.responseFormat = ImageResponseFormat.fromValue(responseFormat); + } + public String getSize() { if (this.size != null) { return this.size; @@ -279,6 +288,11 @@ public Builder deploymentName(String deploymentName) { return this; } + public Builder responseFormat(ImageResponseFormat responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + public Builder responseFormat(String responseFormat) { this.options.setResponseFormat(responseFormat); return this; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java index 3b294a5b02b..bdebfd38dfb 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.image.ImageResponseFormat; /** * OpenAI Image API options. OpenAiImageOptions.java @@ -79,7 +80,7 @@ public class OpenAiImageOptions implements ImageOptions { * b64_json. */ @JsonProperty("response_format") - private String responseFormat; + private ImageResponseFormat responseFormat; /** * The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024 for @@ -159,13 +160,21 @@ public void setQuality(String quality) { @Override public String getResponseFormat() { + return (this.responseFormat != null) ? this.responseFormat.getValue() : null; + } + + public ImageResponseFormat getResponseFormatEnum() { return this.responseFormat; } - public void setResponseFormat(String responseFormat) { + public void setResponseFormat(ImageResponseFormat responseFormat) { this.responseFormat = responseFormat; } + public void setResponseFormat(String responseFormat) { + this.responseFormat = ImageResponseFormat.fromValue(responseFormat); + } + @Override public Integer getWidth() { if (this.width != null) { @@ -326,6 +335,11 @@ public Builder quality(String quality) { return this; } + public Builder responseFormat(ImageResponseFormat responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + public Builder responseFormat(String responseFormat) { this.options.setResponseFormat(responseFormat); return this; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java index faa266ebbeb..b7506fbd8cf 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java @@ -18,6 +18,8 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.image.ImageResponseFormat; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -34,7 +36,7 @@ void testBuilderWithAllFields() { .N(2) .model("dall-e-3") .quality("hd") - .responseFormat("url") + .responseFormat(ImageResponseFormat.URL) .width(1024) .height(1024) .style("vivid") @@ -45,6 +47,7 @@ void testBuilderWithAllFields() { assertThat(options.getModel()).isEqualTo("dall-e-3"); assertThat(options.getQuality()).isEqualTo("hd"); assertThat(options.getResponseFormat()).isEqualTo("url"); + assertThat(options.getResponseFormatAsEnum()).isEqualTo(ImageResponseFormat.URL); assertThat(options.getWidth()).isEqualTo(1024); assertThat(options.getHeight()).isEqualTo(1024); assertThat(options.getSize()).isEqualTo("1024x1024"); @@ -58,7 +61,7 @@ void testCopy() { .N(3) .model("dall-e-3") .quality("standard") - .responseFormat("b64_json") + .responseFormat(ImageResponseFormat.B64_JSON) .width(1792) .height(1024) .style("natural") @@ -72,6 +75,7 @@ void testCopy() { assertThat(copied.getModel()).isEqualTo(original.getModel()); assertThat(copied.getQuality()).isEqualTo(original.getQuality()); assertThat(copied.getResponseFormat()).isEqualTo(original.getResponseFormat()); + assertThat(copied.getResponseFormatAsEnum()).isEqualTo(original.getResponseFormatAsEnum()); assertThat(copied.getWidth()).isEqualTo(original.getWidth()); assertThat(copied.getHeight()).isEqualTo(original.getHeight()); assertThat(copied.getSize()).isEqualTo(original.getSize()); @@ -85,6 +89,7 @@ void testCopy() { assertThat(copiedViaMethod.getModel()).isEqualTo(original.getModel()); assertThat(copiedViaMethod.getQuality()).isEqualTo(original.getQuality()); assertThat(copiedViaMethod.getResponseFormat()).isEqualTo(original.getResponseFormat()); + assertThat(copiedViaMethod.getResponseFormatAsEnum()).isEqualTo(original.getResponseFormatAsEnum()); assertThat(copiedViaMethod.getWidth()).isEqualTo(original.getWidth()); assertThat(copiedViaMethod.getHeight()).isEqualTo(original.getHeight()); assertThat(copiedViaMethod.getSize()).isEqualTo(original.getSize()); @@ -99,7 +104,7 @@ void testSetters() { options.setN(4); options.setModel("dall-e-2"); options.setQuality("standard"); - options.setResponseFormat("url"); + options.setResponseFormat(ImageResponseFormat.URL); options.setWidth(512); options.setHeight(512); options.setStyle("vivid"); @@ -109,6 +114,7 @@ void testSetters() { assertThat(options.getModel()).isEqualTo("dall-e-2"); assertThat(options.getQuality()).isEqualTo("standard"); assertThat(options.getResponseFormat()).isEqualTo("url"); + assertThat(options.getResponseFormatAsEnum()).isEqualTo(ImageResponseFormat.URL); assertThat(options.getWidth()).isEqualTo(512); assertThat(options.getHeight()).isEqualTo(512); assertThat(options.getSize()).isEqualTo("512x512"); @@ -212,7 +218,7 @@ void testFluentApiPattern() { .N(1) .model("dall-e-3") .quality("hd") - .responseFormat("url") + .responseFormat(ImageResponseFormat.URL) .width(1024) .height(1024) .style("vivid") @@ -223,6 +229,7 @@ void testFluentApiPattern() { assertThat(options.getModel()).isEqualTo("dall-e-3"); assertThat(options.getQuality()).isEqualTo("hd"); assertThat(options.getResponseFormat()).isEqualTo("url"); + assertThat(options.getResponseFormatAsEnum()).isEqualTo(ImageResponseFormat.URL); assertThat(options.getWidth()).isEqualTo(1024); assertThat(options.getHeight()).isEqualTo(1024); assertThat(options.getSize()).isEqualTo("1024x1024"); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java index 37dc7abcdba..e9244918529 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java @@ -23,6 +23,7 @@ import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.image.ImageResponseFormat; import org.springframework.ai.image.observation.DefaultImageModelObservationConvention; import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.observation.conventions.AiOperationType; @@ -61,7 +62,7 @@ void observationForImageOperation() { .model(OpenAiImageApi.ImageModel.DALL_E_3.getValue()) .height(1024) .width(1024) - .responseFormat("url") + .responseFormat(ImageResponseFormat.URL) .style("natural") .build(); diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java index cc4e29b6a1c..aa259390abc 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.image.ImageResponseFormat; import org.springframework.ai.stabilityai.StyleEnum; /** @@ -122,7 +123,7 @@ public class StabilityAiImageOptions implements ImageOptions { * accept header. Must be "application/json" or "image/png" */ @JsonProperty("response_format") - private String responseFormat; + private ImageResponseFormat responseFormat; /** * The strictness level of the diffusion process adherence to the prompt text. @@ -329,13 +330,21 @@ public void setHeight(Integer height) { @Override public String getResponseFormat() { + return (this.responseFormat != null) ? this.responseFormat.getValue() : null; + } + + public ImageResponseFormat getResponseFormatEnum() { return this.responseFormat; } - public void setResponseFormat(String responseFormat) { + public void setResponseFormat(ImageResponseFormat responseFormat) { this.responseFormat = responseFormat; } + public void setResponseFormat(String responseFormat) { + this.responseFormat = ImageResponseFormat.fromValue(responseFormat); + } + public Float getCfgScale() { return this.cfgScale; } @@ -455,6 +464,11 @@ public Builder height(Integer height) { return this; } + public Builder responseFormat(ImageResponseFormat responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + public Builder responseFormat(String responseFormat) { this.options.setResponseFormat(responseFormat); return this; diff --git a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageOptionsTests.java b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageOptionsTests.java index 854f022f2f9..402de777b15 100644 --- a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageOptionsTests.java +++ b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageOptionsTests.java @@ -19,6 +19,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.image.ImageResponseFormat; import org.springframework.ai.stabilityai.api.StabilityAiApi; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; @@ -37,7 +38,7 @@ void shouldPreferRuntimeOptionsOverDefaultOptions() { .model("default-model") .width(512) .height(512) - .responseFormat("image/png") + .responseFormat(ImageResponseFormat.IMAGE_PNG) .cfgScale(7.0f) .clipGuidancePreset("FAST_BLUE") .sampler("DDIM") @@ -52,7 +53,7 @@ void shouldPreferRuntimeOptionsOverDefaultOptions() { .model("runtime-model") .width(1024) .height(768) - .responseFormat("application/json") + .responseFormat(ImageResponseFormat.APPLICATION_JSON) .cfgScale(14.0f) .clipGuidancePreset("FAST_GREEN") .sampler("DDPM") @@ -72,6 +73,7 @@ void shouldPreferRuntimeOptionsOverDefaultOptions() { assertThat(options.getWidth()).isEqualTo(1024); assertThat(options.getHeight()).isEqualTo(768); assertThat(options.getResponseFormat()).isEqualTo("application/json"); + assertThat(options.getResponseFormatAsEnum()).isEqualTo(ImageResponseFormat.APPLICATION_JSON); assertThat(options.getCfgScale()).isEqualTo(14.0f); assertThat(options.getClipGuidancePreset()).isEqualTo("FAST_GREEN"); assertThat(options.getSampler()).isEqualTo("DDPM"); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/imageclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/imageclient.adoc index bd8d3a001bd..1e47ff1a7be 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/imageclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/imageclient.adoc @@ -91,6 +91,8 @@ public interface ImageOptions extends ModelOptions { String getResponseFormat(); // openai - url or base64 : stability ai byte[] or base64 + default ImageResponseFormat getResponseFormatAsEnum(); // convenience conversion helper + } ---- @@ -112,6 +114,10 @@ public class ImageResponse implements ModelResponse { private final List imageGenerations; + Optional getResultAsBytes(); + + List getResultsAsBytes(); + @Override public ImageGeneration getResult() { // get the first result diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/Image.java b/spring-ai-model/src/main/java/org/springframework/ai/image/Image.java index bf1f683a16a..775b6c26f61 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/Image.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/Image.java @@ -16,7 +16,13 @@ package org.springframework.ai.image; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.Base64; import java.util.Objects; +import java.util.Optional; + +import org.springframework.util.StringUtils; public class Image { @@ -72,4 +78,15 @@ public int hashCode() { return Objects.hash(this.url, this.b64Json); } + public Optional getB64JsonAsBytes() { + if (!StringUtils.hasText(this.b64Json)) { + return Optional.empty(); + } + return Optional.of(Base64.getDecoder().decode(this.b64Json)); + } + + public Optional getB64JsonAsInputStream() { + return getB64JsonAsBytes().map(ByteArrayInputStream::new); + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptions.java index 435f6fc62df..aed2ee7ef2a 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptions.java @@ -18,6 +18,7 @@ import org.springframework.ai.model.ModelOptions; import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; /** * ImageOptions represent the common options, portable across different image generation @@ -40,6 +41,14 @@ public interface ImageOptions extends ModelOptions { @Nullable String getResponseFormat(); + default @Nullable ImageResponseFormat getResponseFormatAsEnum() { + String responseFormat = getResponseFormat(); + if (!StringUtils.hasText(responseFormat)) { + return null; + } + return ImageResponseFormat.fromValue(responseFormat); + } + @Nullable String getStyle(); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java index 693a4f00f9d..c3c21b196fb 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java @@ -38,11 +38,16 @@ public ImageOptionsBuilder model(String model) { return this; } - public ImageOptionsBuilder responseFormat(String responseFormat) { + public ImageOptionsBuilder responseFormat(ImageResponseFormat responseFormat) { this.options.setResponseFormat(responseFormat); return this; } + public ImageOptionsBuilder responseFormat(String responseFormat) { + this.options.setResponseFormat(ImageResponseFormat.fromValue(responseFormat)); + return this; + } + public ImageOptionsBuilder width(Integer width) { this.options.setWidth(width); return this; @@ -72,7 +77,7 @@ private static class DefaultImageModelOptions implements ImageOptions { private Integer height; - private String responseFormat; + private ImageResponseFormat responseFormat; private String style; @@ -96,10 +101,10 @@ public void setModel(String model) { @Override public String getResponseFormat() { - return this.responseFormat; + return (this.responseFormat != null) ? this.responseFormat.getValue() : null; } - public void setResponseFormat(String responseFormat) { + public void setResponseFormat(ImageResponseFormat responseFormat) { this.responseFormat = responseFormat; } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponse.java b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponse.java index c4605d81890..45d9ed20c8f 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponse.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponse.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Objects; +import java.util.Optional; import org.springframework.ai.model.ModelResponse; import org.springframework.util.CollectionUtils; @@ -91,6 +92,24 @@ public ImageResponseMetadata getMetadata() { return this.imageResponseMetadata; } + public Optional getResultAsBytes() { + ImageGeneration firstGeneration = getResult(); + if (firstGeneration == null || firstGeneration.getOutput() == null) { + return Optional.empty(); + } + return firstGeneration.getOutput().getB64JsonAsBytes().map(byte[]::clone); + } + + public List getResultsAsBytes() { + return this.imageGenerations.stream() + .map(ImageGeneration::getOutput) + .filter(Objects::nonNull) + .map(Image::getB64JsonAsBytes) + .flatMap(Optional::stream) + .map(byte[]::clone) + .toList(); + } + @Override public String toString() { return "ImageResponse [" + "imageResponseMetadata=" + this.imageResponseMetadata + ", imageGenerations=" diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponseFormat.java b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponseFormat.java new file mode 100644 index 00000000000..d7dd260f509 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponseFormat.java @@ -0,0 +1,69 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +import java.util.Arrays; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +import org.springframework.util.StringUtils; + +/** + * Common response formats supported by image generation providers. + * + * @author Kuntal Maity + */ +public enum ImageResponseFormat { + + URL("url"), + + B64_JSON("b64_json"), + + /** + * PNG responses typically returned by providers when requesting raw image bytes. + */ + IMAGE_PNG("image/png"), + + /** + * JSON responses containing additional metadata or base64 encoded payloads. + */ + APPLICATION_JSON("application/json"); + + private final String value; + + ImageResponseFormat(String value) { + this.value = value; + } + + @JsonValue + public String getValue() { + return this.value; + } + + @JsonCreator + public static ImageResponseFormat fromValue(String value) { + if (!StringUtils.hasText(value)) { + return null; + } + return Arrays.stream(values()) + .filter(format -> format.value.equalsIgnoreCase(value)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Unsupported image response format: " + value)); + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java b/spring-ai-model/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java index 7a2a1e86a3b..cfe5ee07cba 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java @@ -19,6 +19,7 @@ import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; +import org.springframework.ai.image.ImageResponseFormat; import org.springframework.util.StringUtils; /** @@ -84,10 +85,11 @@ public KeyValues getHighCardinalityKeyValues(ImageModelObservationContext contex // Request protected KeyValues requestImageFormat(KeyValues keyValues, ImageModelObservationContext context) { - if (StringUtils.hasText(context.getRequest().getOptions().getResponseFormat())) { + ImageResponseFormat responseFormat = context.getRequest().getOptions().getResponseFormatAsEnum(); + if (responseFormat != null) { return keyValues.and( ImageModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_IMAGE_RESPONSE_FORMAT.asString(), - context.getRequest().getOptions().getResponseFormat()); + responseFormat.getValue()); } return keyValues; } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/image/ImageResponseTests.java b/spring-ai-model/src/test/java/org/springframework/ai/image/ImageResponseTests.java new file mode 100644 index 00000000000..c71551d5bdc --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/image/ImageResponseTests.java @@ -0,0 +1,82 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class ImageResponseTests { + + @Test + void getResultAsBytesReturnsFirstDecodedImage() { + byte[] payload = "hello".getBytes(StandardCharsets.UTF_8); + String base64 = Base64.getEncoder().encodeToString(payload); + Image image = new Image("https://example.test/image.png", base64); + ImageResponse response = new ImageResponse(List.of(new ImageGeneration(image))); + + assertThat(response.getResultAsBytes()).hasValueSatisfying(bytes -> assertThat(bytes).isEqualTo(payload)); + } + + @Test + void getResultsAsBytesSkipsEntriesWithoutPayload() { + byte[] payload = new byte[] { 1, 2, 3 }; + String base64 = Base64.getEncoder().encodeToString(payload); + Image imageWithPayload = new Image(null, base64); + Image imageWithoutPayload = new Image("https://example.test/image.png", null); + ImageResponse response = new ImageResponse( + List.of(new ImageGeneration(imageWithPayload), new ImageGeneration(imageWithoutPayload))); + + assertThat(response.getResultsAsBytes()).hasSize(1) + .first() + .satisfies(bytes -> assertThat(bytes).isEqualTo(payload)); + } + + @Test + void imageProvidesOptionalStreamForBase64Payload() throws IOException { + byte[] payload = { 42, 43, 44 }; + String base64 = Base64.getEncoder().encodeToString(payload); + Image image = new Image(null, base64); + + assertThat(image.getB64JsonAsBytes()).contains(payload); + assertThat(image.getB64JsonAsInputStream()).hasValueSatisfying(stream -> { + try (stream) { + assertThat(stream.readAllBytes()).isEqualTo(payload); + } + catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + } + + @Test + void helpersReturnEmptyWhenPayloadMissing() { + Image image = new Image("https://example.test/image.png", null); + ImageResponse response = new ImageResponse(List.of(new ImageGeneration(image))); + + assertThat(image.getB64JsonAsBytes()).isEmpty(); + assertThat(image.getB64JsonAsInputStream()).isEmpty(); + assertThat(response.getResultAsBytes()).isEmpty(); + assertThat(response.getResultsAsBytes()).isEmpty(); + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java b/spring-ai-model/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java index a0c2c4a8305..b3a50e19206 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java @@ -23,6 +23,7 @@ import org.springframework.ai.image.ImageOptions; import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponseFormat; import org.springframework.ai.observation.conventions.AiObservationAttributes; import static org.assertj.core.api.Assertions.assertThat; @@ -90,7 +91,7 @@ void shouldHaveHighCardinalityKeyValuesWhenDefined() { .height(1080) .width(1920) .style("sketch") - .responseFormat("base64") + .responseFormat(ImageResponseFormat.B64_JSON) .build(); ImageModelObservationContext observationContext = ImageModelObservationContext.builder() .imagePrompt(generateImagePrompt(imageOptions)) @@ -98,7 +99,7 @@ void shouldHaveHighCardinalityKeyValuesWhenDefined() { .build(); assertThat(this.observationConvention.getHighCardinalityKeyValues(observationContext)).contains( - KeyValue.of(AiObservationAttributes.REQUEST_IMAGE_RESPONSE_FORMAT.value(), "base64"), + KeyValue.of(AiObservationAttributes.REQUEST_IMAGE_RESPONSE_FORMAT.value(), "b64_json"), KeyValue.of(AiObservationAttributes.REQUEST_IMAGE_SIZE.value(), "1920x1080"), KeyValue.of(AiObservationAttributes.REQUEST_IMAGE_STYLE.value(), "sketch")); }