44 "context"
55 "errors"
66 "fmt"
7+ "net/url"
8+ "strconv"
9+ "strings"
710 "sync"
811 "time"
912
@@ -58,16 +61,19 @@ const (
5861
5962// New creates a new postgresql database driver.
6063func New (ctx context.Context , opts Options ) (* Database , error ) {
61- var sslMode string
62-
63- // Create database object
64- db := Database {}
65- db .err .mutex = sync.Mutex {}
66-
67- // Setup basic configuration options
64+ // Validate options
65+ if len (opts .Host ) == 0 {
66+ return nil , errors .New ("invalid host" )
67+ }
68+ if len (opts .User ) == 0 {
69+ return nil , errors .New ("invalid user name" )
70+ }
71+ if len (opts .Name ) == 0 {
72+ return nil , errors .New ("invalid database name" )
73+ }
74+ sslMode := "disable"
6875 switch opts .SSLMode {
6976 case SSLModeDisable :
70- sslMode = "disable"
7177 case SSLModeAllow :
7278 sslMode = "prefer"
7379 case SSLModeRequired :
@@ -76,6 +82,10 @@ func New(ctx context.Context, opts Options) (*Database, error) {
7682 return nil , errors .New ("invalid SSL mode" )
7783 }
7884
85+ // Create database object
86+ db := Database {}
87+ db .err .mutex = sync.Mutex {}
88+
7989 connString := fmt .Sprintf (
8090 "host='%s' port=%d user='%s' password='%s' dbname='%s' sslmode=%s" ,
8191 encodeDSN (opts .Host ), opts .Port , encodeDSN (opts .User ), encodeDSN (opts .Password ), encodeDSN (opts .Name ),
@@ -110,6 +120,82 @@ func New(ctx context.Context, opts Options) (*Database, error) {
110120 return & db , nil
111121}
112122
123+ // NewFromURL creates a new postgresql database driver from an URL
124+ func NewFromURL (ctx context.Context , rawUrl string ) (* Database , error ) {
125+ opts := Options {}
126+
127+ u , err := url .ParseRequestURI (rawUrl )
128+ if err != nil {
129+ return nil , errors .New ("invalid url provided" )
130+ }
131+
132+ // Check schema
133+ if u .Scheme != "pg" && u .Scheme != "postgres" && u .Scheme != "postgresql" {
134+ return nil , errors .New ("invalid url schema" )
135+ }
136+
137+ // Check host name and port
138+ opts .Host = u .Hostname ()
139+ if len (opts .Host ) == 0 {
140+ return nil , errors .New ("invalid host" )
141+ }
142+ s := u .Port ()
143+ if len (s ) == 0 {
144+ opts .Port = 5432
145+ } else {
146+ val , err2 := strconv .Atoi (s )
147+ if err2 != nil || val < 1 || val > 65535 {
148+ return nil , errors .New ("invalid port" )
149+ }
150+ opts .Port = uint16 (val )
151+ }
152+
153+ // Check user and password
154+ if u .User == nil {
155+ return nil , errors .New ("invalid user name" )
156+ }
157+ opts .User = u .User .Username ()
158+ if len (opts .User ) == 0 {
159+ return nil , errors .New ("invalid user name" )
160+ }
161+
162+ // Check database name
163+ if len (u .Path ) < 1 || (! strings .HasPrefix (u .Path , "/" )) || strings .Index (u .Path [1 :], "/" ) >= 0 {
164+ return nil , errors .New ("invalid database name" )
165+ }
166+ opts .Name = u .Path [1 :]
167+
168+ // Check ssl mode
169+ opts .SSLMode = SSLModeDisable
170+ switch u .Query ().Get ("sslmode" ) {
171+ case "allow" :
172+ opts .SSLMode = SSLModeAllow
173+
174+ case "required" :
175+ opts .SSLMode = SSLModeRequired
176+
177+ case "disabled" :
178+ fallthrough
179+ case "" :
180+
181+ default :
182+ return nil , errors .New ("invalid SSL mode" )
183+ }
184+
185+ // Check max connections count
186+ s = u .Query ().Get ("maxconn" )
187+ if len (s ) > 0 {
188+ val , err2 := strconv .Atoi (s )
189+ if err2 != nil || val < 0 {
190+ return nil , errors .New ("invalid max connections count" )
191+ }
192+ opts .MaxConns = int32 (val )
193+ }
194+
195+ // Create
196+ return New (ctx , opts )
197+ }
198+
113199// Close shutdown the connection pool
114200func (db * Database ) Close () {
115201 if db .pool != nil {
0 commit comments