Skip to content

Commit 9af0036

Browse files
committed
Introduce TemplateRenderer for prompt templating
- Introduce new TemplateRenderer API providing the logic for rendering an input template. - Update the PromptTemplate API to accept a TemplateRenderer object at construction time. - Move ST logic to StTemplateRenderer implementation, used by default in PromptTemplate. Additionally, make start and end delimiter character configurable. Relates to gh-2655 Signed-off-by: Thomas Vitale <[email protected]>
1 parent 15a5069 commit 9af0036

File tree

8 files changed

+924
-107
lines changed

8 files changed

+924
-107
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,128 +20,135 @@
2020
import java.io.InputStream;
2121
import java.nio.charset.Charset;
2222
import java.util.HashMap;
23-
import java.util.HashSet;
2423
import java.util.List;
2524
import java.util.Map;
2625
import java.util.Map.Entry;
2726
import java.util.Set;
2827

29-
import org.antlr.runtime.Token;
30-
import org.antlr.runtime.TokenStream;
31-
import org.stringtemplate.v4.ST;
32-
import org.stringtemplate.v4.compiler.STLexer;
28+
import org.springframework.ai.template.TemplateRenderer;
29+
import org.springframework.ai.template.st.StTemplateRenderer;
30+
import org.springframework.util.Assert;
3331

3432
import org.springframework.ai.chat.messages.Message;
3533
import org.springframework.ai.chat.messages.UserMessage;
3634
import org.springframework.ai.content.Media;
3735
import org.springframework.core.io.Resource;
3836
import org.springframework.util.StreamUtils;
3937

