From adfdc8c9e52471b02c7fe68a7a5317ceaa5c777b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Thuret?= <contact@sebastien-thuret.fr>
Date: Fri, 5 Nov 2021 14:55:45 +0100
Subject: [PATCH] decrease counter if the user wait after the slowdown notice

---
 app/app.py   | 12 ++++++++++--
 app/flood.py | 13 +++++++++++++
 2 files changed, 23 insertions(+), 2 deletions(-)

diff --git a/app/app.py b/app/app.py
index edc45dd..ecc7807 100644
--- a/app/app.py
+++ b/app/app.py
@@ -25,6 +25,7 @@ def get_version():
     except:
         return "?"
 
+
 def get_upload_dir():
     upload_dir = os.path.join(tempfile.gettempdir(), "libretranslate-files-translate")
 
@@ -33,6 +34,7 @@ def get_upload_dir():
 
     return upload_dir
 
+
 def get_req_api_key():
     if request.is_json:
         json = get_json_dict(request)
@@ -42,6 +44,7 @@ def get_req_api_key():
 
     return ak
 
+
 def get_json_dict(request):
     d = request.get_json()
     if not isinstance(d, dict):
@@ -162,8 +165,13 @@ def create_app(args):
     def access_check(f):
         @wraps(f)
         def func(*a, **kw):
-            if flood.is_banned(get_remote_address()):
+            ip = get_remote_address()
+
+            if flood.is_banned(ip):
                 abort(403, description="Too many request limits violations")
+            else:
+                if flood.has_violation(ip):
+                    flood.decrease(ip)
 
             if args.api_keys and args.require_api_key_origin:
                 ak = get_req_api_key()
@@ -621,7 +629,7 @@ def create_app(args):
         """
         if args.disable_files_translation:
             abort(400, description="Files translation are disabled on this server.")
-        
+
         filepath = os.path.join(get_upload_dir(), filename)
         try:
             checked_filepath = security.path_traversal_check(filepath, get_upload_dir())
diff --git a/app/flood.py b/app/flood.py
index ef0ed6b..a76e200 100644
--- a/app/flood.py
+++ b/app/flood.py
@@ -19,6 +19,8 @@ def setup(violations_threshold=100):
     active = True
     threshold = violations_threshold
 
+    print(violations_threshold)
+
     scheduler = BackgroundScheduler()
     scheduler.add_job(func=clear_banned, trigger="interval", weeks=4)
     scheduler.start()
@@ -31,6 +33,17 @@ def report(request_ip):
     if active:
         banned[request_ip] = banned.get(request_ip, 0)
         banned[request_ip] += 1
+        print(banned[request_ip])
+
+
+def decrease(request_ip):
+    if banned[request_ip] > 0:
+        banned[request_ip] -= 1
+        print('decrease',  request_ip)
+
+
+def has_violation(request_ip):
+    return request_ip in banned and banned[request_ip] > 0
 
 
 def is_banned(request_ip):
-- 
GitLab