Skip to content

Commit

Permalink
Add locking in StandardServletAsyncWebRequest
Browse files Browse the repository at this point in the history
The lock protects against race between onError/onComplete notifications
and operations on the ServletOutputStream.

See spring-projectsgh-32342
  • Loading branch information
rstoyanchev committed Mar 1, 2024
1 parent 3b7c435 commit 2aca714
Showing 1 changed file with 58 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;

import javax.servlet.AsyncContext;
Expand All @@ -34,6 +34,7 @@
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.util.DisconnectedClientHelper;

/**
* A Servlet implementation of {@link AsyncWebRequest}.
Expand All @@ -60,9 +61,9 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Nullable
private AsyncContext asyncContext;

private final AtomicReference<State> state;
private State state;

private volatile boolean hasError;
private final ReentrantLock stateLock = new ReentrantLock();


/**
Expand All @@ -87,13 +88,7 @@ public StandardServletAsyncWebRequest(HttpServletRequest request, HttpServletRes

super(request, new LifecycleHttpServletResponse(response));

if (previousRequest != null) {
this.state = previousRequest.state;
this.hasError = previousRequest.hasError;
}
else {
this.state = new AtomicReference<>(State.ACTIVE);
}
this.state = (previousRequest != null ? previousRequest.state : State.ACTIVE);

//noinspection DataFlowIssue
((LifecycleHttpServletResponse) getResponse()).setAsyncWebRequest(this);
Expand Down Expand Up @@ -137,7 +132,7 @@ public boolean isAsyncStarted() {
*/
@Override
public boolean isAsyncComplete() {
return (this.state.get() == State.COMPLETED);
return (this.state == State.COMPLETED);
}

@Override
Expand Down Expand Up @@ -184,20 +179,41 @@ public void onTimeout(AsyncEvent event) throws IOException {

@Override
public void onError(AsyncEvent event) throws IOException {
transitionToErrorState();
this.exceptionHandlers.forEach(consumer -> consumer.accept(event.getThrowable()));
this.stateLock.lock();
try {
transitionToErrorState();
Throwable ex = event.getThrowable();
this.exceptionHandlers.forEach(consumer -> consumer.accept(ex));

// We skip ASYNC dispatches for "disconnected client" errors,
// but can only complete from a Servlet container thread

if (DisconnectedClientHelper.isClientDisconnectedException(ex) && this.state != State.COMPLETED) {
this.asyncContext.complete();
}
}
finally {
this.stateLock.unlock();
}
}

private void transitionToErrorState() {
this.hasError = true;
this.state.compareAndSet(State.ACTIVE, State.ERROR);
if (this.state == State.ACTIVE) {
this.state = State.ERROR;
}
}

@Override
public void onComplete(AsyncEvent event) throws IOException {
this.completionHandlers.forEach(Runnable::run);
this.asyncContext = null;
this.state.set(State.COMPLETED);
this.stateLock.lock();
try {
this.completionHandlers.forEach(Runnable::run);
this.asyncContext = null;
this.state = State.COMPLETED;
}
finally {
this.stateLock.unlock();
}
}


Expand Down Expand Up @@ -256,59 +272,76 @@ public boolean isReady() {

@Override
public void setWriteListener(WriteListener writeListener) {
throw new UnsupportedOperationException();
}

@Override
public void write(int b) throws IOException {
checkState();
obtainLockAndCheckState();
try {
this.delegate.getOutputStream().write(b);
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to write");
}
finally {
releaseLock();
}
}

public void write(byte[] buf, int offset, int len) throws IOException {
checkState();
obtainLockAndCheckState();
try {
this.delegate.getOutputStream().write(buf, offset, len);
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to write");
}
finally {
releaseLock();
}
}

@Override
public void flush() throws IOException {
checkState();
obtainLockAndCheckState();
try {
this.delegate.getOutputStream().flush();
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to flush");
}
finally {
releaseLock();
}
}

@Override
public void close() throws IOException {
checkState();
obtainLockAndCheckState();
try {
this.delegate.getOutputStream().close();
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to close");
}
finally {
releaseLock();
}
}

private void checkState() throws AsyncRequestNotUsableException {
if (this.asyncWebRequest.state.get() != State.ACTIVE) {
private void obtainLockAndCheckState() throws AsyncRequestNotUsableException {
if (!this.asyncWebRequest.stateLock.tryLock() || this.asyncWebRequest.state != State.ACTIVE) {
throw new AsyncRequestNotUsableException("Response not usable after " +
(this.asyncWebRequest.state.get() == State.COMPLETED ?
(this.asyncWebRequest.state == State.COMPLETED ?
"async request completion" : "onError notification") + ".");
}
}

private void releaseLock() {
this.asyncWebRequest.stateLock.unlock();
}

private void handleIOException(IOException ex, String msg) throws AsyncRequestNotUsableException {
this.asyncWebRequest.transitionToErrorState();
throw new AsyncRequestNotUsableException(msg, ex);
Expand Down

0 comments on commit 2aca714

Please sign in to comment.