Coverage for apps / ai / api.py: 94%
215 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-12 10:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-12 10:49 +0000
1"""
2AI settings and prompts API endpoints.
3"""
5import logging
6from functools import wraps
7from typing import Callable, List, Optional
9from django_ratelimit.decorators import ratelimit
10from ninja import Router, Schema, Status
12from apps.core.models import AppSettings
13from apps.recipes.models import Recipe
15from .models import AIPrompt
16from .services.openrouter import OpenRouterService, AIUnavailableError, AIResponseError
17from .services.tips import generate_tips, clear_tips
18from .services.timer import generate_timer_name
19from .services.selector import repair_selector, get_sources_needing_attention
20from .services.validator import ValidationError
21from .services.cache import is_ai_cache_hit
22from .services.quota import release_quota, reserve_quota
23from apps.core.auth import AdminAuth, SessionAuth
25security_logger = logging.getLogger("security")
27router = Router(tags=["ai"])
30# Decorators
33def handle_ai_errors(func: Callable) -> Callable:
34 """Decorator to handle common AI service errors.
36 Catches AIUnavailableError, AIResponseError, and ValidationError,
37 returning appropriate error responses.
39 Returns:
40 - 503 with 'ai_unavailable' error for AIUnavailableError
41 - 400 with 'ai_error' error for AIResponseError or ValidationError
42 """
44 @wraps(func)
45 def wrapper(*args, **kwargs):
46 try:
47 return func(*args, **kwargs)
48 except AIUnavailableError as e:
49 return Status(
50 503,
51 {
52 "error": "ai_unavailable",
53 "message": str(e) or "AI features are not available. Please configure your API key in Settings.",
54 "action": "configure_key",
55 },
56 )
57 except (AIResponseError, ValidationError) as e:
58 return Status(
59 400,
60 {
61 "error": "ai_error",
62 "message": str(e),
63 },
64 )
66 return wrapper
69# Schemas
72class AIStatusOut(Schema):
73 available: bool
74 configured: bool
75 valid: bool
76 default_model: str
77 error: Optional[str] = None
78 error_code: Optional[str] = None
81class TestApiKeyIn(Schema):
82 api_key: str
85class TestApiKeyOut(Schema):
86 success: bool
87 message: str
90class SaveApiKeyIn(Schema):
91 api_key: str
94class SaveApiKeyOut(Schema):
95 success: bool
96 message: str
99class PromptOut(Schema):
100 prompt_type: str
101 name: str
102 description: str
103 system_prompt: str
104 user_prompt_template: str
105 model: str
106 is_active: bool
109class PromptUpdateIn(Schema):
110 system_prompt: Optional[str] = None
111 user_prompt_template: Optional[str] = None
112 model: Optional[str] = None
113 is_active: Optional[bool] = None
116class ModelOut(Schema):
117 id: str
118 name: str
121class ErrorOut(Schema):
122 error: str
123 message: str
124 action: Optional[str] = None # User-facing action to resolve the error
127# Endpoints
130@router.get("/status", response=AIStatusOut, auth=SessionAuth())
131def get_ai_status(request):
132 """Check if AI service is available with optional key validation.
134 Returns a status object with:
135 - available: Whether AI features can be used (configured AND valid)
136 - configured: Whether an API key is configured
137 - valid: Whether the API key has been validated successfully
138 - default_model: The default AI model
139 - error: Error message if something is wrong
140 - error_code: Machine-readable error code
141 """
142 settings = AppSettings.get()
143 has_key = bool(settings.openrouter_api_key)
145 status = {
146 "available": False,
147 "configured": has_key,
148 "valid": False,
149 "default_model": settings.default_ai_model,
150 "error": None,
151 "error_code": None,
152 }
154 if not has_key:
155 status["error"] = "No API key configured"
156 status["error_code"] = "no_api_key"
157 return status
159 # Validate key using cached validation
160 is_valid, error_message = OpenRouterService.validate_key_cached()
161 status["valid"] = is_valid
162 status["available"] = is_valid
164 if not is_valid:
165 status["error"] = error_message or "API key is invalid or expired"
166 status["error_code"] = "invalid_api_key"
168 return status
171@router.post("/test-api-key", response={200: TestApiKeyOut, 400: ErrorOut, 429: dict}, auth=AdminAuth())
172@ratelimit(key="ip", rate="20/h", method="POST", block=False)
173def test_api_key(request, data: TestApiKeyIn):
174 """Test if an API key is valid."""
175 if getattr(request, "limited", False):
176 return Status(429, {"detail": "Rate limit exceeded. Try again later."})
177 if not data.api_key:
178 return Status(
179 400,
180 {
181 "error": "validation_error",
182 "message": "API key is required",
183 },
184 )
186 success, message = OpenRouterService.test_connection(data.api_key)
187 return {
188 "success": success,
189 "message": message,
190 }
193@router.post("/save-api-key", response={200: SaveApiKeyOut, 400: ErrorOut, 429: dict}, auth=AdminAuth())
194@ratelimit(key="ip", rate="20/h", method="POST", block=False)
195def save_api_key(request, data: SaveApiKeyIn):
196 """Save the OpenRouter API key."""
197 if getattr(request, "limited", False):
198 return Status(429, {"detail": "Rate limit exceeded. Try again later."})
199 settings = AppSettings.get()
200 settings.openrouter_api_key = data.api_key
201 settings.save()
203 # Invalidate the validation cache since key was updated
204 OpenRouterService.invalidate_key_cache()
206 return {
207 "success": True,
208 "message": "API key saved successfully",
209 }
212@router.get("/prompts", response=List[PromptOut], auth=AdminAuth())
213def list_prompts(request):
214 """List all AI prompts."""
215 prompts = AIPrompt.objects.all()
216 return list(prompts)
219def _get_prompt_or_error(prompt_type: str):
220 """Return an AIPrompt or a (404, error dict) tuple."""
221 try:
222 return AIPrompt.objects.get(prompt_type=prompt_type)
223 except AIPrompt.DoesNotExist:
224 return Status(404, {"error": "not_found", "message": f'Prompt type "{prompt_type}" not found'})
227def _validate_model(model_id: str):
228 """Return a (422, error dict) tuple if model is invalid, else None."""
229 try:
230 valid_ids = {m["id"] for m in OpenRouterService().get_available_models()}
231 if model_id not in valid_ids:
232 return Status(
233 422,
234 {
235 "error": "invalid_model",
236 "message": f'Model "{model_id}" is not available. Please select a valid model.',
237 },
238 )
239 except (AIUnavailableError, AIResponseError):
240 # Can't validate — allow the change; it may fail later
241 pass
242 return None
245@router.get("/prompts/{prompt_type}", response={200: PromptOut, 404: ErrorOut}, auth=AdminAuth())
246def get_prompt(request, prompt_type: str):
247 """Get a specific AI prompt by type."""
248 return _get_prompt_or_error(prompt_type)
251@router.put("/prompts/{prompt_type}", response={200: PromptOut, 404: ErrorOut, 422: ErrorOut}, auth=AdminAuth())
252def update_prompt(request, prompt_type: str, data: PromptUpdateIn):
253 """Update a specific AI prompt."""
254 result = _get_prompt_or_error(prompt_type)
255 if not isinstance(result, AIPrompt):
256 return result
257 prompt = result
259 if data.model is not None:
260 error = _validate_model(data.model)
261 if error:
262 return error
264 # Update only provided fields
265 if data.system_prompt is not None:
266 prompt.system_prompt = data.system_prompt
267 if data.user_prompt_template is not None:
268 prompt.user_prompt_template = data.user_prompt_template
269 if data.model is not None:
270 prompt.model = data.model
271 if data.is_active is not None:
272 prompt.is_active = data.is_active
274 prompt.save()
275 return prompt
278@router.get("/models", response=List[ModelOut], auth=SessionAuth())
279def list_models(request):
280 """List available AI models from OpenRouter."""
281 try:
282 service = OpenRouterService()
283 return service.get_available_models()
284 except AIUnavailableError:
285 # No API key configured - return empty list
286 return []
287 except AIResponseError:
288 # API error - return empty list
289 return []
292# Tips Schemas
295class TipsIn(Schema):
296 recipe_id: int
297 regenerate: bool = False
300class TipsOut(Schema):
301 tips: List[str]
302 cached: bool
305# Tips Endpoints
308@router.post(
309 "/tips", response={200: TipsOut, 400: ErrorOut, 404: ErrorOut, 429: dict, 503: ErrorOut}, auth=SessionAuth()
310)
311@ratelimit(key="ip", rate="20/h", method="POST", block=False)
312@handle_ai_errors
313def tips_endpoint(request, data: TipsIn):
314 """Generate cooking tips for a recipe.
316 Pass regenerate=True to clear existing tips and generate fresh ones.
317 Only works for recipes owned by the requesting profile.
318 """
319 if getattr(request, "limited", False):
320 security_logger.warning("Rate limit hit: /ai/tips from %s", request.META.get("REMOTE_ADDR"))
321 return Status(429, {"error": "rate_limited", "message": "Too many requests. Please try again later."})
323 allowed, info = reserve_quota(request.auth, "tips")
324 if not allowed:
325 return Status(429, {"error": "quota_exceeded", "message": "Daily limit reached for tips", **info})
327 from apps.profiles.utils import get_current_profile_or_none
329 profile = get_current_profile_or_none(request)
331 try:
332 recipe = Recipe.objects.get(id=data.recipe_id)
333 except Recipe.DoesNotExist:
334 release_quota(request.auth, "tips")
335 return Status(
336 404,
337 {
338 "error": "not_found",
339 "message": f"Recipe {data.recipe_id} not found",
340 },
341 )
343 if not profile or recipe.profile_id != profile.id:
344 release_quota(request.auth, "tips")
345 return Status(
346 404,
347 {
348 "error": "not_found",
349 "message": f"Recipe {data.recipe_id} not found",
350 },
351 )
353 # Clear existing tips if regenerate requested
354 if data.regenerate:
355 clear_tips(data.recipe_id)
357 try:
358 result = generate_tips(data.recipe_id)
359 except Exception:
360 release_quota(request.auth, "tips")
361 raise
362 if result.get("cached"):
363 release_quota(request.auth, "tips")
364 return result
367# Timer Naming Schemas
370class TimerNameIn(Schema):
371 step_text: str
372 duration_minutes: int
375class TimerNameOut(Schema):
376 label: str
379# Timer Naming Endpoints
382@router.post("/timer-name", response={200: TimerNameOut, 400: ErrorOut, 429: dict, 503: ErrorOut}, auth=SessionAuth())
383@ratelimit(key="ip", rate="60/h", method="POST", block=False)
384@handle_ai_errors
385def timer_name_endpoint(request, data: TimerNameIn):
386 """Generate a descriptive name for a cooking timer.
388 Takes a cooking instruction and duration, returns a short label.
389 """
390 if getattr(request, "limited", False):
391 security_logger.warning("Rate limit hit: /ai/timer-name from %s", request.META.get("REMOTE_ADDR"))
392 return Status(429, {"error": "rate_limited", "message": "Too many requests. Please try again later."})
394 allowed, info = reserve_quota(request.auth, "timer")
395 if not allowed:
396 return Status(429, {"error": "quota_exceeded", "message": "Daily limit reached for timer", **info})
398 if not data.step_text:
399 release_quota(request.auth, "timer")
400 return Status(
401 400,
402 {
403 "error": "validation_error",
404 "message": "Step text is required",
405 },
406 )
408 if data.duration_minutes <= 0:
409 release_quota(request.auth, "timer")
410 return Status(
411 400,
412 {
413 "error": "validation_error",
414 "message": "Duration must be positive",
415 },
416 )
418 was_cached = is_ai_cache_hit("timer_name", step_text=data.step_text, duration_minutes=data.duration_minutes)
419 try:
420 result = generate_timer_name(
421 step_text=data.step_text,
422 duration_minutes=data.duration_minutes,
423 )
424 except Exception:
425 release_quota(request.auth, "timer")
426 raise
427 if was_cached:
428 release_quota(request.auth, "timer")
429 return result
432# Selector Repair Schemas
435class SelectorRepairIn(Schema):
436 source_id: int
437 html_sample: str
438 target: str = "recipe search result"
439 confidence_threshold: float = 0.8
440 auto_update: bool = True
443class SelectorRepairOut(Schema):
444 suggestions: List[str]
445 confidence: float
446 original_selector: str
447 updated: bool
448 new_selector: Optional[str] = None
451class SourceNeedingAttentionOut(Schema):
452 id: int
453 host: str
454 name: str
455 result_selector: str
456 consecutive_failures: int
459# Selector Repair Endpoints
462@router.post(
463 "/repair-selector",
464 response={200: SelectorRepairOut, 400: ErrorOut, 404: ErrorOut, 429: dict, 503: ErrorOut},
465 auth=AdminAuth(),
466)
467@ratelimit(key="ip", rate="5/h", method="POST", block=False)
468@handle_ai_errors
469def repair_selector_endpoint(request, data: SelectorRepairIn):
470 """Attempt to repair a broken CSS selector using AI.
472 Analyzes HTML from the search page and suggests new selectors.
473 If confidence is high enough and auto_update=True, the source is updated.
475 This endpoint is intended for admin/maintenance use.
476 """
477 if getattr(request, "limited", False):
478 security_logger.warning("Rate limit hit: /ai/repair-selector from %s", request.META.get("REMOTE_ADDR"))
479 return Status(429, {"error": "rate_limited", "message": "Too many requests. Please try again later."})
480 from apps.recipes.models import SearchSource
482 try:
483 source = SearchSource.objects.get(id=data.source_id)
484 except SearchSource.DoesNotExist:
485 return Status(
486 404,
487 {
488 "error": "not_found",
489 "message": f"SearchSource {data.source_id} not found",
490 },
491 )
493 if not data.html_sample:
494 return Status(
495 400,
496 {
497 "error": "validation_error",
498 "message": "HTML sample is required",
499 },
500 )
502 result = repair_selector(
503 source=source,
504 html_sample=data.html_sample,
505 target=data.target,
506 confidence_threshold=data.confidence_threshold,
507 auto_update=data.auto_update,
508 )
509 return {
510 "suggestions": result["suggestions"],
511 "confidence": result["confidence"],
512 "original_selector": result["original_selector"] or "",
513 "updated": result["updated"],
514 "new_selector": result.get("new_selector"),
515 }
518@router.get("/sources-needing-attention", response=List[SourceNeedingAttentionOut], auth=AdminAuth())
519def sources_needing_attention_endpoint(request):
520 """List all SearchSources that need attention (broken selectors).
522 Returns sources with consecutive_failures >= 3 or needs_attention flag set.
523 """
524 sources = get_sources_needing_attention()
525 return [
526 {
527 "id": s.id,
528 "host": s.host,
529 "name": s.name,
530 "result_selector": s.result_selector or "",
531 "consecutive_failures": s.consecutive_failures,
532 }
533 for s in sources
534 ]