Skip to content

Commit a57f585

Browse files
committed
Fix issues with AbstractCodeWriter state stacks
Code writer context inheritance wasn't working in the previous implementation of the stack based inheritance system. We were able to resolve context from the current state and parent state (because of eagerly copying states and flattening them per/state), but we weren't able to look at parent states that used a CodeSection. Methods of a CodeSection state can be accessed in named template labels. However, when another state is pushed after setting a CodeSection value for a state, code writer wasn't crawling the state stack to see if previous CodeSections could provide a value for a named label that wasn't explicitly set as parameter in the context map. This was because instead of looking up values lazily by iterating over the state stack, we were making eager copies of different pieces of states. This matters when using things like for loops in templates that push an implicit state to capture the current loop iteration state. This commit changes our approach to now actually use the stack itself and iterate over states when attempting to resolve context keys, formatters, iterators, and CodeSection values. Some tests were added to ensure the behavior of the writer is the same. For example, when using inheritance based lookups, removing a context key now means we need to set a context key to an explicit null value to emulate the previous behavior. If we didn't do this, and we just called Map#remove(), the context key may have been set by a parent state, making the remove call do nothing (whereas the previous implementation would actually take effect).
1 parent 1b1c7f4 commit a57f585

File tree

6 files changed

+203
-184
lines changed

6 files changed

+203
-184
lines changed

smithy-utils/src/main/java/software/amazon/smithy/utils/AbstractCodeWriter.java

+130-50
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717

