Skip to main content

Flask: Creating Custom Middleware

Flask: Creating Custom Middleware

Middleware in Flask allows developers to intercept and process requests and responses globally, enabling functionalities like logging, authentication, or performance monitoring. Flask, a lightweight Python web framework, uses WSGI (Web Server Gateway Interface) middleware to extend its capabilities. This tutorial explores creating custom middleware in Flask, covering setup, implementation, and best practices for enhancing application functionality.


01. Why Create Custom Middleware in Flask?

Middleware provides a way to execute code before or after request processing, centralizing logic for tasks like request modification, response formatting, or metrics collection. Flask’s minimalistic design makes middleware ideal for adding cross-cutting concerns without modifying route handlers. Custom middleware ensures flexibility, maintainability, and scalability for complex applications.

Example: Basic Custom Middleware

from flask import Flask
from werkzeug.wrappers import Request, Response

app = Flask(__name__)

class LoggingMiddleware:
    def __init__(self, wsgi_app):
        self.wsgi_app = wsgi_app

    def __call__(self, environ, start_response):
        print(f"Request: {environ['PATH_INFO']}")
        return self.wsgi_app(environ, start_response)

app.wsgi_app = LoggingMiddleware(app.wsgi_app)

@app.route('/')
def index():
    return "Hello, Flask!"

if __name__ == '__main__':
    app.run(debug=True, port=5000)

Output:

* Running on http://127.0.0.1:5000
Request: /

Explanation:

  • LoggingMiddleware - Wraps the Flask WSGI app to log request paths.
  • app.wsgi_app - Replaces the default WSGI app with the middleware.

02. Key Middleware Techniques

Custom middleware in Flask can handle request/response manipulation, integrate with external tools, and enforce application-wide policies. These techniques enhance functionality and maintainability. The table below summarizes key techniques and their applications:

Technique Description Use Case
Request Logging Log request details Debugging, auditing
Authentication Verify user credentials Secure endpoints
Performance Metrics Measure request latency Monitor performance
Response Modification Alter response headers/content Add CORS, custom headers
Error Handling Catch and process exceptions Custom error responses


2.1 Request Logging Middleware

Example: Detailed Request Logging

from flask import Flask
from werkzeug.wrappers import Request, Response
import logging

app = Flask(__name__)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class RequestLoggingMiddleware:
    def __init__(self, wsgi_app):
        self.wsgi_app = wsgi_app

    def __call__(self, environ, start_response):
        request = Request(environ)
        logger.info(f"Request: {request.method} {request.path} from {request.remote_addr}")
        return self.wsgi_app(environ, start_response)

app.wsgi_app = RequestLoggingMiddleware(app.wsgi_app)

@app.route('/')
def index():
    return "Logged Request"

if __name__ == '__main__':
    app.run(debug=True, port=5000)

Output:

* Running on http://127.0.0.1:5000
INFO:__main__:Request: GET / from 127.0.0.1

Explanation:

  • Logs request method, path, and client IP using werkzeug.wrappers.Request.
  • Centralizes logging for all routes.

2.2 Authentication Middleware

Example: Token-Based Authentication

from flask import Flask
from werkzeug.wrappers import Request, Response

app = Flask(__name__)

class AuthMiddleware:
    def __init__(self, wsgi_app):
        self.wsgi_app = wsgi_app
        self.valid_token = "secret-token"

    def __call__(self, environ, start_response):
        request = Request(environ)
        token = request.headers.get('Authorization')
        if token != self.valid_token:
            res = Response('Unauthorized', status=401)
            return res(environ, start_response)
        return self.wsgi_app(environ, start_response)

app.wsgi_app = AuthMiddleware(app.wsgi_app)

@app.route('/')
def index():
    return "Authenticated Access"

if __name__ == '__main__':
    app.run(debug=True, port=5000)

Output:

* Running on http://127.0.0.1:5000
(Request with Authorization: secret-token: Returns "Authenticated Access")
(Request without valid token: Returns 401 Unauthorized)

Explanation:

  • Checks for a valid token in the Authorization header.
  • Blocks unauthorized requests before reaching route handlers.

2.3 Performance Metrics Middleware

Example: Measuring Request Latency

from flask import Flask
from werkzeug.wrappers import Request, Response
import time
import logging

app = Flask(__name__)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class LatencyMiddleware:
    def __init__(self, wsgi_app):
        self.wsgi_app = wsgi_app

    def __call__(self, environ, start_response):
        start_time = time.time()
        response = self.wsgi_app(environ, start_response)
        latency = time.time() - start_time
        request = Request(environ)
        logger.info(f"Latency: {latency:.3f}s for {request.method} {request.path}")
        return response

