Skip to content

[WIP] Use force-quoting in R2dbcMappingContext by default #2047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
import org.springframework.data.relational.core.sql.SqlIdentifier;
import org.springframework.data.relational.domain.RowDocument;
import org.springframework.data.util.TypeInformation;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -224,7 +225,10 @@ private void writeSimpleInternal(OutboundRow sink, Object value, RelationalPersi

Object result = getPotentiallyConvertedSimpleWrite(value);

sink.put(property.getColumnName(),
SqlIdentifier columnName = property.getColumnName();

sink.put(SqlIdentifier.unquoted(columnName.getReference()),
// sink.put(property.getColumnName(),
Parameter.fromOrEmpty(result, getPotentiallyConvertedSimpleNullType(property.getType())));
}

Expand All @@ -242,7 +246,10 @@ private void writePropertyInternal(OutboundRow sink, Object value, RelationalPer
}

List<Object> collectionInternal = createCollection(asCollection(value), property);
sink.put(property.getColumnName(), Parameter.from(collectionInternal));
// sink.put(property.getColumnName(), Parameter.from(collectionInternal));
SqlIdentifier columnName = property.getColumnName();
//
sink.put(SqlIdentifier.unquoted(columnName.getReference()), Parameter.from(collectionInternal));
return;
}

Expand Down Expand Up @@ -299,7 +306,7 @@ private List<Object> writeCollectionInternal(Collection<?> source, @Nullable Typ

private void writeNullInternal(OutboundRow sink, RelationalPersistentProperty property) {

sink.put(property.getColumnName(), Parameter.empty(getPotentiallyConvertedSimpleNullType(property.getType())));
sink.put(property.getColumnName().getReference(), Parameter.empty(getPotentiallyConvertedSimpleNullType(property.getType())));
}

private Class<?> getPotentiallyConvertedSimpleNullType(Class<?> type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,11 @@ public OutboundRow getOutboundRow(Object object) {

for (RelationalPersistentProperty property : entity) {

Parameter value = row.get(property.getColumnName());
Parameter value = row.get(property.getColumnName().getReference());
if (value != null && shouldConvertArrayValue(property, value)) {

Parameter writeValue = getArrayValue(value, property);
row.put(property.getColumnName(), writeValue);
row.put(property.getColumnName().getReference(), writeValue);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
import java.util.ArrayList;
import java.util.List;

import org.jetbrains.annotations.NotNull;
import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.r2dbc.convert.R2dbcConverter;
import org.springframework.data.r2dbc.dialect.R2dbcDialect;
import org.springframework.data.r2dbc.query.BoundAssignments;
import org.springframework.data.r2dbc.query.BoundCondition;
import org.springframework.data.r2dbc.query.UpdateMapper;
import org.springframework.data.relational.core.dialect.RenderContextFactory;
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
import org.springframework.data.relational.core.query.CriteriaDefinition;
Expand Down Expand Up @@ -53,7 +55,7 @@ class DefaultStatementMapper implements StatementMapper {
private final RenderContext renderContext;
private final UpdateMapper updateMapper;
private final MappingContext<? extends RelationalPersistentEntity<?>, ? extends RelationalPersistentProperty> mappingContext;

private boolean forceQuote;
DefaultStatementMapper(R2dbcDialect dialect, R2dbcConverter converter) {

RenderContextFactory factory = new RenderContextFactory(dialect);
Expand All @@ -62,6 +64,9 @@ class DefaultStatementMapper implements StatementMapper {
this.renderContext = factory.createRenderContext();
this.updateMapper = new UpdateMapper(dialect, converter);
this.mappingContext = converter.getMappingContext();
if(mappingContext instanceof RelationalMappingContext relationalMappingContext){
forceQuote = relationalMappingContext.isForceQuote();
}
}

DefaultStatementMapper(R2dbcDialect dialect, RenderContext renderContext, UpdateMapper updateMapper,
Expand All @@ -70,6 +75,9 @@ class DefaultStatementMapper implements StatementMapper {
this.renderContext = renderContext;
this.updateMapper = updateMapper;
this.mappingContext = mappingContext;
if(mappingContext instanceof RelationalMappingContext relationalMappingContext){
forceQuote = relationalMappingContext.isForceQuote();
}
}

@Override
Expand All @@ -90,7 +98,8 @@ public PreparedOperation<?> getMappedObject(SelectSpec selectSpec) {
private PreparedOperation<Select> getMappedObject(SelectSpec selectSpec,
@Nullable RelationalPersistentEntity<?> entity) {

Table table = selectSpec.getTable();
String tableName = selectSpec.getTable().getName().getReference();
Table table = getTable(tableName);
SelectBuilder.SelectAndFrom selectAndFrom = StatementBuilder.select(getSelectList(selectSpec, entity));

if (selectSpec.isDistinct()) {
Expand Down Expand Up @@ -158,7 +167,7 @@ private PreparedOperation<Insert> getMappedObject(InsertSpec insertSpec,
@Nullable RelationalPersistentEntity<?> entity) {

BindMarkers bindMarkers = this.dialect.getBindMarkersFactory().create();
Table table = Table.create(toSql(insertSpec.getTable()));
Table table = getTable(insertSpec.getTable().getReference());

BoundAssignments boundAssignments = this.updateMapper.getMappedObject(bindMarkers, insertSpec.getAssignments(),
table, entity);
Expand Down Expand Up @@ -191,7 +200,7 @@ private PreparedOperation<Update> getMappedObject(UpdateSpec updateSpec,
@Nullable RelationalPersistentEntity<?> entity) {

BindMarkers bindMarkers = this.dialect.getBindMarkersFactory().create();
Table table = Table.create(toSql(updateSpec.getTable()));
Table table = getTable(updateSpec.getTable().getReference());

if (updateSpec.getUpdate() == null || updateSpec.getUpdate().getAssignments().isEmpty()) {
throw new IllegalArgumentException("UPDATE contains no assignments");
Expand All @@ -210,7 +219,6 @@ private PreparedOperation<Update> getMappedObject(UpdateSpec updateSpec,

CriteriaDefinition criteria = updateSpec.getCriteria();
if (criteria != null && !criteria.isEmpty()) {

BoundCondition boundCondition = this.updateMapper.getMappedObject(bindMarkers, criteria, table, entity);

bindings = bindings.and(boundCondition.getBindings());
Expand All @@ -222,6 +230,10 @@ private PreparedOperation<Update> getMappedObject(UpdateSpec updateSpec,
return new DefaultPreparedOperation<>(update, this.renderContext, bindings);
}

private Table getTable(String tableName) {
return forceQuote ? Table.create(SqlIdentifier.quoted(tableName)) : Table.create(SqlIdentifier.unquoted(tableName));
}

@Override
public PreparedOperation<Delete> getMappedObject(DeleteSpec deleteSpec) {
return getMappedObject(deleteSpec, null);
Expand All @@ -236,7 +248,7 @@ private PreparedOperation<Delete> getMappedObject(DeleteSpec deleteSpec,
@Nullable RelationalPersistentEntity<?> entity) {

BindMarkers bindMarkers = this.dialect.getBindMarkersFactory().create();
Table table = Table.create(toSql(deleteSpec.getTable()));
Table table = getTable(deleteSpec.getTable().getReference());

DeleteBuilder.DeleteWhere deleteBuilder = StatementBuilder.delete(table);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import io.r2dbc.spi.Row;
import io.r2dbc.spi.RowMetadata;
import io.r2dbc.spi.Statement;
import org.springframework.data.r2dbc.mapping.R2dbcMappingContext;
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

Expand Down Expand Up @@ -354,6 +356,11 @@ private <T> RowsFetchSpec<T> doSelect(Query query, Class<?> entityType, SqlIdent
Class<T> returnType, Function<? super Statement, ? extends Statement> filterFunction) {

StatementMapper statementMapper = dataAccessStrategy.getStatementMapper().forType(entityType);
boolean forceQuote = false;
if(this.mappingContext instanceof RelationalMappingContext relationalMappingContext){
forceQuote = relationalMappingContext.isForceQuote();
}
tableName = forceQuote ? SqlIdentifier.quoted(tableName.getReference()): SqlIdentifier.unquoted(tableName.getReference());

StatementMapper.SelectSpec selectSpec = statementMapper //
.createSelect(tableName) //
Expand Down Expand Up @@ -548,9 +555,16 @@ private <T> Mono<T> doInsert(T entity, SqlIdentifier tableName, OutboundRow outb
StatementMapper.InsertSpec insert = mapper.createInsert(tableName);

for (SqlIdentifier column : outboundRow.keySet()) {

Parameter settableValue = outboundRow.get(column);
if (settableValue.hasValue()) {
insert = insert.withColumn(column, settableValue);
boolean forceQuote = false;
if(this.mappingContext instanceof R2dbcMappingContext r2dbcMappingContext){
forceQuote = r2dbcMappingContext.isForceQuote();
}


insert = insert.withColumn(forceQuote? SqlIdentifier.quoted(column.getReference()): column, settableValue);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.jetbrains.annotations.TestOnly;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not using this annotation.

import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.r2dbc.convert.R2dbcConverter;
Expand Down Expand Up @@ -128,6 +129,7 @@ interface TypedStatementMapper<T> extends StatementMapper {}
* @param table
* @return the {@link SelectSpec}.
*/
@TestOnly
default SelectSpec createSelect(String table) {
return SelectSpec.create(table);
}
Expand Down Expand Up @@ -250,6 +252,7 @@ protected SelectSpec(Table table, List<String> projectedFields, List<Expression>
* @param table
* @return the {@link SelectSpec}.
*/
@TestOnly
public static SelectSpec create(String table) {
return create(SqlIdentifier.unquoted(table));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ public class R2dbcMappingContext extends RelationalMappingContext {
* Create a new {@link R2dbcMappingContext}.
*/
public R2dbcMappingContext() {
setForceQuote(false);
}

/**
Expand All @@ -42,7 +41,6 @@ public R2dbcMappingContext() {
*/
public R2dbcMappingContext(NamingStrategy namingStrategy) {
super(namingStrategy);
setForceQuote(false);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Map;
import java.util.regex.Pattern;

import org.jetbrains.annotations.NotNull;
import org.springframework.data.domain.Sort;
import org.springframework.data.mapping.MappingException;
import org.springframework.data.mapping.PersistentPropertyPath;
Expand All @@ -31,6 +32,7 @@
import org.springframework.data.r2dbc.convert.R2dbcConverter;
import org.springframework.data.r2dbc.dialect.R2dbcDialect;
import org.springframework.data.relational.core.dialect.Escaper;
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
import org.springframework.data.relational.core.query.CriteriaDefinition;
Expand Down Expand Up @@ -64,6 +66,7 @@ public class QueryMapper {
private final R2dbcConverter converter;
private final R2dbcDialect dialect;
private final MappingContext<? extends RelationalPersistentEntity<?>, RelationalPersistentProperty> mappingContext;
private final boolean forceQuote;

/**
* Creates a new {@link QueryMapper} with the given {@link R2dbcConverter}.
Expand All @@ -80,6 +83,11 @@ public QueryMapper(R2dbcDialect dialect, R2dbcConverter converter) {
this.converter = converter;
this.dialect = dialect;
this.mappingContext = (MappingContext) converter.getMappingContext();
if(mappingContext instanceof RelationalMappingContext relationalMappingContext){
forceQuote = relationalMappingContext.isForceQuote();
}else {
forceQuote= false;
}
}

/**
Expand Down Expand Up @@ -107,7 +115,7 @@ public String toSql(SqlIdentifier identifier) {
public List<OrderByField> getMappedSort(Table table, Sort sort, @Nullable RelationalPersistentEntity<?> entity) {

List<OrderByField> mappedOrder = new ArrayList<>();

table = getTable(table);
for (Sort.Order order : sort) {

SqlSort.validate(order);
Expand All @@ -120,13 +128,23 @@ public List<OrderByField> getMappedSort(Table table, Sort sort, @Nullable Relati
return mappedOrder;
}

Table getTable(Table table) {
String tableName = table.getName().getReference();
table = Table.create(forceQuote? SqlIdentifier.quoted(tableName):SqlIdentifier.unquoted(tableName));
return table;
}

private OrderByField createSimpleOrderByField(Table table, RelationalPersistentEntity<?> entity, Sort.Order order) {

if (order instanceof SqlSort.SqlOrder sqlOrder && sqlOrder.isUnsafe()) {
return OrderByField.from(Expressions.just(sqlOrder.getProperty()));
}
boolean forceQuote = false;
if(this.mappingContext instanceof RelationalMappingContext relationalMappingContext){
forceQuote = relationalMappingContext.isForceQuote();
}

Field field = createPropertyField(entity, SqlIdentifier.unquoted(order.getProperty()), this.mappingContext);
Field field = createPropertyField(entity, forceQuote ? SqlIdentifier.quoted(order.getProperty()) : SqlIdentifier.unquoted(order.getProperty()), this.mappingContext);
return OrderByField.from(table.column(field.getMappedColumnName()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public BoundAssignments getMappedObject(BindMarkers markers, Map<SqlIdentifier,
List<Assignment> result = new ArrayList<>();

assignments.forEach((column, value) -> {
Assignment assignment = getAssignment(column, value, bindings, table, entity);
Assignment assignment = getAssignment(column, value, bindings, getTable(table), entity);
result.add(assignment);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.springframework.data.domain.Sort;
import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy;
import org.springframework.data.r2dbc.core.StatementMapper;
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
import org.springframework.data.relational.core.query.Criteria;
Expand Down Expand Up @@ -163,9 +164,18 @@ private Expression[] getSelectProjection() {
for (String projectedProperty : projectedProperties) {

RelationalPersistentProperty property = entity.getPersistentProperty(projectedProperty);
Column column = table.column(property != null //
? property.getColumnName() //
: SqlIdentifier.unquoted(projectedProperty));
Column column;
if (property != null) {
column = table.column(property.getColumnName());
} else {
boolean forceQuote = false;
if (this.dataAccessStrategy.getConverter().getMappingContext() instanceof RelationalMappingContext relationalMappingContext) {
forceQuote = relationalMappingContext.isForceQuote();
}
column = table.column(forceQuote
? SqlIdentifier.quoted(projectedProperty)
: SqlIdentifier.unquoted(projectedProperty));
}
expressions.add(column);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ void before() {
R2dbcCustomConversions conversions = R2dbcCustomConversions.of(PostgresDialect.INSTANCE, new MoneyConverter(),
new RowConverter(), new RowDocumentConverter(), new PkConverter());

R2dbcMappingContext context = new R2dbcMappingContext();
context.setForceQuote(false);
entityTemplate = new R2dbcEntityTemplate(client, PostgresDialect.INSTANCE,
new MappingR2dbcConverter(new R2dbcMappingContext(), conversions));
new MappingR2dbcConverter(context, conversions));
}

@Test // GH-220
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.junit.jupiter.api.Test;
import org.springframework.data.annotation.ReadOnlyProperty;
import org.springframework.data.r2dbc.dialect.R2dbcDialect;
import org.springframework.data.r2dbc.mapping.OutboundRow;
import org.springframework.data.relational.core.sql.SqlIdentifier;
import org.springframework.r2dbc.core.Parameter;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void shouldDelete() {

StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("DELETE"));

assertThat(statement.getSql()).isEqualTo("DELETE FROM person");
assertThat(statement.getSql()).isEqualTo("DELETE FROM \"person\"");
}

@Test // gh-410
Expand Down Expand Up @@ -104,7 +104,7 @@ void shouldDeleteWithQuery() {

StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("DELETE"));

assertThat(statement.getSql()).isEqualTo("DELETE FROM person WHERE person.THE_NAME = $1");
assertThat(statement.getSql()).isEqualTo("DELETE FROM \"person\" WHERE \"person\".\"THE_NAME\" = $1");
assertThat(statement.getBindings()).hasSize(1).containsEntry(0, Parameter.from("Walter"));
}

Expand All @@ -125,7 +125,7 @@ void shouldDeleteInTable() {

StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("DELETE"));

assertThat(statement.getSql()).isEqualTo("DELETE FROM other_table WHERE other_table.THE_NAME = $1");
assertThat(statement.getSql()).isEqualTo("DELETE FROM other_table WHERE other_table.\"THE_NAME\" = $1");
}

static class Person {
Expand Down
Loading