store.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # GNU MediaGoblin -- federated, autonomous media hosting
  2. # Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS.
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  16. import base64
  17. import time
  18. import six
  19. from openid.association import Association as OIDAssociation
  20. from openid.store.interface import OpenIDStore
  21. from openid.store import nonce
  22. from mediagoblin.plugins.openid.models import Association, Nonce
  23. class SQLAlchemyOpenIDStore(OpenIDStore):
  24. def __init__(self):
  25. self.max_nonce_age = 6 * 60 * 60
  26. def storeAssociation(self, server_url, association):
  27. assoc = Association.query.filter_by(
  28. server_url=server_url, handle=association.handle
  29. ).first()
  30. if not assoc:
  31. assoc = Association()
  32. assoc.server_url = six.text_type(server_url)
  33. assoc.handle = association.handle
  34. # django uses base64 encoding, python-openid uses a blob field for
  35. # secret
  36. assoc.secret = six.text_type(base64.encodestring(association.secret))
  37. assoc.issued = association.issued
  38. assoc.lifetime = association.lifetime
  39. assoc.assoc_type = association.assoc_type
  40. assoc.save()
  41. def getAssociation(self, server_url, handle=None):
  42. assocs = []
  43. if handle is not None:
  44. assocs = Association.query.filter_by(
  45. server_url=server_url, handle=handle
  46. )
  47. else:
  48. assocs = Association.query.filter_by(
  49. server_url=server_url
  50. )
  51. if assocs.count() == 0:
  52. return None
  53. else:
  54. associations = []
  55. for assoc in assocs:
  56. association = OIDAssociation(
  57. assoc.handle, base64.decodestring(assoc.secret),
  58. assoc.issued, assoc.lifetime, assoc.assoc_type
  59. )
  60. if association.getExpiresIn() == 0:
  61. assoc.delete()
  62. else:
  63. associations.append((association.issued, association))
  64. if not associations:
  65. return None
  66. associations.sort()
  67. return associations[-1][1]
  68. def removeAssociation(self, server_url, handle):
  69. assocs = Association.query.filter_by(
  70. server_url=server_url, handle=handle
  71. ).first()
  72. assoc_exists = True if assocs else False
  73. for assoc in assocs:
  74. assoc.delete()
  75. return assoc_exists
  76. def useNonce(self, server_url, timestamp, salt):
  77. if abs(timestamp - time.time()) > nonce.SKEW:
  78. return False
  79. ononce = Nonce.query.filter_by(
  80. server_url=server_url,
  81. timestamp=timestamp,
  82. salt=salt
  83. ).first()
  84. if ononce:
  85. return False
  86. else:
  87. ononce = Nonce()
  88. ononce.server_url = server_url
  89. ononce.timestamp = timestamp
  90. ononce.salt = salt
  91. ononce.save()
  92. return True
  93. def cleanupNonces(self, _now=None):
  94. if _now is None:
  95. _now = int(time.time())
  96. expired = Nonce.query.filter(
  97. Nonce.timestamp < (_now - nonce.SKEW)
  98. )
  99. count = expired.count()
  100. for each in expired:
  101. each.delete()
  102. return count
  103. def cleanupAssociations(self):
  104. now = int(time.time())
  105. assoc = Association.query.all()
  106. count = 0
  107. for each in assoc:
  108. if (each.lifetime + each.issued) <= now:
  109. each.delete()
  110. count = count + 1
  111. return count