Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""LICENSE 

2Copyright 2020 Hermann Krumrey <hermann@krumreyh.com> 

3 

4This file is part of jerrycan. 

5 

6jerrycan is free software: you can redistribute it and/or modify 

7it under the terms of the GNU General Public License as published by 

8the Free Software Foundation, either version 3 of the License, or 

9(at your option) any later version. 

10 

11jerrycan is distributed in the hope that it will be useful, 

12but WITHOUT ANY WARRANTY; without even the implied warranty of 

13MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

14GNU General Public License for more details. 

15 

16You should have received a copy of the GNU General Public License 

17along with jerrycan. If not, see <http://www.gnu.org/licenses/>. 

18LICENSE""" 

19 

20import sys 

21import base64 

22import binascii 

23import logging 

24import sentry_sdk 

25import traceback 

26from logging.handlers import TimedRotatingFileHandler 

27from sqlalchemy.exc import OperationalError 

28from sentry_sdk.integrations.logging import LoggingIntegration 

29from sentry_sdk.integrations.flask import FlaskIntegration 

30from typing import List, Optional, Type, Callable, Tuple, Dict, Any 

31from flask import redirect, url_for, flash, render_template 

32from flask.logging import default_handler 

33from flask.blueprints import Blueprint 

34from werkzeug.exceptions import HTTPException 

35from jerrycan.Config import Config 

36from jerrycan.base import app, login_manager, db 

37from jerrycan.enums import AlertSeverity 

38from jerrycan.db.User import User 

39from jerrycan.db.ApiKey import ApiKey 

40from jerrycan.db.TelegramChatId import TelegramChatId 

41from jerrycan.routes import blueprint_generators \ 

42 as default_blueprint_generators 

43 

44 

45CREATED_BLUEPRINTS = [] 

46""" 

47Keeps track of created blueprint names. 

48This is necessary for unit testing with nose, because duplicate blueprint names 

49will cause errors. 

50""" 

51 

52 

53def init_flask( 

54 module_name: str, 

55 sentry_dsn: str, 

56 root_path: str, 

57 config: Type[Config], 

58 models: List[Type[db.Model]], 

59 blueprint_generators: List[Tuple[Callable[[str], Blueprint], str]], 

60 extra_jinja_vars: Optional[Dict[str, Any]] = None 

61): 

62 """ 

63 Initializes the flask application 

64 :param module_name: The name of the module 

65 :param sentry_dsn: The sentry DSN used for error logging 

66 :param root_path: The root path of the flask application 

67 :param config: The Config class to use for configuration 

68 :param models: The database models to create 

69 :param blueprint_generators: Tuples that contain a function that generates 

70 a blueprint and the name of the blueprint 

71 :param extra_jinja_vars: Any extra variables to pass to jinja 

72 :return: None 

73 """ 

74 app.root_path = root_path 

75 config.load_config(root_path, module_name, sentry_dsn) 

76 __init_logging(config) 

77 

78 default_models = [ 

79 User, 

80 ApiKey, 

81 TelegramChatId 

82 ] 

83 

84 if extra_jinja_vars is None: 

85 extra_jinja_vars = {} 

86 __init_app( 

87 config, 

88 default_blueprint_generators + blueprint_generators, 

89 extra_jinja_vars 

90 ) 

91 __init_db(config, default_models + models) 

92 __init_login_manager(config) 

93 

94 

95def __init_logging(config: Type[Config]): 

96 """ 

97 Sets up logging to a logfile 

98 :param config: The configuration to use 

99 :return: None 

100 """ 

101 sentry_logging = LoggingIntegration( 

102 level=logging.INFO, 

103 event_level=None 

104 ) 

105 sentry_sdk.init( 

106 dsn=config.SENTRY_DSN, 

107 integrations=[FlaskIntegration(), sentry_logging] 

108 ) 

109 

110 app.logger.removeHandler(default_handler) 

111 

112 log_format = \ 

113 "[%(asctime)s, %(levelname)s] %(module)s[%(lineno)d]: %(message)s" 

114 formatter = logging.Formatter(log_format) 

115 

116 info_handler = TimedRotatingFileHandler( 

117 config.LOGGING_PATH, 

118 when="midnight", 

119 interval=1, 

120 backupCount=7 

121 ) 

122 info_handler.setLevel(logging.INFO) 

123 info_handler.setFormatter(formatter) 

124 

125 debug_handler = TimedRotatingFileHandler( 

126 config.DEBUG_LOGGING_PATH, 

127 when="midnight", 

128 interval=1, 

129 backupCount=7 

130 ) 

131 debug_handler.setLevel(logging.DEBUG) 

132 debug_handler.setFormatter(formatter) 

133 

134 stream_handler = logging.StreamHandler(stream=sys.stdout) 

135 stream_handler.setLevel(config.VERBOSITY) 

136 stream_handler.setFormatter(formatter) 

137 

138 app.logger.addHandler(info_handler) 

139 app.logger.addHandler(debug_handler) 

140 app.logger.addHandler(stream_handler) 

141 

142 app.logger.setLevel(logging.DEBUG) 

143 

144 

145def __init_app( 

146 config: Type[Config], 

147 blueprint_generators: List[Tuple[Callable[[str], Blueprint], str]], 

148 extra_jinja_vars: Dict[str, Any] 

149): 

150 """ 