38+
/**
39+
* A template for creating prompts. It allows you to define a template string with
40+
* placeholders for variables, and then render the template with specific values for those
41+
* variables.
42+
* <p>
43+
* NOTE: This class will be marked as final in the next release. If you subclass this
44+
* class, you should consider using the built-in implementation together with the new
45+
* PromptTemplateRenderer interface, which is designed to give you more flexibility and
46+
* control over the rendering process.
47+
*/
4048
public class PromptTemplate implements PromptTemplateActions, PromptTemplateMessageActions {
4149

50+
private static final TemplateRenderer DEFAULT_TEMPLATE_RENDERER = StTemplateRenderer.builder().build();
51+
52+
/**
53+
* @deprecated will become private in the next release. If you're subclassing this
54+
* class, re-consider using the built-in implementation together with the new
55+
* PromptTemplateRenderer interface, designed to give you more flexibility and control
56+
* over the rendering process.
57+
*/
58+
@Deprecated
4259
protected String template;
4360

61+
/**
62+
* @deprecated in favor of {@link TemplateRenderer}
63+
*/
64+
@Deprecated
4465
protected TemplateFormat templateFormat = TemplateFormat.ST;
4566

46-
private ST st;
67+
private final Map<String, Object> variables = new HashMap<>();
4768

48-
private Map<String, Object> dynamicModel = new HashMap<>();
69+
private final TemplateRenderer renderer;
4970

5071
public PromptTemplate(Resource resource) {
51-
try (InputStream inputStream = resource.getInputStream()) {
52-
this.template = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
53-
}
54-
catch (IOException ex) {
55-
throw new RuntimeException("Failed to read resource", ex);
56-
}
57-
try {
58-
this.st = new ST(this.template, '{', '}');
59-
}
60-
catch (Exception ex) {
61-
throw new IllegalArgumentException("The template string is not valid.", ex);
62-
}
72+
this(resource, new HashMap<>(), DEFAULT_TEMPLATE_RENDERER);
6373
}
6474

6575
public PromptTemplate(String template) {
66-
this.template = template;
67-
// If the template string is not valid, an exception will be thrown
68-
try {
69-
this.st = new ST(this.template, '{', '}');
70-
}
71-
catch (Exception ex) {
72-
throw new IllegalArgumentException("The template string is not valid.", ex);
73-
}
76+
this(template, new HashMap<>(), DEFAULT_TEMPLATE_RENDERER);
77+
}
78+
79+
public PromptTemplate(String template, Map<String, Object> variables) {
80+
this(template, variables, DEFAULT_TEMPLATE_RENDERER);
7481
}
7582

76-
public PromptTemplate(String template, Map<String, Object> model) {
83+
public PromptTemplate(Resource resource, Map<String, Object> variables) {
84+
this(resource, variables, DEFAULT_TEMPLATE_RENDERER);
85+
}
86+
87+
PromptTemplate(String template, Map<String, Object> variables, TemplateRenderer renderer) {
88+
Assert.hasText(template, "template cannot be null or empty");
89+
Assert.notNull(variables, "variables cannot be null");
90+
Assert.noNullElements(variables.keySet(), "variables keys cannot be null");
91+
Assert.notNull(renderer, "renderer cannot be null");
92+
7793
this.template = template;
78-
// If the template string is not valid, an exception will be thrown
79-
try {
80-
this.st = new ST(this.template, '{', '}');
81-
for (Entry<String, Object> entry : model.entrySet()) {
82-
add(entry.getKey(), entry.getValue());
83-
}
84-
}
85-
catch (Exception ex) {
86-
throw new IllegalArgumentException("The template string is not valid.", ex);
87-
}
94+
this.variables.putAll(variables);
95+
this.renderer = renderer;
8896
}
8997

90-
public PromptTemplate(Resource resource, Map<String, Object> model) {
98+
PromptTemplate(Resource resource, Map<String, Object> variables, TemplateRenderer renderer) {
99+
Assert.notNull(resource, "resource cannot be null");
100+
Assert.notNull(variables, "variables cannot be null");
101+
Assert.noNullElements(variables.keySet(), "variables keys cannot be null");
102+
Assert.notNull(renderer, "renderer cannot be null");
103+
91104
try (InputStream inputStream = resource.getInputStream()) {
92105
this.template = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
106+
Assert.hasText(template, "template cannot be null or empty");
93107
}
94108
catch (IOException ex) {
95109
throw new RuntimeException("Failed to read resource", ex);
96110
}
97-
// If the template string is not valid, an exception will be thrown
98-
try {
99-
this.st = new ST(this.template, '{', '}');
100-
for (Entry<String, Object> entry : model.entrySet()) {
101-
this.add(entry.getKey(), entry.getValue());
102-
}
103-
}
104-
catch (Exception ex) {
105-
throw new IllegalArgumentException("The template string is not valid.", ex);
106-
}
111+
this.variables.putAll(variables);
112+
this.renderer = renderer;
107113
}
108114

109115
public void add(String name, Object value) {
110-
this.st.add(name, value);
111-
this.dynamicModel.put(name, value);
116+
this.variables.put(name, value);
112117
}
113118

114119
public String getTemplate() {
115120
return this.template;
116121
}
117122

123+
/**
124+
* @deprecated in favor of {@link TemplateRenderer}
125+
*/
126+
@Deprecated
118127
public TemplateFormat getTemplateFormat() {
119128
return this.templateFormat;
120129
}
121130

122-
// Render Methods
131+
// From PromptTemplateStringActions.
132+
123133
@Override
124134
public String render() {
125-
validate(this.dynamicModel);
126-
return this.st.render();
135+
return this.renderer.apply(template, this.variables);
127136
}
128137

129138
@Override
130-
public String render(Map<String, Object> model) {
131-
validate(model);
132-
for (Entry<String, Object> entry : model.entrySet()) {
133-
if (this.st.getAttribute(entry.getKey()) != null) {
134-
this.st.remove(entry.getKey());
135-
}
139+
public String render(Map<String, Object> additionalVariables) {
140+
Map<String, Object> combinedVariables = new HashMap<>(this.variables);
141+
142+
for (Entry<String, Object> entry : additionalVariables.entrySet()) {
136143
if (entry.getValue() instanceof Resource) {
137-
this.st.add(entry.getKey(), renderResource((Resource) entry.getValue()));
144+
combinedVariables.put(entry.getKey(), renderResource((Resource) entry.getValue()));
138145
}
139146
else {
140-
this.st.add(entry.getKey(), entry.getValue());
147+
combinedVariables.put(entry.getKey(), entry.getValue());
141148
}
142-
143149
}
144-
return this.st.render();
150+
151+
return this.renderer.apply(template, combinedVariables);
145152
}
146153

147154
private String renderResource(Resource resource) {
@@ -153,6 +160,8 @@ private String renderResource(Resource resource) {
153160
}
154161
}
155162

163+
// From PromptTemplateMessageActions.
164+
156165
@Override
157166
public Message createMessage() {
158167
return new UserMessage(render());
@@ -164,10 +173,12 @@ public Message createMessage(List<Media> mediaList) {
164173
}
165174

166175
@Override
167-
public Message createMessage(Map<String, Object> model) {
168-
return new UserMessage(render(model));
176+
public Message createMessage(Map<String, Object> additionalVariables) {
177+
return new UserMessage(render(additionalVariables));
169178
}
170179

180+
// From PromptTemplateActions.
181+
171182
@Override
172183
public Prompt create() {
173184
return new Prompt(render(new HashMap<>()));
@@ -179,59 +190,89 @@ public Prompt create(ChatOptions modelOptions) {
179190
}
180191

181192
@Override
182-
public Prompt create(Map<String, Object> model) {
183-
return new Prompt(render(model));
193+
public Prompt create(Map<String, Object> additionalVariables) {
194+
return new Prompt(render(additionalVariables));
184195
}
185196

186197
@Override
187-
public Prompt create(Map<String, Object> model, ChatOptions modelOptions) {
188-
return new Prompt(render(model), modelOptions);
198+
public Prompt create(Map<String, Object> additionalVariables, ChatOptions modelOptions) {
199+
return new Prompt(render(additionalVariables), modelOptions);
189200
}
190201

202+
// Compatibility
203+
204+
/**
205+
* @deprecated in favor of {@link TemplateRenderer}.
206+
*/
207+
@Deprecated
191208
public Set<String> getInputVariables() {
192-
TokenStream tokens = this.st.impl.tokens;
193-
Set<String> inputVariables = new HashSet<>();
194-
boolean isInsideList = false;
195-
196-
for (int i = 0; i < tokens.size(); i++) {
197-
Token token = tokens.get(i);
198-
199-
if (token.getType() == STLexer.LDELIM && i + 1 < tokens.size()
200-
&& tokens.get(i + 1).getType() == STLexer.ID) {
201-
if (i + 2 < tokens.size() && tokens.get(i + 2).getType() == STLexer.COLON) {
202-
inputVariables.add(tokens.get(i + 1).getText());
203-
isInsideList = true;
204-
}
205-
}
206-
else if (token.getType() == STLexer.RDELIM) {
207-
isInsideList = false;
208-
}
209-
else if (!isInsideList && token.getType() == STLexer.ID) {
210-
inputVariables.add(token.getText());
211-
}
212-
}
209+
throw new UnsupportedOperationException(
210+
"The template rendering logic is now provided by PromptTemplateRenderer");
211+
}
213212

214-
return inputVariables;
213+
/**
214+
* @deprecated in favor of {@link TemplateRenderer}.
215+
*/
216+
@Deprecated
217+
protected void validate(Map<String, Object> model) {
218+
throw new UnsupportedOperationException("Validation is now provided by the PromptTemplateRenderer");
215219
}
216220

217-
private Set<String> getModelKeys(Map<String, Object> model) {
218-
Set<String> dynamicVariableNames = new HashSet<>(this.dynamicModel.keySet());
219-
Set<String> modelVariables = new HashSet<>(model.keySet());
220-
modelVariables.addAll(dynamicVariableNames);
221-
return modelVariables;
221+
public Builder mutate() {
222+
return new Builder().template(this.template).variables(this.variables).renderer(this.renderer);
222223
}
223224

224-
protected void validate(Map<String, Object> model) {
225+
// Builder
226+
227+
public static Builder builder() {
228+
return new Builder();
229+
}
230+
231+
public static class Builder {
232+
233+
private String template;
234+
235+
private Resource resource;
225236

226-
Set<String> templateTokens = getInputVariables();
227-
Set<String> modelKeys = getModelKeys(model);
237+
private Map<String, Object> variables = new HashMap<>();
228238

229-
// Check if model provides all keys required by the template
230-
if (!modelKeys.containsAll(templateTokens)) {
231-
templateTokens.removeAll(modelKeys);
232-
throw new IllegalStateException(
233-
"Not all template variables were replaced. Missing variable names are " + templateTokens);
239+
private TemplateRenderer renderer = DEFAULT_TEMPLATE_RENDERER;
240+
241+
private Builder() {
242+
}
243+
244+
public Builder template(String template) {
245+
this.template = template;
246+
return this;
247+
}
248+
249+
public Builder resource(Resource resource) {
250+
this.resource = resource;
251+
return this;
252+
}
253+
254+
public Builder variables(Map<String, Object> variables) {
255+
this.variables = variables;
256+
return this;
257+
}
258+
259+
public Builder renderer(TemplateRenderer renderer) {
260+
this.renderer = renderer;
261+
return this;
262+
}
263+
264+
public PromptTemplate build() {
265+
if (this.template != null && this.resource != null) {
266+
throw new IllegalArgumentException("Only one of template or resource can be set");
267+
}
268+
else if (this.resource != null) {
269+
return new PromptTemplate(this.resource, this.variables, this.renderer);
270+
}
271+
else {
272+
return new PromptTemplate(this.template, this.variables, this.renderer);
273+
}
234274
}
275+
235276
}
236277

237278
}

spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,6 +16,12 @@
1616

1717
package org.springframework.ai.chat.prompt;
1818

19+
import org.springframework.ai.template.TemplateRenderer;
20+
21+
/**
22+
* @deprecated in favor of {@link TemplateRenderer}.
23+
*/
24+
@Deprecated
1925
public enum TemplateFormat {
2026

2127
ST("ST");

0 commit comments

Comments
 (0)