1818
import java.lang.reflect.Method;
1919
import java.util.ArrayDeque;
20+
import java.util.ArrayList;
21+
import java.util.Arrays;
2022
import java.util.Deque;
2123
import java.util.HashMap;
2224
import java.util.Iterator;
25+
import java.util.List;
2326
import java.util.Map;
2427
import java.util.Objects;
2528
import java.util.Optional;
@@ -574,6 +577,12 @@
574577
*/
575578
public abstract class AbstractCodeWriter<T extends AbstractCodeWriter<T>> {
576579

580+
// Valid formatter characters that can be registered. Must be sorted for binary search to work.
581+
static final char[] VALID_FORMATTER_CHARS = {
582+
'!', '%', '&', '*', '+', ',', '-', '.', ';', '=', '@',
583+
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S',
584+
'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '_', '`'};
585+
577586
private static final Pattern LINES = Pattern.compile("\\r?\\n");
578587
private static final Map<Character, BiFunction<Object, String, String>> DEFAULT_FORMATTERS = MapUtils.of(
579588
'L', (s, i) -> formatLiteral(s),
@@ -603,22 +612,19 @@ public AbstractCodeWriter() {
603612
/**
604613
* Copies settings from the given AbstractCodeWriter into this AbstractCodeWriter.
605614
*
606-
* <p>The settings of the {@code other} AbstractCodeWriter will overwrite
607-
* both global and state-based settings of this AbstractCodeWriter. Formatters of
608-
* the {@code other} AbstractCodeWriter will be merged with the formatters of this
609-
* AbstractCodeWriter, and in the case of conflicts, the formatters of the
610-
* {@code other} will take precedence.
615+
* <p>The settings of the {@code other} AbstractCodeWriter will overwrite both global and state-based settings
616+
* of this AbstractCodeWriter.
611617
*
612-
* <p>Stateful settings of the {@code other} AbstractCodeWriter are copied into
613-
* the <em>current</em> state of this AbstractCodeWriter. Only the settings of
614-
* the top-most state is copied. Other states, and the contents of the
615-
* top-most state are not copied.
618+
* <p>Stateful settings of the {@code other} AbstractCodeWriter like formatters, interceptors, and context are
619+
* flattened and then copied into the <em>current</em> state of this AbstractCodeWriter. Any conflicts between
620+
* formatters, interceptors, or context of the current writer are overwritten by the other writer. The stack of
621+
* states and the contents written to {@code other} are not copied.
616622
*
617623
* <pre>{@code
618-
* SimpleCodeWritera = new SimpleCodeWriter();
624+
* SimpleCodeWriter a = new SimpleCodeWriter();
619625
* a.setExpressionStart('#');
620626
*
621-
* SimpleCodeWriterb = new SimpleCodeWriter();
627+
* SimpleCodeWriter b = new SimpleCodeWriter();
622628
* b.copySettingsFrom(a);
623629
*
624630
* assert(b.getExpressionStart() == '#');
@@ -634,6 +640,16 @@ public void copySettingsFrom(AbstractCodeWriter<T> other) {
634640

635641
// Copy the current state settings of other into the current state.
636642
currentState.copyStateFrom(other.currentState);
643+
644+
// Flatten containers into the current state. This is done in reverse order to ensure that more recent
645+
// state changes supersede earlier changes.
646+
Iterator<State> reverseOtherStates = other.states.descendingIterator();
647+
while (reverseOtherStates.hasNext()) {
648+
State otherState = reverseOtherStates.next();
649+
currentState.interceptors.addAll(otherState.interceptors);
650+
currentState.formatters.putAll(otherState.formatters);
651+
currentState.context.putAll(otherState.context);
652+
}
637653
}
638654

639655
/**
@@ -686,7 +702,7 @@ public static String formatLiteral(Object value) {
686702
*/
687703
@SuppressWarnings("unchecked")
688704
public T putFormatter(char identifier, BiFunction<Object, String, String> formatFunction) {
689-
this.currentState.formatters.get().putFormatter(identifier, formatFunction);
705+
this.currentState.putFormatter(identifier, formatFunction);
690706
return (T) this;
691707
}
692708

@@ -962,9 +978,15 @@ public T popState() {
962978

963979
// Don't attempt to intercept anonymous sections.
964980
if (!(sectionValue instanceof AnonymousCodeSection)) {
965-
for (CodeInterceptor<CodeSection, T> interceptor : popped.interceptors.peek().get(sectionValue)) {
966-
result = interceptSection(popped, interceptor, result);
981+
// Ensure the remaining parent interceptors are applied in the order they were inserted.
982+
// This is the reverse order used when normally iterating over the states deque.
983+
Iterator<State> insertionOrderedStates = states.descendingIterator();
984+
while (insertionOrderedStates.hasNext()) {
985+
State state = insertionOrderedStates.next();
986+
result = applyPoppedInterceptors(popped, state, sectionValue, result);
967987
}
988+
// Now ensure the popped state's interceptors are applied.
989+
result = applyPoppedInterceptors(popped, popped, sectionValue, result);
968990
}
969991

970992
if (popped.isInline) {
@@ -986,6 +1008,13 @@ public T popState() {
9861008
return (T) this;
9871009
}
9881010

1011+
private String applyPoppedInterceptors(State popped, State state, CodeSection sectionValue, String result) {
1012+
for (CodeInterceptor<CodeSection, T> interceptor : state.getInterceptors(sectionValue)) {
1013+
result = interceptSection(popped, interceptor, result);
1014+
}
1015+
return result;
1016+
}
1017+
9891018
// This method exists because inlining in popSection is impossible due to needing to mutate a result variable.
9901019
@SuppressWarnings("unchecked")
9911020
private String interceptSection(State popped, CodeInterceptor<CodeSection, T> interceptor, String previous) {
@@ -1070,7 +1099,7 @@ private String interceptSection(State popped, CodeInterceptor<CodeSection, T> in
10701099
*/
10711100
@SuppressWarnings("unchecked")
10721101
public T onSection(String sectionName, Consumer<Object> interceptor) {
1073-
currentState.interceptors.get().putInterceptor(CodeInterceptor.forName(sectionName, (w, p) -> {
1102+
currentState.putInterceptor(CodeInterceptor.forName(sectionName, (w, p) -> {
10741103
String trimmedContent = removeTrailingNewline(p);
10751104
interceptor.accept(trimmedContent);
10761105
}));
@@ -1098,7 +1127,7 @@ public T onSection(String sectionName, Consumer<Object> interceptor) {
10981127
*/
10991128
@SuppressWarnings("unchecked")
11001129
public <S extends CodeSection> T onSection(CodeInterceptor<S, T> interceptor) {
1101-
currentState.interceptors.get().putInterceptor(interceptor);
1130+
currentState.putInterceptor(interceptor);
11021131
return (T) this;
11031132
}
11041133

@@ -1857,7 +1886,7 @@ public T unwrite(Object content, Object... args) {
18571886
*/
18581887
@SuppressWarnings("unchecked")
18591888
public T putContext(String key, Object value) {
1860-
currentState.context.get().put(key, value);
1889+
currentState.context.put(key, value);
18611890
return (T) this;
18621891
}
18631892

@@ -1879,40 +1908,45 @@ public T putContext(Map<String, Object> mappings) {
18791908
/**
18801909
* Removes a named key-value pair from the context of the current state.
18811910
*
1911+
* <p>This method has no effect if the parent state defines the context key value pair.
1912+
*
18821913
* @param key Key to add to remove from the current context.
18831914
* @return Returns self.
18841915
*/
18851916
@SuppressWarnings("unchecked")
18861917
public T removeContext(String key) {
1887-
if (currentState.context.peek().containsKey(key)) {
1888-
currentState.context.get().remove(key);
1918+
if (currentState.context.containsKey(key)) {
1919+
currentState.context.remove(key);
1920+
} else {
1921+
// Parent states might have a value for this context key, so explicitly set it to null in this context.
1922+
currentState.context.put(key, null);
18891923
}
18901924
return (T) this;
18911925
}
18921926

18931927
/**
1894-
* Gets a named contextual key-value pair from the current state.
1928+
* Gets a named contextual key-value pair from the current state or any parent states.
18951929
*
18961930
* @param key Key to retrieve.
18971931
* @return Returns the associated value or null if not present.
18981932
*/
18991933
public Object getContext(String key) {
1900-
CodeSection section = currentState.sectionValue;
1901-
Map<String, Object> currentContext = currentState.context.peek();
1902-
if (currentContext.containsKey(key)) {
1903-
return currentContext.get(key);
1904-
} else if (section != null) {
1905-
Method method = findContextMethod(section, key);
1906-
if (method != null) {
1907-
try {
1908-
return method.invoke(section);
1909-
} catch (ReflectiveOperationException e) {
1910-
String message = String.format(
1911-
"Unable to get context '%s' from a matching method of the current CodeSection: %s %s",
1912-
key,
1913-
e.getCause() != null ? e.getCause().getMessage() : e.getMessage(),
1914-
getDebugInfo());
1915-
throw new RuntimeException(message, e);
1934+
for (State state : states) {
1935+
if (state.context.containsKey(key)) {
1936+
return state.context.get(key);
1937+
} else if (state.sectionValue != null) {
1938+
Method method = findContextMethod(state.sectionValue, key);
1939+
if (method != null) {
1940+
try {
1941+
return method.invoke(state.sectionValue);
1942+
} catch (ReflectiveOperationException e) {
1943+
String message = String.format(
1944+
"Unable to get context '%s' from a matching method of the current CodeSection: %s %s",
1945+
key,
1946+
e.getCause() != null ? e.getCause().getMessage() : e.getMessage(),
1947+
getDebugInfo());
1948+
throw new RuntimeException(message, e);
1949+
}
19161950
}
19171951
}
19181952
}
@@ -1984,7 +2018,7 @@ String expandSection(CodeSection section, String previousContent, Consumer<Strin
19842018
// Used only by CodeFormatter to apply formatters.
19852019
@SuppressWarnings("unchecked")
19862020
String applyFormatter(char identifier, Object value) {
1987-
BiFunction<Object, String, String> f = currentState.formatters.peek().getFormatter(identifier);
2021+
BiFunction<Object, String, String> f = resolveFormatter(identifier);
19882022
if (f != null) {
19892023
return f.apply(value, getIndentText());
19902024
} else if (identifier == 'C') {
@@ -2009,14 +2043,24 @@ String applyFormatter(char identifier, Object value) {
20092043
throw new ClassCastException(String.format(
20102044
"Expected value for 'C' formatter to be an instance of %s or %s, but found %s %s",
20112045
Runnable.class.getName(), Consumer.class.getName(),
2012-
value.getClass().getName(), getDebugInfo()));
2046+
value == null ? "null" : value.getClass().getName(), getDebugInfo()));
20132047
}
20142048
} else {
20152049
// Return null if no formatter was found.
20162050
return null;
20172051
}
20182052
}
20192053

2054+
BiFunction<Object, String, String> resolveFormatter(char identifier) {
2055+
for (State state : states) {
2056+
BiFunction<Object, String, String> result = state.getFormatter(identifier);
2057+
if (result != null) {
2058+
return result;
2059+
}
2060+
}
2061+
return null;
2062+
}
2063+
20202064
private final class State {
20212065
private final boolean isRoot;
20222066
private String indentText = " ";
@@ -2030,9 +2074,9 @@ private final class State {
20302074
private boolean needsIndentation;
20312075

20322076
private CodeSection sectionValue;
2033-
private CopyOnWriteRef<Map<String, Object>> context;
2034-
private CopyOnWriteRef<CodeWriterFormatterContainer> formatters;
2035-
private CopyOnWriteRef<CodeInterceptorContainer<T>> interceptors;
2077+
private final Map<String, Object> context = new HashMap<>();
2078+
private final Map<Character, BiFunction<Object, String, String>> formatters = new HashMap<>();
2079+
private final List<CodeInterceptor<CodeSection, T>> interceptors = new ArrayList<>();
20362080

20372081
private StringBuilder builder;
20382082

@@ -2046,11 +2090,7 @@ private final class State {
20462090
State() {
20472091
builder = new StringBuilder();
20482092
isRoot = true;
2049-
CodeWriterFormatterContainer formatterContainer = new CodeWriterFormatterContainer();
2050-
DEFAULT_FORMATTERS.forEach(formatterContainer::putFormatter);
2051-
this.formatters = CopyOnWriteRef.fromOwned(formatterContainer);
2052-
this.context = CopyOnWriteRef.fromOwned(new HashMap<>());
2053-
this.interceptors = CopyOnWriteRef.fromOwned(new CodeInterceptorContainer<>());
2093+
DEFAULT_FORMATTERS.forEach(this::putFormatter);
20542094
}
20552095

20562096
@SuppressWarnings("CopyConstructorMissesField")
@@ -2060,20 +2100,18 @@ private final class State {
20602100
this.builder = copy.builder;
20612101
}
20622102

2103+
// This does not copy context, interceptors, or formatters.
2104+
// State inheritance relies on stacks of States in an AbstractCodeWriter.
20632105
private void copyStateFrom(State copy) {
20642106
this.newline = copy.newline;
20652107
this.expressionStart = copy.expressionStart;
2066-
this.context = copy.context;
20672108
this.indentText = copy.indentText;
20682109
this.leadingIndentString = copy.leadingIndentString;
20692110
this.indentation = copy.indentation;
20702111
this.newlinePrefix = copy.newlinePrefix;
20712112
this.trimTrailingSpaces = copy.trimTrailingSpaces;
20722113
this.disableNewline = copy.disableNewline;
20732114
this.needsIndentation = copy.needsIndentation;
2074-
this.context = CopyOnWriteRef.fromBorrowed(copy.context.peek(), HashMap::new);
2075-
this.formatters = CopyOnWriteRef.fromBorrowed(copy.formatters.peek(), CodeWriterFormatterContainer::new);
2076-
this.interceptors = CopyOnWriteRef.fromBorrowed(copy.interceptors.peek(), CodeInterceptorContainer::new);
20772115
}
20782116

20792117
@Override
@@ -2200,6 +2238,48 @@ private String getSectionName() {
22002238
return sectionValue.sectionName();
22012239
}
22022240
}
2241+
2242+
void putFormatter(Character identifier, BiFunction<Object, String, String> formatFunction) {
2243+
if (Arrays.binarySearch(VALID_FORMATTER_CHARS, identifier) < 0) {
2244+
throw new IllegalArgumentException("Invalid formatter identifier: " + identifier);
2245+
}
2246+
formatters.put(identifier, formatFunction);
2247+
}
2248+
2249+
BiFunction<Object, String, String> getFormatter(char identifier) {
2250+
return formatters.get(identifier);
2251+
}
2252+
2253+
@SuppressWarnings("unchecked")
2254+
void putInterceptor(CodeInterceptor<? extends CodeSection, T> interceptor) {
2255+
interceptors.add((CodeInterceptor<CodeSection, T>) interceptor);
2256+
}
2257+
2258+
/**
2259+
* Gets a list of interceptors that match the given type and for which the
2260+
* result of {@link CodeInterceptor#isIntercepted(CodeSection)} returns true
2261+
* when given {@code forSection}.
2262+
*
2263+
* @param forSection The section that is being intercepted.
2264+
* @param <S> The type of section being intercepted.
2265+
* @return Returns the list of matching interceptors.
2266+
*/
2267+
<S extends CodeSection> List<CodeInterceptor<CodeSection, T>> getInterceptors(S forSection) {
2268+
// Add in parent interceptors.
2269+
List<CodeInterceptor<CodeSection, T>> result = new ArrayList<>();
2270+
// Merge in local interceptors.
2271+
for (CodeInterceptor<CodeSection, T> interceptor : interceptors) {
2272+
// Add the interceptor only if it's the right type.
2273+
if (interceptor.sectionType().isInstance(forSection)) {
2274+
// Only add if the filter passes.
2275+
if (interceptor.isIntercepted(forSection)) {
2276+
result.add(interceptor);
2277+
}
2278+
}
2279+
}
2280+
2281+
return result;
2282+
}
22032283
}
22042284

22052285
String removeTrailingNewline(String value) {

smithy-utils/src/main/java/software/amazon/smithy/utils/CodeFormatter.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ private Operation parseNormalArgument() {
635635
// Parse the formatter and apply it.
636636
int line = parser.line();
637637
int column = parser.column();
638-
char identifier = parser.expect(CodeWriterFormatterContainer.VALID_FORMATTER_CHARS);
638+
char identifier = parser.expect(AbstractCodeWriter.VALID_FORMATTER_CHARS);
639639

640640
// The error message needs to be created here and given to the operation in way that it can
641641
// throw with an appropriate message.

0 commit comments

Comments
 (0)