session consumer query함수 리턴 값의 애매함을 제거

This commit is contained in:
2025-08-15 23:55:50 +09:00
parent c0ab2afcf4
commit c24d387761
6 changed files with 40 additions and 56 deletions

View File

@ -32,8 +32,8 @@ func (auth *Authorization) ToStrings() []string {
}
}
func (auth *Authorization) Invalidated() bool {
return len(auth.invalidated) > 0
func (auth *Authorization) Valid() bool {
return len(auth.invalidated) == 0 && !auth.Account.IsZero()
}
func MakeAuthrizationFromStringMap(src map[string]string) Authorization {
@ -55,7 +55,7 @@ type Provider interface {
}
type Consumer interface {
Query(string) (Authorization, error)
Query(string) Authorization
Touch(string) (Authorization, error)
IsRevoked(primitive.ObjectID) bool
Revoke(string)

View File

@ -263,25 +263,25 @@ func (c *consumer_mongo) query_internal(sk storagekey) (*sessionMongo, bool, err
return nil, false, nil
}
func (c *consumer_mongo) Query(pk string) (Authorization, error) {
func (c *consumer_mongo) Query(pk string) Authorization {
c.lock.Lock()
defer c.lock.Unlock()
sk := publickey_to_storagekey(publickey(pk))
si, _, err := c.query_internal(sk)
if err != nil {
return Authorization{}, err
return Authorization{}
}
if si == nil {
return Authorization{}, nil
return Authorization{}
}
if time.Now().After(si.Ts.Time().Add(c.ttl)) {
return Authorization{}, nil
return Authorization{}
}
return *si.Auth, nil
return *si.Auth
}
func (c *consumer_mongo) Touch(pk string) (Authorization, error) {

View File

@ -2,6 +2,7 @@ package session
import (
"context"
"errors"
"fmt"
"time"
@ -243,46 +244,49 @@ func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, error) {
expireAt: time.Now().Add(ttl),
}
if auth.Invalidated() {
c.stages[0].deleted[sk] = si
} else {
if auth.Valid() {
c.add_internal(sk, si)
} else {
c.stages[0].deleted[sk] = si
}
return si, nil
}
func (c *consumer_redis) Query(pk string) (Authorization, error) {
var errRevoked = errors.New("session revoked")
var errExpired = errors.New("session expired")
func (c *consumer_redis) Query(pk string) Authorization {
c.lock.Lock()
defer c.lock.Unlock()
sk := publickey_to_storagekey(publickey(pk))
if _, deleted := c.stages[0].deleted[sk]; deleted {
return Authorization{}, nil
return Authorization{}
}
if _, deleted := c.stages[1].deleted[sk]; deleted {
return Authorization{}, nil
return Authorization{}
}
si, err := c.query_internal(sk)
if err != nil {
logger.Println("session consumer query :", pk, err)
return Authorization{}, err
return Authorization{}
}
if si == nil {
logger.Println("session consumer query(si nil) :", pk, nil)
return Authorization{}, nil
return Authorization{}
}
if time.Now().After(si.expireAt) {
logger.Println("session consumer query(expired):", pk, nil)
return Authorization{}, nil
return Authorization{}
}
return *si.Authorization, nil
return *si.Authorization
}
func (c *consumer_redis) Touch(pk string) (Authorization, error) {

View File

@ -60,11 +60,11 @@ func TestExpTable(t *testing.T) {
go func() {
for {
q1, err := cs.Query(sk1)
logger.Println("query :", q1, err)
q1 := cs.Query(sk1)
logger.Println("query :", q1)
q2, err := cs.Query(sk2)
logger.Println("query :", q2, err)
q2 := cs.Query(sk2)
logger.Println("query :", q2)
time.Sleep(time.Second)
}
}()
@ -87,7 +87,7 @@ func TestExpTable(t *testing.T) {
t.Error(err)
}
q2, err := cs2.Query(sk2)
logger.Println("queryf :", q2, err)
q2 := cs2.Query(sk2)
logger.Println("queryf :", q2)
time.Sleep(20 * time.Second)
}

View File

@ -683,18 +683,13 @@ func (ws *WebsocketHandler) upgrade_nosession(w http.ResponseWriter, r *http.Req
accid := primitive.ObjectID(*raw)
sk := r.Header.Get("AS-X-SESSION")
authinfo, err := ws.sessionConsumer.Query(sk)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
if authinfo.Account != accid {
authinfo := ws.sessionConsumer.Query(sk)
if !authinfo.Valid() {
w.WriteHeader(http.StatusUnauthorized)
return
}
if authinfo.Invalidated() {
if authinfo.Account != accid {
w.WriteHeader(http.StatusUnauthorized)
return
}
@ -737,19 +732,8 @@ func (ws *WebsocketHandler) upgrade(w http.ResponseWriter, r *http.Request) {
return
}
authinfo, err := ws.sessionConsumer.Query(sk)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
logger.Error("authorize query failed :", err)
return
}
if authinfo.Account.IsZero() {
w.WriteHeader(http.StatusUnauthorized)
return
}
if authinfo.Invalidated() {
authinfo := ws.sessionConsumer.Query(sk)
if !authinfo.Valid() {
w.WriteHeader(http.StatusUnauthorized)
return
}

View File

@ -305,10 +305,12 @@ func (ws *websocketPeerHandler[T]) upgrade_noauth(w http.ResponseWriter, r *http
sk := r.Header.Get("AS-X-SESSION")
var accid primitive.ObjectID
if len(sk) > 0 {
authinfo, err := ws.sessionConsumer.Query(sk)
if err == nil {
accid = authinfo.Account
authinfo := ws.sessionConsumer.Query(sk)
if !authinfo.Valid() {
w.WriteHeader(http.StatusUnauthorized)
return
}
accid = authinfo.Account
}
if accid.IsZero() {
@ -363,14 +365,8 @@ func (ws *websocketPeerHandler[T]) upgrade(w http.ResponseWriter, r *http.Reques
}()
sk := r.Header.Get("AS-X-SESSION")
authinfo, err := ws.sessionConsumer.Query(sk)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
logger.Error("authorize query failed :", err)
return
}
if authinfo.Account.IsZero() || authinfo.Invalidated() {
authinfo := ws.sessionConsumer.Query(sk)
if !authinfo.Valid() {
w.WriteHeader(http.StatusUnauthorized)
return
}