diff --git a/session/common.go b/session/common.go index f06ae54..25d05f2 100644 --- a/session/common.go +++ b/session/common.go @@ -32,8 +32,8 @@ func (auth *Authorization) ToStrings() []string { } } -func (auth *Authorization) Invalidated() bool { - return len(auth.invalidated) > 0 +func (auth *Authorization) Valid() bool { + return len(auth.invalidated) == 0 && !auth.Account.IsZero() } func MakeAuthrizationFromStringMap(src map[string]string) Authorization { @@ -55,7 +55,7 @@ type Provider interface { } type Consumer interface { - Query(string) (Authorization, error) + Query(string) Authorization Touch(string) (Authorization, error) IsRevoked(primitive.ObjectID) bool Revoke(string) diff --git a/session/impl_mongo.go b/session/impl_mongo.go index 1b69e7c..b6011b3 100644 --- a/session/impl_mongo.go +++ b/session/impl_mongo.go @@ -263,25 +263,25 @@ func (c *consumer_mongo) query_internal(sk storagekey) (*sessionMongo, bool, err return nil, false, nil } -func (c *consumer_mongo) Query(pk string) (Authorization, error) { +func (c *consumer_mongo) Query(pk string) Authorization { c.lock.Lock() defer c.lock.Unlock() sk := publickey_to_storagekey(publickey(pk)) si, _, err := c.query_internal(sk) if err != nil { - return Authorization{}, err + return Authorization{} } if si == nil { - return Authorization{}, nil + return Authorization{} } if time.Now().After(si.Ts.Time().Add(c.ttl)) { - return Authorization{}, nil + return Authorization{} } - return *si.Auth, nil + return *si.Auth } func (c *consumer_mongo) Touch(pk string) (Authorization, error) { diff --git a/session/impl_redis.go b/session/impl_redis.go index edb1e64..fe5e0db 100644 --- a/session/impl_redis.go +++ b/session/impl_redis.go @@ -2,6 +2,7 @@ package session import ( "context" + "errors" "fmt" "time" @@ -243,46 +244,49 @@ func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, error) { expireAt: time.Now().Add(ttl), } - if auth.Invalidated() { - c.stages[0].deleted[sk] = si - } else { + if auth.Valid() { c.add_internal(sk, si) + } else { + c.stages[0].deleted[sk] = si } return si, nil } -func (c *consumer_redis) Query(pk string) (Authorization, error) { +var errRevoked = errors.New("session revoked") +var errExpired = errors.New("session expired") + +func (c *consumer_redis) Query(pk string) Authorization { c.lock.Lock() defer c.lock.Unlock() sk := publickey_to_storagekey(publickey(pk)) if _, deleted := c.stages[0].deleted[sk]; deleted { - return Authorization{}, nil + return Authorization{} } if _, deleted := c.stages[1].deleted[sk]; deleted { - return Authorization{}, nil + return Authorization{} } si, err := c.query_internal(sk) if err != nil { logger.Println("session consumer query :", pk, err) - return Authorization{}, err + return Authorization{} } if si == nil { logger.Println("session consumer query(si nil) :", pk, nil) - return Authorization{}, nil + return Authorization{} } if time.Now().After(si.expireAt) { logger.Println("session consumer query(expired):", pk, nil) - return Authorization{}, nil + return Authorization{} } - return *si.Authorization, nil + return *si.Authorization } func (c *consumer_redis) Touch(pk string) (Authorization, error) { diff --git a/session/session_test.go b/session/session_test.go index c8cf7ec..1f1562b 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -60,11 +60,11 @@ func TestExpTable(t *testing.T) { go func() { for { - q1, err := cs.Query(sk1) - logger.Println("query :", q1, err) + q1 := cs.Query(sk1) + logger.Println("query :", q1) - q2, err := cs.Query(sk2) - logger.Println("query :", q2, err) + q2 := cs.Query(sk2) + logger.Println("query :", q2) time.Sleep(time.Second) } }() @@ -87,7 +87,7 @@ func TestExpTable(t *testing.T) { t.Error(err) } - q2, err := cs2.Query(sk2) - logger.Println("queryf :", q2, err) + q2 := cs2.Query(sk2) + logger.Println("queryf :", q2) time.Sleep(20 * time.Second) } diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index 263eaab..d3fa385 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -683,18 +683,13 @@ func (ws *WebsocketHandler) upgrade_nosession(w http.ResponseWriter, r *http.Req accid := primitive.ObjectID(*raw) sk := r.Header.Get("AS-X-SESSION") - authinfo, err := ws.sessionConsumer.Query(sk) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - if authinfo.Account != accid { + authinfo := ws.sessionConsumer.Query(sk) + if !authinfo.Valid() { w.WriteHeader(http.StatusUnauthorized) return } - if authinfo.Invalidated() { + if authinfo.Account != accid { w.WriteHeader(http.StatusUnauthorized) return } @@ -737,19 +732,8 @@ func (ws *WebsocketHandler) upgrade(w http.ResponseWriter, r *http.Request) { return } - authinfo, err := ws.sessionConsumer.Query(sk) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - logger.Error("authorize query failed :", err) - return - } - - if authinfo.Account.IsZero() { - w.WriteHeader(http.StatusUnauthorized) - return - } - - if authinfo.Invalidated() { + authinfo := ws.sessionConsumer.Query(sk) + if !authinfo.Valid() { w.WriteHeader(http.StatusUnauthorized) return } diff --git a/wshandler/wshandler_peer.go b/wshandler/wshandler_peer.go index 96e3ebc..f1af5a7 100644 --- a/wshandler/wshandler_peer.go +++ b/wshandler/wshandler_peer.go @@ -305,10 +305,12 @@ func (ws *websocketPeerHandler[T]) upgrade_noauth(w http.ResponseWriter, r *http sk := r.Header.Get("AS-X-SESSION") var accid primitive.ObjectID if len(sk) > 0 { - authinfo, err := ws.sessionConsumer.Query(sk) - if err == nil { - accid = authinfo.Account + authinfo := ws.sessionConsumer.Query(sk) + if !authinfo.Valid() { + w.WriteHeader(http.StatusUnauthorized) + return } + accid = authinfo.Account } if accid.IsZero() { @@ -363,14 +365,8 @@ func (ws *websocketPeerHandler[T]) upgrade(w http.ResponseWriter, r *http.Reques }() sk := r.Header.Get("AS-X-SESSION") - authinfo, err := ws.sessionConsumer.Query(sk) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - logger.Error("authorize query failed :", err) - return - } - - if authinfo.Account.IsZero() || authinfo.Invalidated() { + authinfo := ws.sessionConsumer.Query(sk) + if !authinfo.Valid() { w.WriteHeader(http.StatusUnauthorized) return }