Skip to content

Commit

Permalink
Use bytestring for sql table names everywhere
Browse files Browse the repository at this point in the history
This removes inefficient transforming table names from Text to bytestring and back all the time
  • Loading branch information
mpscholten committed Dec 3, 2020
1 parent fad9e2d commit a94d0a0
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 42 deletions.
2 changes: 1 addition & 1 deletion IHP/AutoRefresh.hs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ instance WSApp AutoRefreshWSApp where
AwaitingSessionID -> pure ()


registerNotificationTrigger :: (?modelContext :: ModelContext) => IORef (Set Text) -> IORef AutoRefreshServer -> IO ()
registerNotificationTrigger :: (?modelContext :: ModelContext) => IORef (Set ByteString) -> IORef AutoRefreshServer -> IO ()
registerNotificationTrigger touchedTablesVar autoRefreshServer = do
touchedTables <- Set.toList <$> readIORef touchedTablesVar
subscribedTables <- (get #subscribedTables) <$> (autoRefreshServer |> readIORef)
Expand Down
4 changes: 2 additions & 2 deletions IHP/AutoRefresh/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ data AutoRefreshSession = AutoRefreshSession
-- | MVar that is filled whenever some table changed
, event :: MVar ()
-- | All tables this auto refresh session watches
, tables :: Set Text
, tables :: Set ByteString
-- | The last rendered html of this action. Initially this is the result of the initial page rendering
, lastResponse :: LByteString
-- | Keep track of the last ping to this session to close it after too much time has passed without anything happening
, lastPing :: UTCTime
}

data AutoRefreshServer = AutoRefreshServer { sessions :: [AutoRefreshSession], subscribedTables :: Set Text }
data AutoRefreshServer = AutoRefreshServer { sessions :: [AutoRefreshSession], subscribedTables :: Set ByteString }

newAutoRefreshServer :: AutoRefreshServer
newAutoRefreshServer = AutoRefreshServer { sessions = [], subscribedTables = mempty }
18 changes: 17 additions & 1 deletion IHP/HaskellSupport.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ module IHP.HaskellSupport (
, includes
, stripTags
, symbolToText
, symbolToByteString
) where

import ClassyPrelude
Expand All @@ -44,7 +45,8 @@ import qualified Data.Attoparsec.ByteString.Char8 as Attoparsec
import Data.String.Conversions (cs)
import qualified Debug.Trace
import qualified Data.Text as Text
import qualified Data.Maybe
import qualified Data.Maybe
import qualified Data.ByteString.Char8 as ByteString

--(|>) :: a -> f -> f a
infixl 8 |>
Expand Down Expand Up @@ -81,9 +83,11 @@ includes = elem