app.wsgi_app = LatencyMiddleware(app.wsgi_app)

@app.route('/')
def index():
    return "Performance Monitored"

if __name__ == '__main__':
    app.run(debug=True, port=5000)

Output:

* Running on http://127.0.0.1:5000
INFO:__main__:Latency: 0.002s for GET /

Explanation:

  • Measures and logs request processing time.
  • Useful for identifying slow endpoints.

2.4 Response Modification Middleware

Example: Adding CORS Headers

from flask import Flask
from werkzeug.wrappers import Request, Response

app = Flask(__name__)

class CORSMiddleware:
    def __init__(self, wsgi_app):
        self.wsgi_app = wsgi_app

    def __call__(self, environ, start_response):
        def custom_start_response(status, headers, *args):
            headers.append(('Access-Control-Allow-Origin', '*'))
            return start_response(status, headers, *args)
        return self.wsgi_app(environ, custom_start_response)

app.wsgi_app = CORSMiddleware(app.wsgi_app)

@app.route('/')
def index():
    return "CORS Enabled"

if __name__ == '__main__':
    app.run(debug=True, port=5000)

Output:

* Running on http://127.0.0.1:5000
(Response includes header: Access-Control-Allow-Origin: *)

Explanation:

  • Modifies responses to include CORS headers.
  • Enables cross-origin requests for APIs.

2.5 Error Handling Middleware

Example: Custom Error Handling

from flask import Flask
from werkzeug.wrappers import Request, Response
import logging

app = Flask(__name__)

logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)

class ErrorHandlingMiddleware:
    def __init__(self, wsgi_app):
        self.wsgi_app = wsgi_app

    def __call__(self, environ, start_response):
        try:
            return self.wsgi_app(environ, start_response)
        except Exception as e:
            logger.error(f"Error: {str(e)}", exc_info=True)
            res = Response('Internal Server Error', status=500)
            return res(environ, start_response)

app.wsgi_app = ErrorHandlingMiddleware(app.wsgi_app)

@app.route('/error')
def error():
    raise ValueError("Test error")

if __name__ == '__main__':
    app.run(debug=True, port=5000)

Output:

* Running on http://127.0.0.1:5000
ERROR:__main__:Error: Test error
(Request to /error: Returns 500 Internal Server Error)

Explanation:

  • Catches exceptions and logs them with stack traces.
  • Returns a user-friendly error response.

2.6 Incorrect Middleware Setup

Example: Broken Middleware

from flask import Flask

app = Flask(__name__)

class BrokenMiddleware:
    def __init__(self, wsgi_app):
        self.wsgi_app = wsgi_app
    # Missing __call__ method

app.wsgi_app = BrokenMiddleware(app.wsgi_app)

@app.route('/')
def index():
    return "This will fail"

if __name__ == '__main__':
    app.run(debug=True, port=5000)

Output:

* Running on http://127.0.0.1:5000
TypeError: 'BrokenMiddleware' object is not callable

Explanation:

  • Missing __call__ method makes the middleware non-callable.
  • Solution: Implement __call__ to handle WSGI requests.

03. Effective Usage

3.1 Recommended Practices

  • Combine multiple middleware for modular functionality.

Example: Comprehensive Middleware Setup

from flask import Flask
from werkzeug.wrappers import Request, Response
import logging
import time

app = Flask(__name__)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class LoggingMiddleware:
    def __init__(self, wsgi_app):
        self.wsgi_app = wsgi_app

    def __call__(self, environ, start_response):
        request = Request(environ)
        logger.info(f"Request: {request.method} {request.path} from {request.remote_addr}")
        return self.wsgi_app(environ, start_response)

class LatencyMiddleware:
    def __init__(self, wsgi_app):
        self.wsgi_app = wsgi_app

    def __call__(self, environ, start_response):
        start_time = time.time()
        response = self.wsgi_app(environ, start_response)
        latency = time.time() - start_time
        request = Request(environ)
        logger.info(f"Latency: {latency:.3f}s for {request.path}")
        return response

class CORSMiddleware:
    def __init__(self, wsgi_app):
        self.wsgi_app = wsgi_app

    def __call__(self, environ, start_response):
        def custom_start_response(status, headers, *args):
            headers.append(('Access-Control-Allow-Origin', '*'))
            return start_response(status, headers, *args)
        return self.wsgi_app(environ, custom_start_response)

