From e8819b6793e7296e0d8abfc5b4d9b44482d7ebda Mon Sep 17 00:00:00 2001 From: Curtis Chin Jen Sem Date: Wed, 17 Dec 2025 11:41:44 +0100 Subject: [PATCH] Add `FromRow`/`ToRow` instances for `Solo` --- CHANGELOG.md | 3 +++ src/Database/PostgreSQL/PQTypes/Format.hs | 7 +++++++ src/Database/PostgreSQL/PQTypes/FromRow.hs | 7 +++++++ src/Database/PostgreSQL/PQTypes/ToRow.hs | 7 +++++++ test/Main.hs | 5 +++++ 5 files changed, 29 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e2a7bae..e77c594 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +# hpqtypes-1.xx.x.x (xxxx-xx-xx) +* Add `FromRow`/`ToRow` instances for `Solo` + # hpqtypes-1.14.0.0 (2025-12-10) * Make `begin`, `commit` and `rollback` do nothing instead of throwing an error if the on demand connection acquisition mode is active. diff --git a/src/Database/PostgreSQL/PQTypes/Format.hs b/src/Database/PostgreSQL/PQTypes/Format.hs index 9a05ceb..0fbebca 100644 --- a/src/Database/PostgreSQL/PQTypes/Format.hs +++ b/src/Database/PostgreSQL/PQTypes/Format.hs @@ -16,6 +16,7 @@ import Data.Proxy import Data.Text qualified as T import Data.Text.Lazy qualified as TL import Data.Time +import Data.Tuple import Data.UUID.Types import Data.Word @@ -157,6 +158,12 @@ instance pqFormat = pqFormat @t pqVariables = 1 +instance + ( PQFormat t + ) => PQFormat (Solo t) where + pqFormat = pqFormat @t + pqVariables = 1 + instance ( PQFormat t1, PQFormat t2 ) => PQFormat (t1, t2) where diff --git a/src/Database/PostgreSQL/PQTypes/FromRow.hs b/src/Database/PostgreSQL/PQTypes/FromRow.hs index fb63ffc..98f6878 100644 --- a/src/Database/PostgreSQL/PQTypes/FromRow.hs +++ b/src/Database/PostgreSQL/PQTypes/FromRow.hs @@ -6,6 +6,7 @@ module Database.PostgreSQL.PQTypes.FromRow import Control.Exception qualified as E import Data.ByteString.Unsafe qualified as BS import Data.Functor.Identity +import Data.Tuple import Foreign.C import Foreign.Marshal.Alloc import Foreign.Ptr @@ -84,6 +85,12 @@ instance FromSQL t => FromRow (Identity t) where t <- peek p1 >>= convert res i b pure (Identity t) +instance FromSQL t => FromRow (Solo t) where + fromRow res err b i = withFormat $ \fmt -> alloca $ \p1 -> do + verify err =<< c_PQgetf1 res err i fmt b p1 + t <- peek p1 >>= convert res i b + pure (MkSolo t) + instance ( FromSQL t1, FromSQL t2 ) => FromRow (t1, t2) where diff --git a/src/Database/PostgreSQL/PQTypes/ToRow.hs b/src/Database/PostgreSQL/PQTypes/ToRow.hs index f5af3bb..ea1e258 100644 --- a/src/Database/PostgreSQL/PQTypes/ToRow.hs +++ b/src/Database/PostgreSQL/PQTypes/ToRow.hs @@ -5,6 +5,7 @@ module Database.PostgreSQL.PQTypes.ToRow import Data.ByteString.Unsafe qualified as BS import Data.Functor.Identity +import Data.Tuple import Foreign.C import Foreign.Marshal.Alloc import Foreign.Ptr @@ -61,6 +62,12 @@ instance ToSQL t => ToRow (Identity t) where where Identity t = row +instance ToSQL t => ToRow (Solo t) where + toRow row pa param err = withFormat row $ \fmt -> toSQL t pa $ \base -> + verify err =<< c_PQputf1 param err fmt base + where + t = getSolo row + instance ( ToSQL t1, ToSQL t2 ) => ToRow (t1, t2) where diff --git a/test/Main.hs b/test/Main.hs index 6437192..7c2a5f8 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -16,6 +16,7 @@ import Data.Int import Data.Maybe import Data.Text qualified as T import Data.Time +import Data.Tuple import Data.Typeable import Data.UUID.Types qualified as U import Data.Word @@ -112,6 +113,9 @@ instance Arbitrary Interval where instance (Arbitrary a1, Arbitrary a2) => Arbitrary (a1 :*: a2) where arbitrary = (:*:) <$> arbitrary <*> arbitrary +instance Arbitrary a => Arbitrary (Solo a) where + arbitrary = MkSolo <$> arbitrary + instance Arbitrary a => Arbitrary (Composite a) where arbitrary = Composite <$> arbitrary @@ -688,6 +692,7 @@ tests td = , rowTest td (u :: Identity T.Text :*: (Double, Int16)) , rowTest td (u :: (T.Text, Double) :*: Identity Int16) , rowTest td (u :: (Int16, T.Text, Int64, Double) :*: Identity Bool :*: (String0, AsciiChar)) + , rowTest td (u :: Solo Int16) , rowTest td (u :: (Int16, Int32)) , rowTest td (u :: (Int16, Int32, Int64)) , rowTest td (u :: (Int16, Int32, Int64, Float))