151 Initializes the flask app 

152 :param config: The configuration to use 

153 :param blueprint_generators: Tuples that contain a function that generates 

154 a blueprint and the name of the blueprint 

155 :param extra_jinja_vars: Any extra variables to pass to jinja 

156 :return: None 

157 """ 

158 app.testing = config.TESTING 

159 app.config["TRAP_HTTP_EXCEPTIONS"] = True 

160 app.config["SERVER_NAME"] = Config.base_url().split("://", 1)[1] 

161 if Config.BEHIND_PROXY: 

162 app.config["PREFERRED_URL_SCHEME"] = "https" 

163 app.secret_key = config.FLASK_SECRET 

164 for blueprint_generator, blueprint_name in blueprint_generators: 

165 if blueprint_name in CREATED_BLUEPRINTS: 

166 app.logger.debug(f"Blueprint {blueprint_name} already created") 

167 continue 

168 else: 

169 app.logger.info(f"Creating blueprint {blueprint_name}") 

170 CREATED_BLUEPRINTS.append(blueprint_name) 

171 blueprint = blueprint_generator(blueprint_name) 

172 app.register_blueprint(blueprint) 

173 

174 @app.context_processor 

175 def inject_template_variables(): 

176 """ 

177 Injects the project's version string so that it will be available 

178 in templates 

179 :return: The dictionary to inject 

180 """ 

181 defaults = { 

182 "version": config.VERSION, 

183 "env": app.env, 

184 "config": config 

185 } 

186 defaults.update(extra_jinja_vars) 

187 return defaults 

188 

189 @app.errorhandler(Exception) 

190 def exception_handling(e: Exception): 

191 """ 

192 Handles any uncaught exceptions and shows an applicable error page 

193 :param e: The caught exception 

194 :return: The response to the exception 

195 """ 

196 if isinstance(e, HTTPException): 

197 error = e 

198 if e.code == 401: 

199 flash( 

200 config.STRINGS["401_message"], 

201 AlertSeverity.DANGER.value 

202 ) 

203 return redirect(url_for("user_management.login")) 

204 app.logger.warning("Caught HTTP exception: {}".format(e)) 

205 else: 

206 error = HTTPException(config.STRINGS["500_message"]) 

207 error.code = 500 

208 trace = "".join(traceback.format_exception(*sys.exc_info())) 

209 app.logger.error("Caught exception: {}\n{}".format(e, trace)) 

210 sentry_sdk.capture_exception(e) 

211 return render_template( 

212 config.REQUIRED_TEMPLATES["error_page"], 

213 error=error 

214 ) 

215 

216 @app.errorhandler(HTTPException) # type: ignore 

217 def unauthorized_handling(e: HTTPException): 

218 """ 

219 Forwards HTTP exceptions to the error handler 

220 :param e: The HTTPException 

221 :return: The response to the exception 

222 """ 

223 return exception_handling(e) 

224 

225 

226def __init_db(config: Type[Config], models: List[db.Model]): 

227 """ 

228 Initializes the database 

229 :param config: The configuration to use 

230 :param models: The models to create in the database 

231 :return: None 

232 """ 

233 app.config["SQLALCHEMY_DATABASE_URI"] = config.DB_URI 

234 app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False 

235 

236 # Makes sure that we don't get errors because 

237 # of an idle database connection 

238 app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"pool_pre_ping": True} 

239 

240 db.init_app(app) 

241 

242 for model in models: 

243 app.logger.debug(f"Loading model {model.__name__}") 

244 

245 with app.app_context(): 

246 try: 

247 db.create_all() 

248 except OperationalError: 

249 print("Failed to connect to the database") 

250 sys.exit(1) 

251 

252 

253def __init_login_manager(config: Type[Config]): 

254 """ 

255 Initializes the login manager 

256 :param config: The configuration to use 

257 :return: None 

258 """ 

259 login_manager.session_protection = config.SESSION_PROTECTION 

260 

261 # Set up login manager 

262 @login_manager.user_loader 

263 def load_user(user_id: str) -> Optional[User]: 

264 """ 

265 Loads a user from an ID 

266 :param user_id: The ID 

267 :return: The User 

268 """ 

269 return User.query.get(int(user_id)) 

270 

271 @login_manager.request_loader 

272 def load_user_from_request(request) -> Optional[User]: 

273 """ 

274 Loads a user pased on a provided API key 

275 :param request: The request containing the API key in the headers 

276 :return: The user or None if no valid API key was provided 

277 """ 

278 if "Authorization" not in request.headers: 

279 return None 

280 

281 api_key = request.headers["Authorization"].replace("Basic ", "", 1) 

282 

283 try: 

284 api_key = base64.b64decode( 

285 api_key.encode("utf-8") 

286 ).decode("utf-8") 

287 except (TypeError, binascii.Error): 

288 return None 

289 

290 db_api_key = ApiKey.query.get(api_key.split(":", 1)[0]) 

291 

292 # Check for validity of API key 

293 if db_api_key is None or not db_api_key.verify_key(api_key): 

294 return None 

295 

296 elif db_api_key.has_expired(): 

297 db.session.delete(db_api_key) 

298 db.session.commit() 

299 return None 

300 

301 return User.query.get(db_api_key.user_id)