Skip to content

Commit

Permalink
Merge pull request #1204 from ironmansoftware/session
Browse files Browse the repository at this point in the history
Fix issue with sessions.
  • Loading branch information
adamdriscoll authored Sep 30, 2019
2 parents 55d014a + 41bc132 commit 8694913
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 73 deletions.
23 changes: 9 additions & 14 deletions src/UniversalDashboard/Controllers/ComponentController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
using UniversalDashboard.Interfaces;
using UniversalDashboard.Models.Basics;
using System.Security;
using StackExchange.Profiling;
using Microsoft.Extensions.Primitives;

namespace UniversalDashboard.Controllers
{
[Route("api/internal/component")]
public class ComponentController : Controller
{
private static readonly Logger Log = LogManager.GetLogger(nameof(ComponentController));

private readonly IExecutionService _executionService;
private readonly IDashboardService _dashboardService;
private readonly AutoReloader _autoReloader;
Expand Down Expand Up @@ -74,23 +73,19 @@ private async Task<IActionResult> RunScript(Endpoint endpoint, Dictionary<string
ExecutionContext executionContext = new ExecutionContext(endpoint, variables, parameters, HttpContext?.User);
executionContext.NoSerialization = noSerialization;

if (HttpContext.Session.TryGetValue("SessionId", out byte[] sessionIdBytes))
if (HttpContext.Request.Headers.TryGetValue("UDConnectionId", out StringValues connectionId))
{
var sessionId = new Guid(sessionIdBytes);
executionContext.SessionId = sessionId.ToString();
executionContext.ConnectionId = _memoryCache.Get(executionContext.SessionId) as string;
executionContext.SessionId = _memoryCache.Get(connectionId) as string;
executionContext.ConnectionId = connectionId;
}

using (MiniProfiler.Current.Step($"Execute: {endpoint.Name}"))
return await Task.Run(() =>
{
return await Task.Run(() =>
{
var result = _executionService.ExecuteEndpoint(executionContext, endpoint);
var actionResult = ConvertToActionResult(result);
var result = _executionService.ExecuteEndpoint(executionContext, endpoint);
var actionResult = ConvertToActionResult(result);

return actionResult;
});
}
return actionResult;
});
}
catch (Exception ex) {
Log.Warn("RunScript() " + ex.Message + Environment.NewLine + ex.StackTrace);
Expand Down
88 changes: 56 additions & 32 deletions src/UniversalDashboard/Execution/EndpointProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ namespace UniversalDashboard.Execution
{
public class EndpointService : IEndpointService
{
private readonly MemoryCache _endpointCache;
private readonly List<Endpoint> _restEndpoints;
private readonly List<Endpoint> _scheduledEndpoints;
private static readonly Logger logger = LogManager.GetLogger("EndpointService");
Expand All @@ -22,6 +21,9 @@ public class EndpointService : IEndpointService

private static IEndpointService _instance;

public Dictionary<string, Endpoint> Endpoints { get; private set; }
public Dictionary<string, SessionState> Sessions { get; private set; }

public static IEndpointService Instance
{
get
Expand All @@ -37,29 +39,54 @@ public static IEndpointService Instance

private EndpointService()
{
_endpointCache = new MemoryCache(new MemoryCacheOptions());
Endpoints = new Dictionary<string, Endpoint>();
Sessions = new Dictionary<string, SessionState>();

_restEndpoints = new List<Endpoint>();
_scheduledEndpoints = new List<Endpoint>();
}

public MemoryCache EndpointCache => _endpointCache;

public void StartSession(string sessionId)
public void StartSession(string sessionId, string connectionId)
{
lock(sessionLock)
{
if (_sessionLocks.ContainsKey(sessionId)) return;
_endpointCache.Set(Constants.SessionState + sessionId, new SessionState());
_sessionLocks.Add(sessionId, new object());
if (_sessionLocks.ContainsKey(sessionId))
{
lock(_sessionLocks[sessionId])
{
var session = Sessions[sessionId];
session.ConnectionIds.Add(connectionId);
}
}
else
{
Sessions.Add(sessionId, new SessionState {
ConnectionIds = new List<string> {
connectionId
}
});
_sessionLocks.Add(sessionId, new object());
}
}
}

public void EndSession(string sessionId)
public void EndSession(string sessionId, string connectionId)
{
lock(sessionLock)
{
_endpointCache.Remove(Constants.SessionState + sessionId);
_sessionLocks.Remove(sessionId);
var session = Sessions[sessionId];
if (session.ConnectionIds.Count <= 1)
{
Sessions.Remove(sessionId);
_sessionLocks.Remove(sessionId);
}
else
{
lock(_sessionLocks[sessionId])
{
session.ConnectionIds.Remove(connectionId);
}
}
}
}

Expand All @@ -82,24 +109,24 @@ public void Register(Endpoint callback)

if (callback.SessionId == null)
{
_endpointCache.Set(callback.Name, callback);
Endpoints.Add(callback.Name, callback);
}
else
{
lock(sessionLock)
{
if (!_sessionLocks.ContainsKey(callback.SessionId))
{
StartSession(callback.SessionId);
StartSession(callback.SessionId, string.Empty);
}
}

lock(_sessionLocks[callback.SessionId])
{
if (_endpointCache.TryGetValue(Constants.SessionState + callback.SessionId, out SessionState sessionState))
if (Sessions.ContainsKey(callback.SessionId))
{
sessionState.Endpoints.Add(callback);
_endpointCache.Set(Constants.SessionState + callback.SessionId, sessionState);
var session = Sessions[callback.SessionId];
session.Endpoints.Add(callback.Name, callback);
}
}
}
Expand Down Expand Up @@ -130,24 +157,23 @@ public void Unregister(string name, string sessionId)

if (sessionId == null)
{
if (_endpointCache.TryGetValue(name, out object result))
if (Endpoints.ContainsKey(name))
{
Endpoints.Remove(name);
logger.Debug("Endpoint found. Removing endpoint.");
_endpointCache.Remove(name);
}
}
else
{
if (_endpointCache.TryGetValue(Constants.SessionState + sessionId, out SessionState sessionState))
if (Sessions.ContainsKey(sessionId))
{
var endpoint = sessionState.Endpoints.FirstOrDefault(m => m.Name?.Equals(name, StringComparison.OrdinalIgnoreCase) == true);
if (endpoint != null)
var session = Sessions[sessionId];
if (session.Endpoints.ContainsKey(name))
{
logger.Debug("Session endpoint found. Removing endpoint.");
lock(sessionState.SyncRoot)
lock(session.SyncRoot)
{
sessionState.Endpoints.Remove(endpoint);
_endpointCache.Set(Constants.SessionState + sessionId, sessionState);
session.Endpoints.Remove(name);
}
}
}
Expand All @@ -157,25 +183,23 @@ public void Unregister(string name, string sessionId)
public Endpoint Get(string name, string sessionId)
{
logger.Debug($"Get() {name} {sessionId}");

Endpoint callback;
if (sessionId != null)
{
if (_endpointCache.TryGetValue(Constants.SessionState + sessionId, out SessionState sessionState))
if (Sessions.ContainsKey(sessionId))
{
var endpoint = sessionState.Endpoints.FirstOrDefault(m => m.Name?.Equals(name, StringComparison.OrdinalIgnoreCase) == true);
if (endpoint != null)
var session = Sessions[sessionId];
if (session.Endpoints.ContainsKey(name))
{
logger.Debug("Found session endpoint.");
return endpoint;
return session.Endpoints[name];
}
}
}

if (_endpointCache.TryGetValue(name, out callback))
if (Endpoints.ContainsKey(name))
{
logger.Debug("Found endpoint.");
return callback;
return Endpoints[name];
}

return null;
Expand Down
6 changes: 4 additions & 2 deletions src/UniversalDashboard/Interfaces/IEndpointService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ public interface IEndpointService
Endpoint GetByUrl(string url, string method, Dictionary<string, object> matchedVariables);
IEnumerable<Endpoint> GetScheduledEndpoints();
void Register(Endpoint callback);
void StartSession(string sessionId);
void EndSession(string sessionId);
void StartSession(string sessionId, string connectionId);
void EndSession(string sessionId, string connectionId);
Dictionary<string, Endpoint> Endpoints { get; }
Dictionary<string, SessionState> Sessions { get; }
}
}
7 changes: 4 additions & 3 deletions src/UniversalDashboard/Models/SessionState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ public class SessionState
{
public SessionState()
{
Endpoints = new List<Endpoint>();
Endpoints = new Dictionary<string, Endpoint>();
ConnectionIds = new List<string>();
SyncRoot = new object();
}

public object SyncRoot { get; set; }

public List<Endpoint> Endpoints { get; set; }
public List<string> ConnectionIds { get; set; }
public Dictionary<string, Endpoint> Endpoints { get; set; }
}
}
9 changes: 4 additions & 5 deletions src/UniversalDashboard/Server/DashboardHub.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public override async Task OnDisconnectedAsync(Exception exception)
if (sessionId != null)
{
_memoryCache.Remove(sessionId);
_dashboardService.EndpointService.EndSession(sessionId as string);
_dashboardService.EndpointService.EndSession(sessionId as string, Context.ConnectionId);
}

_memoryCache.Remove(Context.ConnectionId);
Expand All @@ -154,11 +154,10 @@ public async Task SetSessionId(string sessionId)
{
Log.Debug($"SetSessionId({sessionId})");

await Task.FromResult(0);

_memoryCache.Set(Context.ConnectionId, sessionId);
_memoryCache.Set(sessionId, Context.ConnectionId);
_dashboardService.EndpointService.StartSession(sessionId);
_dashboardService.EndpointService.StartSession(sessionId, Context.ConnectionId);

await Clients.All.SendAsync("setConnectionId", Context.ConnectionId);
}

public Task Reload()
Expand Down
13 changes: 5 additions & 8 deletions src/UniversalDashboard/Server/ServerStartup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ public void ConfigureServices(IServiceCollection services)
options.Cookie.SecurePolicy = CookieSecurePolicy.SameAsRequest;
});

services.AddMiniProfiler();

var serviceDescriptor = services.FirstOrDefault(descriptor => descriptor.ServiceType.Name == "IRegistryPolicyResolver");
services.Remove(serviceDescriptor);
}
Expand Down Expand Up @@ -123,11 +121,11 @@ public void Configure(IApplicationBuilder app, Microsoft.AspNetCore.Hosting.IHos
ContentTypeProvider = provider
});

