diff --git a/mongo.go b/mongo.go index 4d6b24c..db1ed6a 100644 --- a/mongo.go +++ b/mongo.go @@ -15,6 +15,7 @@ import ( "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" + "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" ) type MongoClient struct { @@ -60,8 +61,19 @@ func (ci *ConnectionInfo) SetDatabase(dbname string) *ConnectionInfo { return ci } -func NewMongoClient(ctx context.Context, url string, dbname string) (MongoClient, error) { - return newMongoClient(ctx, NewMongoConnectionInfo(url, dbname)) +var errNoDatabaseNameInMongoUri = errors.New("mongo uri has no database name") + +func NewMongoClient(ctx context.Context, url string) (MongoClient, error) { + connstr, err := connstring.ParseAndValidate(url) + if err != nil { + return MongoClient{}, err + } + + if len(connstr.Database) == 0 { + return MongoClient{}, errNoDatabaseNameInMongoUri + } + + return newMongoClient(ctx, NewMongoConnectionInfo(url, connstr.Database)) } func newMongoClient(ctx context.Context, ci *ConnectionInfo) (MongoClient, error) { diff --git a/session/impl_mongo.go b/session/impl_mongo.go index 9fbaed5..8ecc1d7 100644 --- a/session/impl_mongo.go +++ b/session/impl_mongo.go @@ -2,14 +2,12 @@ package session import ( "context" - "errors" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" - "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "repositories.action2quare.com/ayo/gocommon" "repositories.action2quare.com/ayo/gocommon/logger" ) @@ -25,18 +23,8 @@ type sessionMongo struct { Ts primitive.DateTime `bson:"_ts"` } -var errNoDatabaseNameInMongoUri = errors.New("mongo uri has no database name") - func newProviderWithMongo(ctx context.Context, mongoUrl string, ttl time.Duration) (Provider, error) { - connstr, err := connstring.ParseAndValidate(mongoUrl) - if err != nil { - return nil, err - } - if len(connstr.Database) == 0 { - return nil, errNoDatabaseNameInMongoUri - } - - mc, err := gocommon.NewMongoClient(ctx, mongoUrl, connstr.Database) + mc, err := gocommon.NewMongoClient(ctx, mongoUrl) if err != nil { return nil, err } @@ -105,16 +93,7 @@ type sessionPipelineDocument struct { } func newConsumerWithMongo(ctx context.Context, mongoUrl string, ttl time.Duration) (Consumer, error) { - connstr, err := connstring.ParseAndValidate(mongoUrl) - if err != nil { - return nil, err - } - - if len(connstr.Database) == 0 { - return nil, errNoDatabaseNameInMongoUri - } - - mc, err := gocommon.NewMongoClient(ctx, mongoUrl, connstr.Database) + mc, err := gocommon.NewMongoClient(ctx, mongoUrl) if err != nil { return nil, err }