From 98982c585d1ac63fca6249010c75e4ef3a8eed38 Mon Sep 17 00:00:00 2001 From: Aditya8369 Date: Tue, 9 Jun 2026 10:37:01 +0530 Subject: [PATCH] [ENHANCEMENT] Secure chatbot backend --- chatbot.py | 161 +++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 138 insertions(+), 23 deletions(-) diff --git a/chatbot.py b/chatbot.py index 75b0659..d2d160d 100644 --- a/chatbot.py +++ b/chatbot.py @@ -1,32 +1,104 @@ import os +import time +import html +from collections import defaultdict, deque + from flask import Flask, request, jsonify, render_template from flask_cors import CORS import google.generativeai as genai -# Load API Key securely from environment variable -# Never hardcode secrets in source code. -# Set your key with: export GEMINI_API_KEY="your_key_here" (Linux/Mac) -# set GEMINI_API_KEY=your_key_here (Windows) +# ----------------------------- +# Security: API key via env var +# ----------------------------- API_KEY = os.environ.get("GEMINI_API_KEY") if not API_KEY: raise EnvironmentError( "GEMINI_API_KEY environment variable is not set. " - "Please set it before running the application. " - "See .env.example for guidance." + "Set it before running the server." ) genai.configure(api_key=API_KEY) +# ----------------------------- +# App setup +# ----------------------------- app = Flask(__name__, template_folder="templates", static_folder="static") CORS(app) - -def format_response(text): - """Formats chatbot response for better readability.""" - # Safely convert markdown-like bold to HTML - formatted_text = text.replace("**", "").replace("*", "").replace("\n", "
") - return formatted_text.strip() +# Create the model once at startup (cheaper than per request) +MODEL_NAME = os.environ.get("GEMINI_MODEL", "gemini-1.5-pro") +model = genai.GenerativeModel(MODEL_NAME) + + +# ----------------------------- +# Input/output hardening +# ----------------------------- +MAX_MESSAGE_CHARS = int(os.environ.get("CHAT_MAX_MESSAGE_CHARS", "2000")) + +# Simple refusal/guardrail keywords (heuristic) +ABUSE_KEYWORDS = [ + # Prompt injection patterns + "ignore previous", + "disregard", + "system prompt", + "developer prompt", + "jailbreak", + "act as", + "bypass", + "reveal your instructions", + # Harmful/illegal categories (basic) + "how to build a bomb", + "make a bomb", + "poison", + "suicide", + "self-harm", + "kill yourself", + "kill someone", + "murder", + "steal", + "credit card", + "creditcard", +] + + +def format_response(text: str) -> str: + """Return safe text for the frontend (no unsafe HTML).""" + if text is None: + return "" + # Escape HTML to prevent XSS; preserve line breaks. + return html.escape(text).replace("\n", "
").strip() + + +def looks_like_abuse(user_text: str) -> bool: + t = (user_text or "").lower() + return any(k in t for k in ABUSE_KEYWORDS) + + +# ----------------------------- +# Rate limiting (per IP) +# In-memory sliding window. +# ----------------------------- +RATE_LIMIT_WINDOW_SECONDS = int(os.environ.get("CHAT_RATE_LIMIT_WINDOW_SECONDS", "60")) +RATE_LIMIT_COUNT = int(os.environ.get("CHAT_RATE_LIMIT_COUNT", "10")) + +# ip -> deque[timestamps] +_requests = defaultdict(deque) + + +def rate_limited(ip: str) -> bool: + now = time.time() + q = _requests[ip] + + # Drop old timestamps + while q and (now - q[0]) > RATE_LIMIT_WINDOW_SECONDS: + q.popleft() + + if len(q) >= RATE_LIMIT_COUNT: + return True + + q.append(now) + return False @app.route("/") @@ -43,27 +115,70 @@ def chat(): user_input = data["message"] - if not user_input or not user_input.strip(): + if not isinstance(user_input, str): + return jsonify({"error": "Message must be a string"}), 400 + + user_input = user_input.strip() + + if not user_input: return jsonify({"error": "Message cannot be empty"}), 400 + if len(user_input) > MAX_MESSAGE_CHARS: + return ( + jsonify({"error": f"Message too large (max {MAX_MESSAGE_CHARS} characters)."}), + 413, + ) + + # Rate limit per IP + ip = request.headers.get("X-Forwarded-For", request.remote_addr) or "unknown" + ip = ip.split(",")[0].strip() # handle proxies lists + + if rate_limited(ip): + return jsonify({"error": "Too many requests. Please try again later."}), 429 + + # Basic abuse prevention / guardrails + if looks_like_abuse(user_input): + reply = ( + "I can’t help with that request. " + "If you need help with a physics, maths, chemistry, or biology topic, tell me what concept you’re working on." + ) + return jsonify({"reply": format_response(reply)}) + try: - model = genai.GenerativeModel("gemini-1.5-pro") - response = model.generate_content(user_input) + # Keep the model constrained to educational tutoring. + system_guardrails = ( + "You are an educational tutor for Physics, Maths, Chemistry, and Biology. " + "Be concise, step-by-step when useful, and avoid unsafe or illegal content. " + "If the user asks for something unrelated, politely guide them back to learning topics." + ) + + # Structured prompt (reduces injection impact) + prompt = f"{system_guardrails}\n\nUser message: {user_input}" + + response = model.generate_content(prompt) + + reply_text = "" + if response and getattr(response, "candidates", None): + cand0 = response.candidates[0] if response.candidates else None + parts = getattr(getattr(cand0, "content", None), "parts", None) + if parts and len(parts) > 0: + reply_text = getattr(parts[0], "text", None) or "" - if response and response.candidates: - reply = format_response(response.candidates[0].content.parts[0].text) - else: - reply = "I'm not sure how to respond to that. Can you try rephrasing?" + if not reply_text: + reply_text = "I’m not sure how to respond to that. Can you try rephrasing?" + + return jsonify({"reply": format_response(reply_text)}) except Exception as e: # Log internally but don't expose raw exception to the client print(f"[ERROR] Chatbot error: {e}") - return jsonify({"reply": "An error occurred while processing your request. Please try again later."}), 500 - - return jsonify({"reply": reply}) + return ( + jsonify({"reply": "An error occurred while processing your request. Please try again later."}), + 500, + ) if __name__ == "__main__": - # Disable debug mode in production debug_mode = os.environ.get("FLASK_DEBUG", "false").lower() == "true" app.run(debug=debug_mode) +