var dashboardService = app.ApplicationServices.GetService(typeof(IDashboardService)) as IDashboardService;
var dashboardService = app.ApplicationServices.GetService(typeof(IDashboardService)) as IDashboardService;

if (dashboardService?.DashboardOptions?.Certificate != null || dashboardService?.DashboardOptions?.CertificateFile != null) {
app.UseHttpsRedirection();
}
if (dashboardService?.DashboardOptions?.Certificate != null || dashboardService?.DashboardOptions?.CertificateFile != null) {
app.UseHttpsRedirection();
}

if (dashboardService?.DashboardOptions?.PublishedFolders != null) {
foreach(var publishedFolder in dashboardService.DashboardOptions.PublishedFolders) {
Expand All @@ -147,12 +145,11 @@ public void Configure(IApplicationBuilder app, Microsoft.AspNetCore.Hosting.IHos
{
routes.MapHub<DashboardHub>("/dashboardhub");
});

app.UseWebSockets();

app.UseSession();

app.UseMiniProfiler();

app.UseMvc();
}
}
Expand Down
1 change: 0 additions & 1 deletion src/UniversalDashboard/UniversalDashboard.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv" Version="2.1.3" />
<PackageReference Include="MiniProfiler.AspNetCore.Mvc" Version="4.1.0" />
<PackageReference Include="System.Reflection.Emit" Version="4.3.0" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="2.1.2" />
<PackageReference Include="Newtonsoft.Json" Version="11.0.2" />
Expand Down
20 changes: 14 additions & 6 deletions src/client/src/app/services/fetch-service.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ import {getApiPath} from 'config'

