2 Commits 852f15678b ... d576c1c344

Author SHA1 Message Date
  Adam Pioterek d576c1c344 decrypt auth token (fully authenticate an account) 4 months ago
  Adam Pioterek 9f4b5e642e refactor authentication 4 months ago
2 changed files with 208 additions and 59 deletions
  1. 208 58
      auth.go
  2. 0 1
      go.mod

+ 208 - 58
auth.go

@@ -5,6 +5,7 @@ import (
 	"crypto/sha512"
 	"encoding/base64"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -150,13 +151,16 @@ func decodeBytes(response ...interface{}) (interface{}, error) {
 		Finish()
 }
 
+type passGetter func() (string, error)
+
 type Account struct {
-	username     string
-	password     string
-	twoFactor    string
-	authToken    string
-	refreshToken string
-	authInfo     AuthInfo
+	username       string
+	password       string
+	passGetterFunc passGetter
+	twoFactor      string
+	authToken      string
+	refreshToken   string
+	authInfo       AuthInfo
 }
 
 func NewAccount(username string) Account {
@@ -196,6 +200,21 @@ func (a *Account) SetCredentials(password, totp string) {
 	a.twoFactor = totp
 }
 
+func makeX(values ...interface{}) (interface{}, error) {
+	s := values[0].(*srp)
+	p := [4]interface{}{s.password, s.a.SaltBytes, s.a.ModulusBytes, nil}
+	x, err := gott.NewResult(gott.Tuple(p[:])).
+		Map(makeFullSalt).
+		Bind(bcrypt).
+		Map(replaceBcryptVersion).
+		Map(appendModulus).
+		Map(hash).
+		Map(convertToBig).
+		Finish()
+	s.x = x.(*big.Int)
+	return s, err
+}
+
 func makeFullSalt(values ...interface{}) interface{} {
 	salt := values[1].([]byte)
 	values[1] = append(salt, "proton"...)
@@ -244,19 +263,6 @@ func convertToBig(bytes ...interface{}) interface{} {
 	return x.SetBytes(a)
 }
 
-func makeX(password string, salt, modulus []byte) (*big.Int, error) {
-	p := [4]interface{}{password, salt, modulus, nil}
-	x, err := gott.NewResult(gott.Tuple(p[:])).
-		Map(makeFullSalt).
-		Bind(bcrypt).
-		Map(replaceBcryptVersion).
-		Map(appendModulus).
-		Map(hash).
-		Map(convertToBig).
-		Finish()
-	return x.(*big.Int), err
-}
-
 func convertToBytes(number interface{}) interface{} {
 	n := number.(*big.Int)
 	bytes := n.Bytes()
@@ -270,36 +276,54 @@ func convertToBytes(number interface{}) interface{} {
 	return bytes
 }
 
-func makeMultiplier(generator, modulus *big.Int) *big.Int {
-	modulusBytes := convertToBytes(modulus).([]byte)
-	gB := convertToBytes(generator).([]byte)
+func makeMultiplier(values ...interface{}) interface{} {
+	s := values[0].(*srp)
+	modulusBytes := convertToBytes(s.modulus).([]byte)
+	gB := convertToBytes(s.generator).([]byte)
 	h := hash(append(gB, modulusBytes...))
 	m := convertToBig(h).(*big.Int)
-	return m.Mod(m, modulus)
+	s.multiplier = m.Mod(m, s.modulus)
+	return s
 }
 
-func makeSecret(modulus *big.Int) (*big.Int, error) {
+func makeSecret(values ...interface{}) (interface{}, error) {
+	s := values[0].(*srp)
 	key := [128]byte{}
 	_, err := rand.Read(key[:])
-	return convertToBig(key[:]).(*big.Int), err
+	s.secret = convertToBig(key[:]).(*big.Int)
+	return s, err
+}
+
+func makeEphemeral(values ...interface{}) interface{} {
+	s := values[0].(*srp)
+	ephemeral := big.NewInt(0)
+	s.ephemeral = convertToBytes(ephemeral.Exp(s.generator, s.secret, s.modulus)).([]byte)
+	return s
 }
 
-func makeU(clientEphemeral, serverEphemeral []byte) *big.Int {
-	return convertToBig(hash(append(clientEphemeral, serverEphemeral...)).([]byte)).(*big.Int)
+func makeU(values ...interface{}) interface{} {
+	s := values[0].(*srp)
+	s.u = convertToBig(hash(append(s.ephemeral, s.a.ServerEphemeralBytes...)).([]byte)).(*big.Int)
+	return s
 }
 
-func makeShared(serverEphemeral []byte, multiplier, generator, x, modulus, clientSecret, u *big.Int) []byte {
-	base := convertToBig(serverEphemeral).(*big.Int)
-	base.Sub(base, multiplier.Mul(multiplier, generator.Exp(generator, x, modulus)))
-	exponent := clientSecret.Add(clientSecret, u.Mul(u, x))
-	return convertToBytes(base.Exp(base, exponent, modulus)).([]byte)
+func makeShared(values ...interface{}) interface{} {
+	s := values[0].(*srp)
+	base := convertToBig(s.a.ServerEphemeralBytes).(*big.Int)
+	base.Sub(base, s.multiplier.Mul(s.multiplier, s.generator.Exp(s.generator, s.x, s.modulus)))
+	exponent := s.secret.Add(s.secret, s.u.Mul(s.u, s.x))
+	s.shared = convertToBytes(base.Exp(base, exponent, s.modulus)).([]byte)
+	return s
 }
 
-func makeProofs(clientEphemeral, serverEphemeral, shared []byte) (string, string) {
+func makeProofs(values ...interface{}) interface{} {
+	s := values[0].(*srp)
 	encoder := base64.StdEncoding
-	clientProof := hash(append(append(clientEphemeral, serverEphemeral...), shared...)).([]byte)
-	serverProof := hash(append(append(clientEphemeral, clientProof...), shared...)).([]byte)
-	return encoder.EncodeToString(clientProof), encoder.EncodeToString(serverProof)
+	clientProof := hash(append(append(s.ephemeral, s.a.ServerEphemeralBytes...), s.shared...)).([]byte)
+	serverProof := hash(append(append(s.ephemeral, clientProof...), s.shared...)).([]byte)
+	s.clientProof = encoder.EncodeToString(clientProof)
+	s.serverProofExpected = encoder.EncodeToString(serverProof)
+	return s
 }
 
 func prepareAuthRequest(request ...interface{}) (interface{}, error) {
@@ -308,15 +332,23 @@ func prepareAuthRequest(request ...interface{}) (interface{}, error) {
 		strings.NewReader(string(request[0].([]byte))))
 }
 
-func requestAuth(username, clientEphemeral, clientProof, sessionKey, twoFactor string) (string, error) {
+func unmarshalAuthResponse(values ...interface{}) (interface{}, error) {
+	response := values[0].([]byte)
+	var res authResult
+	err := json.Unmarshal(response, &res)
+	return res, err
+}
+
+func requestAuth(values ...interface{}) (interface{}, error) {
+	s := values[0].(*srp)
 	m := map[string]interface{}{
-		"Username":        username,
+		"Username":        s.username,
 		"ClientID":        "Cobalt",
 		"ClientSecret":    "4957cc9a2e0a2a49d02475c9d013478d",
-		"ClientEphemeral": clientEphemeral,
-		"ClientProof":     clientProof,
-		"SRPSession":      sessionKey,
-		"TwoFactorCode":   twoFactor,
+		"ClientEphemeral": base64.StdEncoding.EncodeToString(s.ephemeral),
+		"ClientProof":     s.clientProof,
+		"SRPSession":      s.sessionKey,
+		"TwoFactorCode":   s.twoFactor,
 	}
 	response, err := gott.NewResult(m).
 		Bind(marshalRequest).
@@ -325,39 +357,157 @@ func requestAuth(username, clientEphemeral, clientProof, sessionKey, twoFactor s
 		Bind(doRequest).
 		Bind(checkResponse).
 		Bind(readResponse).
+		Bind(unmarshalAuthResponse).
 		Finish()
 
+	r := response.(authResult)
+	r.passGetterFunc = s.passGetterFunc
+	r.password = s.password
+
 	if err != nil {
 		return "", err
 	}
-	return string(response.([]byte)), err
+	t := [2]interface{}{
+		r,
+		s.serverProofExpected,
+	}
+	return gott.Tuple(t[:]), err
+}
+
+func getKeyPassword(values ...interface{}) (interface{}, error) {
+	r := values[0].(authResult)
+	if r.PasswordMode != 1 {
+		password, err := r.passGetterFunc()
+		if err != nil {
+			return values, err
+		} else {
+			r.password = password
+		}
+	}
+	values[0] = r
+	return gott.Tuple(values), nil
 }
 
-func (a *Account) Authenticate() error {
-	x, err := makeX(a.password, a.authInfo.SaltBytes, a.authInfo.ModulusBytes)
+func verifyServerProof(values ...interface{}) (interface{}, error) {
+	authInfo := values[0].(authResult)
+	serverProofExpected := values[1].(string)
+	if serverProofExpected != authInfo.ServerProof {
+		return authResult{}, errors.New("Invalid server credentials")
+	} else {
+		return authInfo, nil
+	}
+}
+
+func computePrivateKeyPassword(values ...interface{}) (interface{}, error) {
+	authInfo := values[0].(authResult)
+	salt, err := base64.StdEncoding.DecodeString(authInfo.KeySalt)
 	if err != nil {
-		return err
+		return authInfo, err
 	}
-	generator := big.NewInt(2)
-	modulus := convertToBig(a.authInfo.ModulusBytes).(*big.Int)
-	multiplier := makeMultiplier(generator, modulus)
-	secret, err := makeSecret(modulus)
+	p := [4]interface{}{authInfo.password, salt, nil, nil}
+	hash, err := bcrypt(gott.Tuple(p[:])...)
 	if err != nil {
-		return err
+		return authInfo, err
 	}
-	ephemeral := big.NewInt(0)
-	ephemeralBytes := convertToBytes(ephemeral.Exp(generator, secret, modulus)).([]byte)
-	u := makeU(ephemeralBytes, a.authInfo.ServerEphemeralBytes)
-	shared := makeShared(a.authInfo.ServerEphemeralBytes, multiplier, generator, x, modulus, secret, u)
-	clientProof, serverProofExpected := makeProofs(ephemeralBytes, a.authInfo.ServerEphemeralBytes, shared)
+	hash = replaceBcryptVersion(hash.(gott.Tuple)...).(gott.Tuple)[3].([]byte)[29:]
+	authInfo.password = string(hash.([]byte))
+	return authInfo, nil
+}
+
+func decryptToken(values ...interface{}) (interface{}, error) {
+	authInfo := values[0].(authResult)
+	token, err := helper.DecryptMessageArmored(authInfo.PrivateKey, authInfo.password, authInfo.AccessToken)
+	authInfo.accessTokenDecrypted = token
+	return authInfo, err
+}
+
+type srp struct {
+	password string
+	a        AuthInfo
+
+	x                   *big.Int
+	generator           *big.Int
+	modulus             *big.Int
+	multiplier          *big.Int
+	secret              *big.Int
+	ephemeral           []byte
+	u                   *big.Int
+	shared              []byte
+	clientProof         string
+	serverProofExpected string
+	username            string
+	sessionKey          string
+	twoFactor           string
+
+	passGetterFunc passGetter
+}
+
+type authResult struct {
+	Code                   int                    `json:"Code"`
+	AccessToken            string                 `json:"AccessToken"`
+	ExpiresIn              int                    `json:"ExpiresIn"`
+	TokenType              string                 `json:"TokenType"`
+	Scope                  string                 `json:"Scope"`
+	Uid                   string                 `json:"Uid"`
+	UID                   string                 `json:"UID"`
+	UserID                 string                 `json:"UserID"`
+	RefreshToken           string                 `json:"RefreshToken"`
+	EventID                string                 `json:"EventID"`
+	PasswordMode           int                    `json:"PasswordMode"`
+	TwoFactorOld           int                    `json:"TwoFactor"`
+	TwoFactorNew           map[string]interface{} `json:"2FA"`
+	PrivateKey             string                 `json:"PrivateKey"`
+	EncPrivateKeyEncrypted string                 `json:"EncPrivateKey"`
+	KeySalt                string                 `json:"KeySalt"`
+	ServerProof            string                 `json:"ServerProof"`
+	expiryDate             time.Time
+	accessTokenDecrypted   string
+	passGetterFunc         passGetter
+	password               string
+}
+
+func (a *Account) SetPassGetter(f passGetter) {
+	a.passGetterFunc = f
+}
+
+func (a *Account) Authenticate() error {
+	s := srp{
+		password:       a.password,
+		a:              a.authInfo,
+		generator:      big.NewInt(2),
+		modulus:        convertToBig(a.authInfo.ModulusBytes).(*big.Int),
+		username:       a.username,
+		sessionKey:     a.authInfo.SessionKey,
+		twoFactor:      a.twoFactor,
+		passGetterFunc: a.passGetterFunc,
+	}
+	response, err := gott.NewResult(&s).
+		Bind(makeX).
+		Map(makeMultiplier).
+		Bind(makeSecret).
+		Map(makeEphemeral).
+		Map(makeU).
+		Map(makeShared).
+		Map(makeProofs).
+		Bind(requestAuth).
+		Bind(getKeyPassword).
+		Bind(verifyServerProof).
+		Bind(computePrivateKeyPassword).
+		Bind(decryptToken).
+		Finish()
 
-	response, err := requestAuth(a.username, base64.StdEncoding.EncodeToString(ephemeralBytes), clientProof, a.authInfo.SessionKey, a.twoFactor)
+	a.authToken = response.(authResult).accessTokenDecrypted
+	a.refreshToken = response.(authResult).RefreshToken
 
 	if err != nil {
 		return err
 	}
 
-	fmt.Printf("response:\n'''\n%s\n''',\nExpected: %s\n", response, serverProofExpected)
+
+	a.authInfo = AuthInfo{}
+	a.passGetterFunc = nil
+	a.password = ""
+	a.twoFactor = ""
 
 	return err
 }

+ 0 - 1
go.mod

@@ -6,7 +6,6 @@ require (
 	github.com/ProtonMail/gopenpgp v1.0.0
 	github.com/cruxic/bcrypt v0.0.0-00010101000000-000000000000
 	github.com/stretchr/testify v1.3.0 // indirect
-	golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4 // indirect
 	notabug.org/apiote/gott v1.0.1
 )