# Chain middleware
app.wsgi_app = LoggingMiddleware(LatencyMiddleware(CORSMiddleware(app.wsgi_app)))

@app.route('/')
def index():
    return "Comprehensive Middleware"

if __name__ == '__main__':
    app.run(debug=True, port=5000)

Output:

* Running on http://127.0.0.1:5000
INFO:__main__:Request: GET / from 127.0.0.1
INFO:__main__:Latency: 0.002s for /
(Response includes Access-Control-Allow-Origin: *)
  • Chains logging, latency tracking, and CORS middleware.
  • Modular design for maintainability and scalability.

3.2 Practices to Avoid

  • Avoid heavy processing in middleware to prevent performance degradation.

Example: Heavy Middleware

from flask import Flask
from werkzeug.wrappers import Request, Response
import time

app = Flask(__name__)

class HeavyMiddleware:
    def __init__(self, wsgi_app):
        self.wsgi_app = wsgi_app

    def __call__(self, environ, start_response):
        time.sleep(1)  # Simulate heavy processing
        return self.wsgi_app(environ, start_response)

app.wsgi_app = HeavyMiddleware(app.wsgi_app)

@app.route('/')
def index():
    return "Slow due to middleware"

if __name__ == '__main__':
    app.run(debug=True, port=5000)

Output:

* Running on http://127.0.0.1:5000
(Requests delayed by 1 second due to middleware)
  • Heavy processing in middleware slows down all requests.
  • Solution: Offload intensive tasks to asynchronous workers or optimize logic.

04. Common Use Cases

4.1 API Authentication

Enforce token-based authentication for API endpoints.

Example: API Token Middleware

from flask import Flask, jsonify
from werkzeug.wrappers import Request, Response

app = Flask(__name__)

class APITokenMiddleware:
    def __init__(self, wsgi_app):
        self.wsgi_app = wsgi_app
        self.valid_token = "api-secret"

    def __call__(self, environ, start_response):
        request = Request(environ)
        if not request.path.startswith('/api/'):
            return self.wsgi_app(environ, start_response)
        token = request.headers.get('Authorization')
        if token != self.valid_token:
            res = Response('Unauthorized API access', status=401)
            return res(environ, start_response)
        return self.wsgi_app(environ, start_response)

app.wsgi_app = APITokenMiddleware(app.wsgi_app)

@app.route('/api/data')
def data():
    return jsonify({'data': 'secure'})

if __name__ == '__main__':
    app.run(debug=True, port=5000)

Output:

* Running on http://127.0.0.1:5000
(Request to /api/data with Authorization: api-secret: Returns {"data": "secure"})
(Request without valid token: Returns 401 Unauthorized)

Explanation:

  • Secures API routes with token authentication.
  • Applies only to /api/ paths, leaving other routes unaffected.

4.2 Performance Monitoring

Track request latency and log performance metrics.

Example: Performance Monitoring Middleware

from flask import Flask
from werkzeug.wrappers import Request, Response
from prometheus_client import Histogram, make_wsgi_app
from werkzeug.middleware.dispatcher import DispatcherMiddleware
import time

app = Flask(__name__)

request_latency = Histogram('flask_request_latency_seconds', 'Request latency', ['endpoint'])

class PerformanceMiddleware:
    def __init__(self, wsgi_app):
        self.wsgi_app = wsgi_app

    def __call__(self, environ, start_response):
        request = Request(environ)
        start_time = time.time()
        response = self.wsgi_app(environ, start_response)
        latency = time.time() - start_time
        endpoint = request.path
        request_latency.labels(endpoint=endpoint).observe(latency)
        return response

app.wsgi_app = DispatcherMiddleware(PerformanceMiddleware(app.wsgi_app), {'/metrics': make_wsgi_app()})

@app.route('/')
def index():
    return "Performance Monitored"

if __name__ == '__main__':
    app.run(debug=True, port=5000)

Output (http://127.0.0.1:5000/metrics):

flask_request_latency_seconds_sum{endpoint="/"} 0.002

Explanation:

  • Integrates with Prometheus to track latency metrics.
  • Centralizes performance monitoring for all routes.

Conclusion

Creating custom middleware in Flask enhances application functionality and maintainability. Key takeaways:

  • Implement middleware to handle logging, authentication, performance, and more.
  • Use werkzeug.wrappers for request/response manipulation.
  • Chain middleware for modular, reusable logic.
  • Avoid heavy processing to maintain performance.

With these practices, you can build robust Flask applications with powerful, centralized functionality!

Comments