export const fetchGet = function(url, success, history) {
fetch(getApiPath() + url, {
credentials: 'include'
credentials: 'include',
headers: {
'UDConnectionId': UniversalDashboard.connectionId
}
})
.then(function(response){
UniversalDashboard.invokeMiddleware('GET', url, history, response);
Expand Down Expand Up @@ -36,7 +39,8 @@ export const fetchPost = function(url, data, success) {
method: 'post',
headers: {
'Accept': 'application/json, text/plain, */*',
'Content-Type': 'application/json'
'Content-Type': 'application/json',
'UDConnectionId': UniversalDashboard.connectionId
},
body: JSON.stringify(data),
credentials: 'include'
Expand All @@ -62,7 +66,8 @@ export const fetchPostFormData = function(url, data, success) {
fetch(getApiPath() + url, {
method: 'post',
headers: {
'Accept': 'application/json, text/plain, */*'//,
'Accept': 'application/json, text/plain, */*',
'UDConnectionId': UniversalDashboard.connectionId
//'Content-Type': 'multipart/form-data'
},
body: data,
Expand Down Expand Up @@ -90,7 +95,8 @@ export const fetchDelete = function(url, data, success) {
method: 'delete',
headers: {
'Accept': 'application/json, text/plain, */*',
'Content-Type': 'application/json'
'Content-Type': 'application/json',
'UDConnectionId': UniversalDashboard.connectionId
},
body: JSON.stringify(data),
credentials: 'include'
Expand All @@ -117,7 +123,8 @@ export const fetchPut = function(url, data, success) {
method: 'put',
headers: {
'Accept': 'application/json, text/plain, */*',
'Content-Type': 'application/json'
'Content-Type': 'application/json',
'UDConnectionId': UniversalDashboard.connectionId
},
body: JSON.stringify(data),
credentials: 'include'
Expand Down Expand Up @@ -145,7 +152,8 @@ export const fetchPostRaw = function(url, data, success) {
method: 'post',
headers: {
'Accept': 'application/json, text/plain, */*',
'Content-Type': 'text/plain'
'Content-Type': 'text/plain',
'UDConnectionId': UniversalDashboard.connectionId
},
body: data,
credentials: 'include'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export const UniversalDashboardService = {
unsubscribe: PubSub.unsubscribe,
publish: PubSub.publishSync,
toaster: toaster,
connectionId: '',
renderComponent: function(component, history, dynamicallyLoaded) {

if (component == null) return <React.Fragment/>;
Expand Down
Loading

0 comments on commit 8694913

Please sign in to comment.