Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a race condition in the AudioLoader #125

Merged
merged 2 commits into from
Jul 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,70 +31,61 @@
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicBoolean;

public class AudioLoader implements AudioLoadResultHandler {

private static final Logger log = LoggerFactory.getLogger(AudioLoader.class);
private final AudioPlayerManager audioPlayerManager;

private List<AudioTrack> loadedItems;
private boolean used = false;
private final CompletableFuture<List<AudioTrack>> loadedItems;
private final AtomicBoolean used = new AtomicBoolean(false);

public AudioLoader(AudioPlayerManager audioPlayerManager) {
this.audioPlayerManager = audioPlayerManager;
this.loadedItems = new CompletableFuture<>();
}

List<AudioTrack> loadSync(String identifier) throws InterruptedException {
if(used)
public CompletionStage<List<AudioTrack>> load(String identifier) {
boolean isUsed = this.used.getAndSet(true);
if (isUsed) {
throw new IllegalStateException("This loader can only be used once per instance");

used = true;

audioPlayerManager.loadItem(identifier, this);

synchronized (this) {
this.wait();
}

log.trace("Loading item with identifier {}", identifier);
this.audioPlayerManager.loadItem(identifier, this);

return loadedItems;
}

@Override
public void trackLoaded(AudioTrack audioTrack) {
loadedItems = new ArrayList<>();
loadedItems.add(audioTrack);
log.info("Loaded track " + audioTrack.getInfo().title);
synchronized (this) {
this.notify();
}
ArrayList<AudioTrack> result = new ArrayList<>();
result.add(audioTrack);
this.loadedItems.complete(result);
}

@Override
public void playlistLoaded(AudioPlaylist audioPlaylist) {
log.info("Loaded playlist " + audioPlaylist.getName());
loadedItems = audioPlaylist.getTracks();
synchronized (this) {
this.notify();
}
this.loadedItems.complete(audioPlaylist.getTracks());
}

@Override
public void noMatches() {
log.info("No matches found");
loadedItems = new ArrayList<>();
synchronized (this) {
this.notify();
}
this.loadedItems.complete(Collections.emptyList());
}

@Override
public void loadFailed(FriendlyException e) {
log.error("Load failed", e);
loadedItems = new ArrayList<>();
synchronized (this) {
this.notify();
}
this.loadedItems.complete(Collections.emptyList());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,24 @@
import org.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Controller;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.RestController;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;

@Controller
@RestController
public class AudioLoaderRestHandler {

private static final Logger log = LoggerFactory.getLogger(AudioLoaderRestHandler.class);
Expand All @@ -60,18 +65,18 @@ private void log(HttpServletRequest request) {
log.info("GET " + path);
}

private boolean isAuthorized(HttpServletRequest request, HttpServletResponse response) {
//returns an empty answer if the auth succeeded, or a response to send back immediately
private <T> Optional<ResponseEntity<T>> checkAuthorization(HttpServletRequest request) {
if (request.getHeader("Authorization") == null) {
response.setStatus(403);
return false;
return Optional.of(new ResponseEntity<>(HttpStatus.UNAUTHORIZED));
}

if (!request.getHeader("Authorization").equals(serverConfig.getPassword())) {
log.warn("Authorization failed");
response.setStatus(403);
return false;
return Optional.of(new ResponseEntity<>(HttpStatus.FORBIDDEN));
}
return true;

return Optional.empty();
}

private JSONObject trackToJSON(AudioTrack audioTrack) {
Expand All @@ -88,19 +93,9 @@ private JSONObject trackToJSON(AudioTrack audioTrack) {
.put("position", audioTrack.getPosition());
}

@GetMapping(value = "/loadtracks", produces = "application/json")
@ResponseBody
public String getLoadTracks(HttpServletRequest request, HttpServletResponse response, @RequestParam String identifier)
throws IOException, InterruptedException {
log(request);

if (!isAuthorized(request, response))
return "";

private JSONArray encodeTrackList(List<AudioTrack> trackList) {
JSONArray tracks = new JSONArray();
List<AudioTrack> list = new AudioLoader(audioPlayerManager).loadSync(identifier);

list.forEach(track -> {
trackList.forEach(track -> {
JSONObject object = new JSONObject();
object.put("info", trackToJSON(track));

Expand All @@ -109,33 +104,53 @@ public String getLoadTracks(HttpServletRequest request, HttpServletResponse resp
object.put("track", encoded);
tracks.put(object);
} catch (IOException e) {
throw new RuntimeException();
log.warn("Failed to encode a track {}, skipping", track.getIdentifier(), e);
}
});
return tracks;
}

return tracks.toString();
@GetMapping(value = "/loadtracks", produces = "application/json")
@ResponseBody
public CompletionStage<ResponseEntity<String>> getLoadTracks(HttpServletRequest request, @RequestParam String identifier) {
log(request);

Optional<ResponseEntity<String>> notAuthed = checkAuthorization(request);
if (notAuthed.isPresent()) {
return CompletableFuture.completedFuture(notAuthed.get());
}

return new AudioLoader(audioPlayerManager).load(identifier)
.thenApply(this::encodeTrackList)
.thenApply(tracksArray -> new ResponseEntity<>(tracksArray.toString(), HttpStatus.OK));
}

@GetMapping(value = "/decodetrack", produces = "application/json")
@ResponseBody
public String getDecodeTrack(HttpServletRequest request, HttpServletResponse response, @RequestParam String track) throws IOException {
public ResponseEntity<String> getDecodeTrack(HttpServletRequest request, HttpServletResponse response, @RequestParam String track)
throws IOException {
log(request);

if (!isAuthorized(request, response))
return "";
Optional<ResponseEntity<String>> notAuthed = checkAuthorization(request);
if (notAuthed.isPresent()) {
return notAuthed.get();
}

AudioTrack audioTrack = Util.toAudioTrack(audioPlayerManager, track);

return trackToJSON(audioTrack).toString();
return new ResponseEntity<>(trackToJSON(audioTrack).toString(), HttpStatus.OK);
}

@PostMapping(value = "/decodetracks", consumes = "application/json", produces = "application/json")
@ResponseBody
public String postDecodeTracks(HttpServletRequest request, HttpServletResponse response, @RequestBody String body) throws IOException {
public ResponseEntity<String> postDecodeTracks(HttpServletRequest request, HttpServletResponse response, @RequestBody String body)
throws IOException {
log(request);

if (!isAuthorized(request, response))
return "";
Optional<ResponseEntity<String>> notAuthed = checkAuthorization(request);
if (notAuthed.isPresent()) {
return notAuthed.get();
}

JSONArray requestJSON = new JSONArray(body);
JSONArray responseJSON = new JSONArray();
Expand All @@ -152,6 +167,6 @@ public String postDecodeTracks(HttpServletRequest request, HttpServletResponse r
responseJSON.put(trackJSON);
}

return responseJSON.toString();
return new ResponseEntity<>(responseJSON.toString(), HttpStatus.OK);
}
}