instance Data.Default.Default UUID.UUID where
def = UUID.nil
{-# INLINE def #-}

instance forall name name'. (KnownSymbol name, name' ~ name) => IsLabel name (Proxy name') where
fromLabel = Proxy @name'
{-# INLINE fromLabel #-}

-- | Returns the field value for a field name
--
Expand Down Expand Up @@ -171,6 +175,7 @@ isToday' currentTime timestamp = utcTimeToYearMonthDay currentTime == utcTimeToY
-- | Allows `Just "someThing"` to be written as `"someThing"`
instance IsString string => IsString (Maybe string) where
fromString string = Just (fromString string)
{-# INLINE fromString #-}


-- | Example:
Expand Down Expand Up @@ -255,6 +260,17 @@ symbolToText :: forall symbol. (KnownSymbol symbol) => Text
symbolToText = Text.pack (symbolVal @symbol Proxy)
{-# INLINE symbolToText #-}

-- | Returns the value of a type level symbol as a bytestring
--
-- >>> symbolToByteString @"hello"
-- "hello"
--
-- >>> symbolToByteString @(GetTableName User)
-- "users"
symbolToByteString :: forall symbol. (KnownSymbol symbol) => ByteString
symbolToByteString = ByteString.pack (symbolVal @symbol Proxy)
{-# INLINE symbolToByteString #-}

instance IsString UUID.UUID where
fromString string = case UUID.fromString string of
Just uuid -> uuid
Expand Down
39 changes: 27 additions & 12 deletions IHP/ModelSupport.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ data ModelContext = ModelContext
-- | If True, prints out all SQL queries that are executed. Will be set to True by default in development mode (as configured in Config.hs) and False in production.
, queryDebuggingEnabled :: Bool
-- | A callback that is called whenever a specific table is accessed using a SELECT query
, trackTableReadCallback :: Maybe (Text -> IO ())
, trackTableReadCallback :: Maybe (ByteString -> IO ())
}

-- | Provides a mock ModelContext to be used when a database connection is not available
Expand Down Expand Up @@ -284,7 +284,7 @@ sqlQuery :: (?modelContext :: ModelContext, PG.ToRow q, PG.FromRow r, Show q) =>
sqlQuery theQuery theParameters = do
logQuery theQuery theParameters
withDatabaseConnection \connection -> PG.query connection theQuery theParameters
{-# INLINE sqlQuery #-}
{-# INLINABLE sqlQuery #-}


-- | Runs a sql statement (like a CREATE statement)
Expand All @@ -296,11 +296,11 @@ sqlExec :: (?modelContext :: ModelContext, PG.ToRow q, Show q) => Query -> q ->
sqlExec theQuery theParameters = do
logQuery theQuery theParameters
withDatabaseConnection \connection -> PG.execute connection theQuery theParameters
{-# INLINE sqlExec #-}
{-# INLINABLE sqlExec #-}

withDatabaseConnection :: (?modelContext :: ModelContext) => (Connection -> IO a) -> IO a
withDatabaseConnection block = let ModelContext { connectionPool } = ?modelContext in Pool.withResource connectionPool block
{-# INLINE withDatabaseConnection #-}
{-# INLINABLE withDatabaseConnection #-}

-- | Runs a raw sql query which results in a single scalar value such as an integer or string
--
Expand All @@ -315,7 +315,7 @@ sqlQueryScalar query parameters = do
pure case result of
[PG.Only result] -> result
_ -> error "sqlQueryScalar: Expected a scalar result value"
{-# INLINE sqlQueryScalar #-}
{-# INLINABLE sqlQueryScalar #-}

-- | Returns the table name of a given model.
--
Expand All @@ -328,11 +328,23 @@ tableName :: forall model. (KnownSymbol (GetTableName model)) => Text
tableName = symbolToText @(GetTableName model)
{-# INLINE tableName #-}

-- | Returns the table name of a given model as a bytestring.
--
-- __Example:__
--
-- >>> tableNameByteString @User
-- "users"
--
tableNameByteString :: forall model. (KnownSymbol (GetTableName model)) => ByteString
tableNameByteString = symbolToByteString @(GetTableName model)
{-# INLINE tableNameByteString #-}

logQuery :: (?modelContext :: ModelContext, Show query, Show parameters) => query -> parameters -> IO ()
logQuery query parameters = when queryDebuggingEnabled (putStrLn (tshow (query, parameters)))
where
ModelContext { queryDebuggingEnabled } = ?modelContext
-- Env.isProduction FrameworkConfig.environment
{-# INLINABLE logQuery #-}

-- | Runs a @DELETE@ query for a record.
--
Expand All @@ -349,7 +361,7 @@ deleteRecord model = do
logQuery theQuery theParameters
sqlExec (PG.Query . cs $! theQuery) theParameters
pure ()
{-# INLINE deleteRecord #-}
{-# INLINABLE deleteRecord #-}

-- | Runs a @DELETE@ query for a list of records.
--
Expand All @@ -365,7 +377,7 @@ deleteRecords records = do
else logQuery theQuery theParameters
sqlExec (PG.Query . cs $! theQuery) theParameters
pure ()
{-# INLINE deleteRecords #-}
{-# INLINABLE deleteRecords #-}

-- | Runs a @DELETE@ query to delete all rows in a table.
--
Expand All @@ -377,7 +389,7 @@ deleteAll = do
logQuery theQuery ()
sqlExec (PG.Query . cs $! theQuery) ()
pure ()
{-# INLINE deleteAll #-}
{-# INLINABLE deleteAll #-}

type family Include (name :: GHC.Types.Symbol) model

Expand Down Expand Up @@ -433,12 +445,15 @@ data MetaBag = MetaBag

instance Default MetaBag where
def = MetaBag { annotations = [], touchedFields = [] }
{-# INLINE def #-}

instance SetField "annotations" MetaBag [(Text, Text)] where
setField value meta = meta { annotations = value }
{-# INLINE setField #-}

instance SetField "touchedFields" MetaBag [Text] where
setField value meta = meta { touchedFields = value }
{-# INLINE setField #-}

-- | Returns 'True' if any fields of the record have unsaved changes
--
Expand Down Expand Up @@ -554,7 +569,7 @@ instance (ToJSON (PrimaryKey a)) => ToJSON (Id' a) where

-- | Thrown by 'fetchOne' when the query result is empty
data RecordNotFoundException
= RecordNotFoundException { queryAndParams :: (Text, [Action]) }
= RecordNotFoundException { queryAndParams :: (ByteString, [Action]) }
deriving (Show)

instance Exception RecordNotFoundException
Expand All @@ -573,11 +588,11 @@ instance ToField value => ToField [value] where
instance (FromField value, Typeable value) => FromField [value] where
fromField field value = PG.fromPGArray <$> (fromField field value)

trackTableRead :: (?modelContext :: ModelContext) => Text -> IO ()
trackTableRead :: (?modelContext :: ModelContext) => ByteString -> IO ()
trackTableRead tableName = case get #trackTableReadCallback ?modelContext of
Just callback -> callback tableName
Nothing -> pure ()
{-# INLINE trackTableRead #-}
{-# INLINABLE trackTableRead #-}

-- | Track all tables in SELECT queries executed within the given IO action.
--
Expand All @@ -592,7 +607,7 @@ trackTableRead tableName = case get #trackTableReadCallback ?modelContext of
-- > tables <- readIORef ?touchedTables
-- > -- tables = Set.fromList ["projects", "users"]
-- >
withTableReadTracker :: (?modelContext :: ModelContext) => ((?modelContext :: ModelContext, ?touchedTables :: IORef (Set Text)) => IO ()) -> IO ()
withTableReadTracker :: (?modelContext :: ModelContext) => ((?modelContext :: ModelContext, ?touchedTables :: IORef (Set ByteString)) => IO ()) -> IO ()
withTableReadTracker trackedSection = do
touchedTablesVar <- newIORef Set.empty
let trackTableReadCallback = Just \tableName -> modifyIORef touchedTablesVar (Set.insert tableName)
Expand Down
10 changes: 5 additions & 5 deletions IHP/PGNotify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ import IHP.ModelSupport
-- Now insert something into the @projects@ table. E.g. by running @make psql@ and then running @INSERT INTO projects (id, name) VALUES (DEFAULT, 'New project');@
-- You will see that @"Something changed in the projects table"@ is printed onto the screen.
--
watchInsertOrUpdateTable :: (?modelContext :: ModelContext) => Text -> IO () -> IO (Async ())
watchInsertOrUpdateTable :: (?modelContext :: ModelContext) => ByteString -> IO () -> IO (Async ())
watchInsertOrUpdateTable tableName onInsertOrUpdate = do
sqlExec (PG.Query $ cs $ createNotificationTrigger tableName) ()
sqlExec (PG.Query $ createNotificationTrigger tableName) ()

let listenStatement = "LISTEN " <> PG.Query (cs $ eventName tableName)
let listenStatement = "LISTEN " <> PG.Query (eventName tableName)
async do
forever do
notification <- withDatabaseConnection \databaseConnection -> do
Expand All @@ -47,7 +47,7 @@ watchInsertOrUpdateTable tableName onInsertOrUpdate = do
async onInsertOrUpdate

-- | Returns the sql code to set up a database trigger. Mainly used by 'watchInsertOrUpdateTable'.
createNotificationTrigger :: Text -> Text
createNotificationTrigger :: ByteString -> ByteString
createNotificationTrigger tableName = "CREATE OR REPLACE FUNCTION " <> functionName <> "() RETURNS TRIGGER AS $$"
<> "BEGIN\n"
<> " PERFORM pg_notify('" <> eventName tableName <> "', '');\n"
Expand All @@ -64,5 +64,5 @@ createNotificationTrigger tableName = "CREATE OR REPLACE FUNCTION " <> functionN
deleteTriggerName = "did_delete_" <> tableName

-- | Retuns the event name of the event that the pg notify trigger dispatches
eventName :: Text -> Text
eventName :: ByteString -> ByteString
eventName tableName = "did_change_" <> tableName
44 changes: 23 additions & 21 deletions IHP/QueryBuilder.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE BangPatterns, TypeFamilies, DataKinds, PolyKinds, TypeApplications, ScopedTypeVariables, TypeInType, ConstraintKinds, TypeOperators, GADTs, UndecidableInstances, StandaloneDeriving, FunctionalDependencies, FlexibleContexts, InstanceSigs, AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns, TypeFamilies, DataKinds, PolyKinds, TypeApplications, ScopedTypeVariables, TypeInType, ConstraintKinds, TypeOperators, GADTs, UndecidableInstances, StandaloneDeriving, FunctionalDependencies, FlexibleContexts, InstanceSigs, AllowAmbiguousTypes, DeriveAnyClass #-}
{-|
Module: IHP.QueryBuilder
Description: Tool to build simple sql queries
Expand Down Expand Up @@ -53,6 +53,8 @@ import GHC.OverloadedLabels
import IHP.ModelSupport
import qualified Data.ByteString.Builder as ByteStringBuilder
import IHP.HtmlSupport.ToHtml
import qualified Data.ByteString.Char8 as ByteString
import qualified Control.DeepSeq as DeepSeq

-- | Represent's a @SELECT * FROM ..@ query. It's the starting point to build a query.
-- Used togehter with the other functions to compose a sql query.
Expand Down Expand Up @@ -106,7 +108,7 @@ data QueryBuilder (table :: Symbol) where
OffsetQueryBuilder :: Int -> !(QueryBuilder table) -> QueryBuilder table
UnionQueryBuilder :: !(QueryBuilder table) -> !(QueryBuilder table) -> QueryBuilder table

data Condition = VarCondition !Text !Action | OrCondition !Condition !Condition | AndCondition !Condition !Condition deriving (Show)
data Condition = VarCondition !ByteString !Action | OrCondition !Condition !Condition | AndCondition !Condition !Condition deriving (Show)

deriving instance Show (QueryBuilder a)

Expand All @@ -119,31 +121,31 @@ instance Eq (IHP.QueryBuilder.QueryBuilder table) where a == b = True

data OrderByDirection = Asc | Desc deriving (Eq, Show)
data SQLQuery = SQLQuery {
selectFrom :: !Text,
selectFrom :: !ByteString,
whereCondition :: !(Maybe Condition),
orderByClause :: !([(Text, OrderByDirection)]),
limitClause :: !(Maybe Text),
offsetClause :: !(Maybe Text)
orderByClause :: !([(ByteString, OrderByDirection)]),
limitClause :: !(Maybe ByteString),
offsetClause :: !(Maybe ByteString)
}

{-# INLINE buildQuery #-}
buildQuery :: forall table. (KnownSymbol table) => QueryBuilder table -> SQLQuery
buildQuery !queryBuilder =
case queryBuilder of
NewQueryBuilder ->
let tableName = symbolToText @table
let tableName = symbolToByteString @table
in SQLQuery { selectFrom = cs tableName, whereCondition = Nothing, orderByClause = [], limitClause = Nothing, offsetClause = Nothing }
FilterByQueryBuilder (fieldProxy, operator, value) queryBuilder ->
let
query = buildQuery queryBuilder
condition = VarCondition ((fieldNameToColumnName . cs $ symbolVal fieldProxy) <> " " <> compileOperator fieldProxy operator <> " ?") value
condition = VarCondition ((cs $ fieldNameToColumnName . cs $ symbolVal fieldProxy) <> " " <> compileOperator fieldProxy operator <> " ?") value
in
query { whereCondition = Just $ case whereCondition query of Just c -> AndCondition c condition; Nothing -> condition }
OrderByQueryBuilder (fieldProxy, orderByDirection) queryBuilder ->
let query = buildQuery queryBuilder
in query { orderByClause = (orderByClause query) ++ [(fieldNameToColumnName . cs $ symbolVal fieldProxy, orderByDirection)] } -- although adding to the end of a list is bad form, these lists are very short
LimitQueryBuilder limit queryBuilder -> (buildQuery queryBuilder) { limitClause = Just ("LIMIT " <> tshow limit) }
OffsetQueryBuilder offset queryBuilder -> (buildQuery queryBuilder) { offsetClause = Just ("OFFSET " <> tshow offset) }
in query { orderByClause = (orderByClause query) ++ [(cs $ fieldNameToColumnName . cs $ symbolVal fieldProxy, orderByDirection)] } -- although adding to the end of a list is bad form, these lists are very short
LimitQueryBuilder limit queryBuilder -> (buildQuery queryBuilder) { limitClause = Just ("LIMIT " <> cs (show limit)) }
OffsetQueryBuilder offset queryBuilder -> (buildQuery queryBuilder) { offsetClause = Just ("OFFSET " <> cs (show offset)) }
UnionQueryBuilder firstQueryBuilder secondQueryBuilder ->
let
firstQuery = buildQuery firstQueryBuilder
Expand Down Expand Up @@ -175,15 +177,15 @@ instance (model ~ GetModelByTableName table, KnownSymbol table) => Fetchable (Qu
fetch !queryBuilder = do
let !(theQuery, theParameters) = toSQL' (buildQuery queryBuilder)
logQuery theQuery theParameters
trackTableRead (tableName @model)
trackTableRead (tableNameByteString @model)
sqlQuery (Query $ cs theQuery) theParameters

{-# INLINE fetchOneOrNothing #-}
fetchOneOrNothing :: (?modelContext :: ModelContext) => (PG.FromRow model, KnownSymbol (GetTableName model)) => QueryBuilder table -> IO (Maybe model)
fetchOneOrNothing !queryBuilder = do
let !(theQuery, theParameters) = toSQL' (buildQuery queryBuilder) { limitClause = Just "LIMIT 1"}
logQuery theQuery theParameters
trackTableRead (tableName @model)
trackTableRead (tableNameByteString @model)
results <- sqlQuery (Query $ cs theQuery) theParameters
pure $ listToMaybe results

Expand Down Expand Up @@ -213,7 +215,7 @@ fetchCount !queryBuilder = do
let !(theQuery', theParameters) = toSQL' (buildQuery queryBuilder)
let theQuery = "SELECT COUNT(*) FROM (" <> theQuery' <> ") AS _count_values"
logQuery theQuery theParameters
trackTableRead (symbolToText @table)
trackTableRead (symbolToByteString @table)
[PG.Only count] <- sqlQuery (Query $! cs theQuery) theParameters
pure count
{-# INLINE fetchCount #-}
Expand All @@ -233,7 +235,7 @@ fetchExists !queryBuilder = do
let !(theQuery', theParameters) = toSQL' (buildQuery queryBuilder)
let theQuery = "SELECT EXISTS (" <> theQuery' <> ") AS _exists_values"
logQuery theQuery theParameters
trackTableRead (symbolToText @table)
trackTableRead (symbolToByteString @table)
[PG.Only exists] <- sqlQuery (Query $! cs theQuery) theParameters
pure exists
{-# INLINE fetchExists #-}
Expand Down Expand Up @@ -262,12 +264,12 @@ genericfetchIdsOneOrNothing !ids = query @model |> filterWhereIn (#id, ids) |> f
genericFetchIdsOne :: forall model value table. (KnownSymbol table, PG.FromRow model, ?modelContext :: ModelContext, ToField value, EqOrIsOperator value, HasField "id" model value, model ~ GetModelByTableName table, GetTableName model ~ table) => [value] -> IO model
genericFetchIdsOne !ids = query @model |> filterWhereIn (#id, ids) |> fetchOne

toSQL :: forall table. (KnownSymbol table) => QueryBuilder table -> (Text, [Action])
toSQL :: forall table. (KnownSymbol table) => QueryBuilder table -> (ByteString, [Action])
toSQL queryBuilder = toSQL' (buildQuery queryBuilder)
{-# INLINE toSQL #-}

toSQL' sqlQuery@SQLQuery { selectFrom, orderByClause, limitClause, offsetClause } =
(theQuery, theParams)
(DeepSeq.force theQuery, theParams)
where
!theQuery =
"SELECT " <> selectors <> " FROM "
Expand All @@ -277,9 +279,9 @@ toSQL' sqlQuery@SQLQuery { selectFrom, orderByClause, limitClause, offsetClause
<> limitClause'
<> offsetClause'

selectors :: Text
selectors :: ByteString
selectors = selectFrom <> ".*"
fromClause :: Text
fromClause :: ByteString
fromClause = selectFrom
!theParams =
case whereCondition sqlQuery of
Expand All @@ -293,13 +295,13 @@ toSQL' sqlQuery@SQLQuery { selectFrom, orderByClause, limitClause, offsetClause
orderByClause' =
case orderByClause of
[] -> mempty
xs -> " ORDER BY " <> intercalate "," ((map (\(column,direction) -> column <> (if direction == Desc then " DESC" else mempty)) xs))
xs -> " ORDER BY " <> ByteString.intercalate "," ((map (\(column,direction) -> column <> (if direction == Desc then " DESC" else mempty)) xs))
limitClause' = fromMaybe "" limitClause
offsetClause' = fromMaybe "" offsetClause
{-# INLINE toSQL' #-}

{-# INLINE compileConditionQuery #-}
compileConditionQuery :: Condition -> Text
compileConditionQuery :: Condition -> ByteString
compileConditionQuery (VarCondition var _) = var
compileConditionQuery (OrCondition a b) = "(" <> compileConditionQuery a <> ") OR (" <> compileConditionQuery b <> ")"
compileConditionQuery (AndCondition a b) = "(" <> compileConditionQuery a <> ") AND (" <> compileConditionQuery b <> ")"
Expand Down

0 comments on commit a94d0a0

Please sign in to comment.