diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 625fb8b47..bad92d383 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -27,6 +27,8 @@ from ami.jobs.tasks import process_nats_pipeline_result from ami.main.api.schemas import project_id_doc_param from ami.main.api.views import DefaultViewSet +from ami.ml.auth import HasProcessingServiceAPIKey +from ami.ml.models.processing_service import ProcessingService from ami.utils.fields import url_boolean_param from .models import Job, JobDispatchMode, JobState @@ -146,6 +148,13 @@ class JobViewSet(DefaultViewSet, ProjectMixin): permission_classes = [ObjectPermission] + def _update_processing_service_heartbeat(self, request): + """Update heartbeat for the specific PS identified by API key auth.""" + from ami.ml.schemas import get_client_info + + if isinstance(request.auth, ProcessingService): + request.auth.mark_seen(client_info=get_client_info(request)) + def get_serializer_class(self): """ Return different serializers for list and detail views. @@ -247,7 +256,12 @@ def list(self, request, *args, **kwargs): responses={200: MLJobTasksResponseSerializer}, parameters=[project_id_doc_param], ) - @action(detail=True, methods=["post"], name="tasks") + @action( + detail=True, + methods=["post"], + name="tasks", + permission_classes=[ObjectPermission | HasProcessingServiceAPIKey], + ) def tasks(self, request, pk=None): """ Fetch tasks from the job queue (POST). @@ -275,8 +289,13 @@ def tasks(self, request, pk=None): if not job.pipeline: raise ValidationError("This job does not have a pipeline configured") - # Record heartbeat for async processing services on this pipeline - _mark_pipeline_pull_services_seen(job) + # Record heartbeat. When the request is API-key-authenticated we know the + # exact PS, so use the precise per-PS heartbeat. Fall back to the bulk + # pipeline-level heartbeat for token-authenticated requests (transition period). + if isinstance(request.auth, ProcessingService): + self._update_processing_service_heartbeat(request) + else: + _mark_pipeline_pull_services_seen(job) # Get tasks from NATS JetStream from ami.ml.orchestration.nats_queue import TaskQueueManager @@ -298,7 +317,12 @@ async def get_tasks(): responses={200: MLJobResultsResponseSerializer}, parameters=[project_id_doc_param], ) - @action(detail=True, methods=["post"], name="result") + @action( + detail=True, + methods=["post"], + name="result", + permission_classes=[ObjectPermission | HasProcessingServiceAPIKey], + ) def result(self, request, pk=None): """ Submit pipeline results. @@ -310,8 +334,11 @@ def result(self, request, pk=None): job = self.get_object() - # Record heartbeat for async processing services on this pipeline - _mark_pipeline_pull_services_seen(job) + # Record heartbeat (see comment in tasks() for rationale) + if isinstance(request.auth, ProcessingService): + self._update_processing_service_heartbeat(request) + else: + _mark_pipeline_pull_services_seen(job) serializer = MLJobResultsRequestSerializer(data=request.data) serializer.is_valid(raise_exception=True) diff --git a/ami/main/tests.py b/ami/main/tests.py index 4bfbdc4de..12c9a6991 100644 --- a/ami/main/tests.py +++ b/ami/main/tests.py @@ -3648,9 +3648,16 @@ def test_nonexistent_taxa_list_returns_404(self): class TestProjectPipelinesAPI(APITestCase): - """Test the project pipelines API endpoint.""" + """Test the project pipelines API endpoint. + + Pipeline registration requires API key authentication (since PR #1194). + The processing service is identified by its API key, not by name. + """ def setUp(self): + from unittest.mock import patch + + from ami.ml.models.processing_service import ProcessingServiceAPIKey from ami.users.roles import ProjectManager, create_roles_for_project self.user = User.objects.create_user(email="test@example.com") # type: ignore @@ -3665,76 +3672,68 @@ def setUp(self): create_roles_for_project(self.other_project) ProjectManager.assign_user(self.user, self.project) + # Create a processing service with API key for registration tests + with patch.object(ProcessingService, "get_status"): + self.service = ProcessingService.objects.create(name="TestService", endpoint_url=None) + self.service.projects.add(self.project) + _, self.api_key = ProcessingServiceAPIKey.objects.create_key(name="test-key", processing_service=self.service) + def _get_pipelines_url(self, project_id): """Get the pipelines API URL for a project.""" return f"/api/v2/projects/{project_id}/pipelines/" - def _get_test_payload(self, service_name: str): - """Get a minimal test payload for pipeline registration.""" - return { - "processing_service_name": service_name, - "pipelines": [], - } - - def test_create_new_service_success(self): - """Test creating a new processing service if it doesn't exist.""" + def test_registration_with_api_key_succeeds(self): + """Test that API-key-authenticated registration succeeds.""" url = self._get_pipelines_url(self.project.pk) - payload = self._get_test_payload("NewService") + payload = {"pipelines": []} - self.client.force_authenticate(user=self.user) + self.client.credentials(HTTP_AUTHORIZATION=f"Api-Key {self.api_key}") response = self.client.post(url, payload, format="json") self.assertEqual(response.status_code, status.HTTP_201_CREATED) - # Verify service was created and associated - service = ProcessingService.objects.get(name="NewService") - self.assertIn(self.project, service.projects.all()) - def test_reregistration_is_idempotent(self): - """Test that re-registering a service already associated with the project succeeds.""" - # Create and associate service - service = ProcessingService.objects.create(name="ExistingService") - service.projects.add(self.project) - + """Test that re-registering the same service succeeds.""" url = self._get_pipelines_url(self.project.pk) - payload = self._get_test_payload("ExistingService") + payload = {"pipelines": []} - self.client.force_authenticate(user=self.user) - response = self.client.post(url, payload, format="json") + self.client.credentials(HTTP_AUTHORIZATION=f"Api-Key {self.api_key}") - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + response1 = self.client.post(url, payload, format="json") + self.assertEqual(response1.status_code, status.HTTP_201_CREATED) - def test_associate_existing_service_success(self): - """Test associating existing service with project when not yet associated.""" - # Create service but don't associate with project - service = ProcessingService.objects.create(name="UnassociatedService") + response2 = self.client.post(url, payload, format="json") + self.assertEqual(response2.status_code, status.HTTP_201_CREATED) + def test_registration_updates_heartbeat(self): + """Test that registration marks the service as seen.""" url = self._get_pipelines_url(self.project.pk) - payload = self._get_test_payload("UnassociatedService") + payload = {"pipelines": []} - self.client.force_authenticate(user=self.user) - response = self.client.post(url, payload, format="json") + self.client.credentials(HTTP_AUTHORIZATION=f"Api-Key {self.api_key}") + self.client.post(url, payload, format="json") - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertIn(self.project, service.projects.all()) + self.service.refresh_from_db() + self.assertIsNotNone(self.service.last_seen) + self.assertTrue(self.service.last_seen_live) - def test_unauthorized_project_access_returns_403(self): - """Test 403 when user doesn't have write access to project.""" + def test_wrong_project_denied(self): + """Test that API key for a PS not linked to the target project is denied.""" url = self._get_pipelines_url(self.other_project.pk) - payload = self._get_test_payload("UnauthorizedService") + payload = {"pipelines": []} - self.client.force_authenticate(user=self.user) + self.client.credentials(HTTP_AUTHORIZATION=f"Api-Key {self.api_key}") response = self.client.post(url, payload, format="json") - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertIn(response.status_code, [status.HTTP_403_FORBIDDEN, status.HTTP_404_NOT_FOUND]) - def test_invalid_payload_returns_400(self): - """Test 400 when payload is invalid.""" + def test_user_token_auth_rejected_for_registration(self): + """Test that user-token auth is rejected for pipeline registration.""" url = self._get_pipelines_url(self.project.pk) - invalid_payload = {"invalid": "data"} + payload = {"pipelines": []} self.client.force_authenticate(user=self.user) - response = self.client.post(url, invalid_payload, format="json") + response = self.client.post(url, payload, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -3771,7 +3770,7 @@ def test_list_pipelines_draft_project_non_member(self): def test_unauthenticated_write_returns_401(self): """Unauthenticated users cannot register pipelines.""" url = self._get_pipelines_url(self.project.pk) - payload = self._get_test_payload("AnonService") + payload = {"pipelines": []} response = self.client.post(url, payload, format="json") self.assertIn(response.status_code, [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]) diff --git a/ami/ml/admin.py b/ami/ml/admin.py index 008b20e84..8ed2f8ae5 100644 --- a/ami/ml/admin.py +++ b/ami/ml/admin.py @@ -1,10 +1,11 @@ from django.contrib import admin +from rest_framework_api_key.admin import APIKeyModelAdmin from ami.main.admin import AdminBase, ProjectPipelineConfigInline from .models.algorithm import Algorithm, AlgorithmCategoryMap from .models.pipeline import Pipeline -from .models.processing_service import ProcessingService +from .models.processing_service import ProcessingService, ProcessingServiceAPIKey @admin.register(Algorithm) @@ -70,8 +71,32 @@ class ProcessingServiceAdmin(AdminBase): "id", "name", "endpoint_url", + "last_seen_live", "created_at", ] + readonly_fields = ["last_seen_client_info"] + + @admin.action(description="Generate API key for selected processing services (revokes existing)") + def generate_api_key(self, request, queryset): + for ps in queryset: + ps.api_keys.filter(revoked=False).update(revoked=True) + _, plaintext_key = ProcessingServiceAPIKey.objects.create_key( + name=f"{ps.name} key", + processing_service=ps, + ) + self.message_user( + request, + f"{ps.name}: {plaintext_key} (copy now — it won't be shown again)", + ) + + actions = [generate_api_key] + + +@admin.register(ProcessingServiceAPIKey) +class ProcessingServiceAPIKeyAdmin(APIKeyModelAdmin): + list_display = [*APIKeyModelAdmin.list_display, "processing_service"] + list_filter = ["processing_service"] + search_fields = [*APIKeyModelAdmin.search_fields, "processing_service__name"] @admin.register(AlgorithmCategoryMap) diff --git a/ami/ml/auth.py b/ami/ml/auth.py new file mode 100644 index 000000000..e8660e996 --- /dev/null +++ b/ami/ml/auth.py @@ -0,0 +1,106 @@ +""" +API key authentication for processing services. + +Uses djangorestframework-api-key to provide key-based auth. Each ProcessingService +can have one or more API keys. When a request arrives with `Authorization: Api-Key `, +the authentication class identifies the ProcessingService and sets request.auth to it. + +Contains: + - ProcessingServiceAPIKeyAuthentication: DRF auth backend + - HasProcessingServiceAPIKey: DRF permission class + +The ProcessingServiceAPIKey model lives in ami.ml.models.processing_service. +""" + +import logging + +from rest_framework import authentication, exceptions, permissions +from rest_framework_api_key.permissions import KeyParser + +from ami.ml.models.processing_service import ProcessingServiceAPIKey + +logger = logging.getLogger(__name__) + + +class ProcessingServiceAPIKeyAuthentication(authentication.BaseAuthentication): + """ + DRF authentication class that identifies a ProcessingService from an API key. + + Sets: + request.user = AnonymousUser (required by django-guardian/ObjectPermission) + request.auth = ProcessingService instance + + This allows views to check `request.auth` to get the calling service, + and permission classes to verify project access. + """ + + key_parser = KeyParser() + + def authenticate(self, request): + key = self.key_parser.get(request) + if not key: + return None # No Api-Key header; fall through to next auth class + + try: + api_key = ProcessingServiceAPIKey.objects.get_from_key(key) + except ProcessingServiceAPIKey.DoesNotExist: + raise exceptions.AuthenticationFailed("Invalid API key.") + + if not api_key.is_valid: + raise exceptions.AuthenticationFailed("API key has been revoked or expired.") + + from django.contrib.auth.models import AnonymousUser + + return (AnonymousUser(), api_key.processing_service) + + def authenticate_header(self, request): + return "Api-Key" + + +class HasProcessingServiceAPIKey(permissions.BasePermission): + """ + Allow access for requests authenticated with a ProcessingService API key. + + The auth backend places the ProcessingService on request.auth. + This permission verifies project membership. + + Compose with ObjectPermission for endpoints used by both users and services: + permission_classes = [ObjectPermission | HasProcessingServiceAPIKey] + """ + + def has_permission(self, request, view): + from ami.ml.models.processing_service import ProcessingService + + if not isinstance(request.auth, ProcessingService): + return False + + # For detail views (e.g. /jobs/{pk}/tasks/), defer project scoping + # to has_object_permission where we can derive it from the object. + # CONTRACT: all detail-level actions using this permission MUST call + # self.get_object() so that DRF invokes has_object_permission(). + # Actions that fetch objects manually without get_object() will bypass + # project-scoping checks. + if view.kwargs.get("pk"): + return True + + get_active_project = getattr(view, "get_active_project", None) + if not callable(get_active_project): + return False + + project = get_active_project() + if not project: + return False + + return request.auth.projects.filter(pk=project.pk).exists() + + def has_object_permission(self, request, view, obj): + from ami.ml.models.processing_service import ProcessingService + + if not isinstance(request.auth, ProcessingService): + return False + + ps = request.auth + project = obj.get_project() if hasattr(obj, "get_project") else None + if not project: + return False + return ps.projects.filter(pk=project.pk).exists() diff --git a/ami/ml/migrations/0029_api_key_and_client_info.py b/ami/ml/migrations/0029_api_key_and_client_info.py new file mode 100644 index 000000000..30f5a4688 --- /dev/null +++ b/ami/ml/migrations/0029_api_key_and_client_info.py @@ -0,0 +1,67 @@ +# Generated by Django 4.2.10 on 2026-03-29 05:36 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0028_normalize_empty_endpoint_url_to_null"), + ] + + operations = [ + migrations.AddField( + model_name="processingservice", + name="last_seen_client_info", + field=models.JSONField(blank=True, null=True), + ), + migrations.CreateModel( + name="ProcessingServiceAPIKey", + fields=[ + ( + "id", + models.CharField(editable=False, max_length=150, primary_key=True, serialize=False, unique=True), + ), + ("prefix", models.CharField(editable=False, max_length=8, unique=True)), + ("hashed_key", models.CharField(editable=False, max_length=150)), + ("created", models.DateTimeField(auto_now_add=True, db_index=True)), + ( + "name", + models.CharField( + default="", + help_text="A free-form name for the API key. Need not be unique. 50 characters max.", + max_length=50, + ), + ), + ( + "revoked", + models.BooleanField( + blank=True, + default=False, + help_text="If the API key is revoked, clients cannot use it anymore. (This cannot be undone.)", + ), + ), + ( + "expiry_date", + models.DateTimeField( + blank=True, + help_text="Once API key expires, clients cannot use it anymore.", + null=True, + verbose_name="Expires", + ), + ), + ( + "processing_service", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, related_name="api_keys", to="ml.processingservice" + ), + ), + ], + options={ + "verbose_name": "Processing Service API Key", + "verbose_name_plural": "Processing Service API Keys", + "ordering": ("-created",), + "abstract": False, + }, + ), + ] diff --git a/ami/ml/models/__init__.py b/ami/ml/models/__init__.py index 5000c7f53..758124547 100644 --- a/ami/ml/models/__init__.py +++ b/ami/ml/models/__init__.py @@ -1,6 +1,6 @@ from ami.ml.models.algorithm import Algorithm, AlgorithmCategoryMap from ami.ml.models.pipeline import Pipeline -from ami.ml.models.processing_service import ProcessingService +from ami.ml.models.processing_service import ProcessingService, ProcessingServiceAPIKey from ami.ml.models.project_pipeline_config import ProjectPipelineConfig __all__ = [ @@ -8,5 +8,6 @@ "AlgorithmCategoryMap", "Pipeline", "ProcessingService", + "ProcessingServiceAPIKey", "ProjectPipelineConfig", ] diff --git a/ami/ml/models/processing_service.py b/ami/ml/models/processing_service.py index fce1aefc5..dba3e61b3 100644 --- a/ami/ml/models/processing_service.py +++ b/ami/ml/models/processing_service.py @@ -7,6 +7,7 @@ import requests from django.conf import settings from django.db import models +from rest_framework_api_key.models import AbstractAPIKey from ami.base.models import BaseQuerySet from ami.main.models import BaseModel, Project @@ -36,7 +37,7 @@ def async_services(self) -> "ProcessingServiceQuerySet": out to them, they poll Antenna for tasks and push results back. Their liveness is tracked via heartbeats from mark_seen() rather than active health checks. """ - return self.filter(endpoint_url__isnull=True) + return self.filter(models.Q(endpoint_url__isnull=True) | models.Q(endpoint_url__exact="")) def sync_services(self) -> "ProcessingServiceQuerySet": """ @@ -46,7 +47,7 @@ def sync_services(self) -> "ProcessingServiceQuerySet": /readyz and /process endpoints. Their liveness is tracked by the periodic check_processing_services_online Celery task. """ - return self.filter(endpoint_url__isnull=False) + return self.exclude(models.Q(endpoint_url__isnull=True) | models.Q(endpoint_url__exact="")) class ProcessingServiceManager(models.Manager.from_queryset(ProcessingServiceQuerySet)): @@ -71,6 +72,9 @@ class ProcessingService(BaseModel): last_seen_live = models.BooleanField(null=True) last_seen_latency = models.FloatField(null=True) + # Last known client info from the most recent request + last_seen_client_info = models.JSONField(null=True, blank=True) + objects = ProcessingServiceManager() @property @@ -174,14 +178,32 @@ def create_pipelines( algorithms_created=algorithms_created, ) - def mark_seen(self, live: bool = True) -> None: + # Fields in client_info that are set by the server (get_client_info) and + # should always be overwritten, even when the new value is empty. + _SERVER_OBSERVED_CLIENT_FIELDS = frozenset({"ip", "user_agent"}) + + def mark_seen(self, live: bool = True, client_info: dict | None = None) -> None: """ Record that we heard from this processing service. Used by async/pull-mode services that don't have an endpoint to check. + Optionally persists client_info (ip, user_agent, hostname, etc.) from the request. + + Client-reported fields (hostname, software, version, etc.) are merged + non-destructively: a heartbeat with empty values won't overwrite rich + data saved during registration. Server-observed fields (ip, user_agent) + are always updated since they reflect the current request. """ self.last_seen = datetime.datetime.now() self.last_seen_live = live - self.save(update_fields=["last_seen", "last_seen_live"]) + update_fields = ["last_seen", "last_seen_live"] + if client_info is not None: + merged = dict(self.last_seen_client_info or {}) + for key, value in client_info.items(): + if key in self._SERVER_OBSERVED_CLIENT_FIELDS or value: + merged[key] = value + self.last_seen_client_info = merged + update_fields.append("last_seen_client_info") + self.save(update_fields=update_fields) def get_status(self, timeout=90) -> ProcessingServiceStatusResponse: """ @@ -306,13 +328,19 @@ def get_or_create_default_processing_service( register_pipelines: bool = True, ) -> "ProcessingService | None": """ - Create a default processing service for a project. + Create a default push-mode processing service for a project. If configured, will use the global default processing service for the current environment. Otherwise, it return None. Set the "DEFAULT_PROCESSING_SERVICE_ENDPOINT" and "DEFAULT_PROCESSING_SERVICE_NAME" environment variables to configure & enable the default processing service. + + .. deprecated:: + For async/pull-mode services, use the self-registration flow instead: + the processing service authenticates with user credentials, creates itself + via the API, generates its own API key, and registers pipelines. + See processing_services/minimal/register.py for an example. """ name = settings.DEFAULT_PROCESSING_SERVICE_NAME or "Default Processing Service" @@ -335,4 +363,25 @@ def get_or_create_default_processing_service( enable_only=settings.DEFAULT_PIPELINES_ENABLED, projects=Project.objects.filter(pk=project.pk), ) + return service + + +class ProcessingServiceAPIKey(AbstractAPIKey): + """ + An API key tied to a specific ProcessingService. + + The plaintext key is only available at creation time (returned by + ProcessingServiceAPIKey.objects.create_key()). The database stores + only the hashed version. The 8-character prefix is stored for display. + """ + + processing_service = models.ForeignKey( + "ml.ProcessingService", + on_delete=models.CASCADE, + related_name="api_keys", + ) + + class Meta(AbstractAPIKey.Meta): + verbose_name = "Processing Service API Key" + verbose_name_plural = "Processing Service API Keys" diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index 9322e4116..83d80235d 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -8,6 +8,81 @@ logger.setLevel(logging.DEBUG) +class ProcessingServiceClientInfo(pydantic.BaseModel): + """Identity metadata sent by a processing service worker. + + Client-reported fields (all optional, sent by the worker): + hostname, software, version, platform, pod_name + + Server-observed fields (set by get_client_info(), never from client): + ip, user_agent + """ + + # Client-reported + hostname: str = "" + software: str = "" + version: str = "" + platform: str = "" + pod_name: str = "" + + # Server-observed (overwritten on the server, cannot be spoofed) + ip: str = "" + user_agent: str = "" + + # Max number of extra fields and max length per extra value to prevent + # unbounded storage from arbitrary client-supplied keys. + _MAX_EXTRA_FIELDS = 20 + _MAX_EXTRA_VALUE_LEN = 500 + + class Config: + extra = "allow" + + @pydantic.validator("*", pre=True, each_item=False) + def _truncate_strings(cls, v): + if isinstance(v, str) and len(v) > 500: + return v[:500] + return v + + @pydantic.root_validator(pre=False) + def _limit_extra_fields(cls, values): + known = set(cls.__fields__) + extra_keys = [k for k in values if k not in known] + if len(extra_keys) > cls._MAX_EXTRA_FIELDS: + for key in extra_keys[cls._MAX_EXTRA_FIELDS :]: + del values[key] + for key in extra_keys: + if key in values and isinstance(values[key], str) and len(values[key]) > cls._MAX_EXTRA_VALUE_LEN: + values[key] = values[key][: cls._MAX_EXTRA_VALUE_LEN] + return values + + +def get_client_info(request) -> dict: + """ + Extract client_info from request body, merged with server-observed values. + + Server-observed fields (ip, user_agent) always come from the server and + cannot be spoofed by the client. + Client-reported fields come from request.data["client_info"] when provided. + Handles bare-list payloads (legacy /result format) gracefully. + """ + data = request.data if isinstance(request.data, dict) else {} + raw = data.get("client_info") or {} + + try: + info = ProcessingServiceClientInfo(**raw) + except Exception: + info = ProcessingServiceClientInfo() + + # Always overwrite server-observed fields to prevent client spoofing. + # Note: X-Forwarded-For can be spoofed unless a trusted proxy strips it. + # This IP is informational (debugging/audit) and not used for access control. + forwarded = request.headers.get("x-forwarded-for") + info.ip = forwarded.split(",")[0].strip() if forwarded else request.META.get("REMOTE_ADDR", "unknown") + info.user_agent = request.headers.get("user-agent", "") + + return info.dict() + + class BoundingBox(pydantic.BaseModel): x1: float y1: float @@ -62,7 +137,7 @@ class AlgorithmReference(pydantic.BaseModel): class AlgorithmCategoryMapResponse(pydantic.BaseModel): data: list[dict] = pydantic.Field( - default_factory=dict, + default_factory=list, description="Complete data for each label, such as id, gbif_key, explicit index, source, etc.", examples=[ [ @@ -262,24 +337,6 @@ class PipelineProcessingTask(pydantic.BaseModel): # config: PipelineRequestConfigParameters | dict | None = None -class ProcessingServiceClientInfo(pydantic.BaseModel): - """Identity metadata sent by a processing service worker. - - A single ProcessingService record in the database may have multiple - physical workers, pods, or machines running simultaneously. This model - lets the server distinguish between them for logging, debugging, and - eventually for per-worker health tracking. - - Fields are intentionally left open for now. Processing services can - send any key-value pairs they find useful (e.g. hostname, pod_name, - software version). The schema will be tightened once real-world usage - patterns emerge. - """ - - class Config: - extra = "allow" - - class PipelineTaskResult(pydantic.BaseModel): """ The result from processing a single PipelineProcessingTask. @@ -360,12 +417,3 @@ class PipelineRegistrationResponse(pydantic.BaseModel): pipelines: list[PipelineConfigResponse] = [] pipelines_created: list[str] = [] algorithms_created: list[str] = [] - - -class AsyncPipelineRegistrationRequest(pydantic.BaseModel): - """ - Request to register pipelines from an async processing service - """ - - processing_service_name: str - pipelines: list[PipelineConfigResponse] = [] diff --git a/ami/ml/serializers.py b/ami/ml/serializers.py index e7e9e6aaf..688a5379c 100644 --- a/ami/ml/serializers.py +++ b/ami/ml/serializers.py @@ -7,7 +7,7 @@ from .models.pipeline import Pipeline, PipelineStage from .models.processing_service import ProcessingService from .models.project_pipeline_config import ProjectPipelineConfig -from .schemas import PipelineConfigResponse +from .schemas import PipelineConfigResponse, ProcessingServiceClientInfo class AlgorithmCategoryMapSerializer(DefaultSerializer): @@ -138,6 +138,7 @@ class ProcessingServiceSerializer(DefaultSerializer): pipelines = PipelineNestedSerializer(many=True, read_only=True) projects = serializers.SerializerMethodField() is_async = serializers.BooleanField(read_only=True) + api_key_prefix = serializers.SerializerMethodField() endpoint_url = serializers.CharField(required=False, allow_null=True, allow_blank=False, max_length=1024) class Meta: @@ -151,6 +152,8 @@ class Meta: "endpoint_url", "is_async", "pipelines", + "api_key_prefix", + "last_seen_client_info", "created_at", "updated_at", "last_seen", @@ -164,7 +167,15 @@ def get_projects(self, obj): """ return list(obj.projects.values_list("id", flat=True)) + def get_api_key_prefix(self, obj): + # Use prefetched api_keys to avoid N+1 queries + active_keys = [k for k in obj.api_keys.all() if not k.revoked] + if not active_keys: + return None + latest = max(active_keys, key=lambda k: k.created) + return latest.prefix + class PipelineRegistrationSerializer(serializers.Serializer): - processing_service_name = serializers.CharField() + client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) pipelines = SchemaField(schema=list[PipelineConfigResponse], default=[]) diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 36ba5b5f7..9c1a468af 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -192,6 +192,62 @@ def test_mark_seen_offline(self): self.assertIsNotNone(service.last_seen) self.assertFalse(service.last_seen_live) + def test_mark_seen_merges_client_info(self): + """Test that heartbeat mark_seen() merges client_info instead of overwriting. + + Registration sets rich client_info (hostname, software, version). + Subsequent heartbeats (task polling) should update server-observed fields + (ip, user_agent) but preserve client-reported fields when the new values are empty. + """ + service = ProcessingService.objects.create(name="Merge Test", endpoint_url=None) + + # Simulate registration with rich client_info + registration_info = { + "hostname": "gpu-worker-01", + "software": "ami-data-companion", + "version": "1.2.3", + "platform": "linux", + "ip": "10.0.0.1", + "user_agent": "python-requests/2.31", + } + service.mark_seen(client_info=registration_info) + service.refresh_from_db() + self.assertEqual(service.last_seen_client_info["hostname"], "gpu-worker-01") + self.assertEqual(service.last_seen_client_info["software"], "ami-data-companion") + + # Simulate heartbeat with only server-observed fields (empty client-reported) + heartbeat_info = { + "hostname": "", + "software": "", + "version": "", + "platform": "", + "ip": "10.0.0.2", # IP may change + "user_agent": "python-requests/2.31", + } + service.mark_seen(client_info=heartbeat_info) + service.refresh_from_db() + + # Server-observed fields updated + self.assertEqual(service.last_seen_client_info["ip"], "10.0.0.2") + # Client-reported fields preserved from registration + self.assertEqual(service.last_seen_client_info["hostname"], "gpu-worker-01") + self.assertEqual(service.last_seen_client_info["software"], "ami-data-companion") + self.assertEqual(service.last_seen_client_info["version"], "1.2.3") + + def test_mark_seen_client_info_overwrite_with_new_values(self): + """Test that non-empty client-reported fields DO overwrite existing values.""" + service = ProcessingService.objects.create(name="Overwrite Test", endpoint_url=None) + + service.mark_seen(client_info={"hostname": "old-host", "software": "v1", "ip": "1.1.1.1"}) + service.refresh_from_db() + + service.mark_seen(client_info={"hostname": "new-host", "software": "v2", "ip": "2.2.2.2"}) + service.refresh_from_db() + + self.assertEqual(service.last_seen_client_info["hostname"], "new-host") + self.assertEqual(service.last_seen_client_info["software"], "v2") + self.assertEqual(service.last_seen_client_info["ip"], "2.2.2.2") + def test_get_status_updates_last_seen_for_sync_service(self): """Test that get_status() updates last_seen fields for sync services (even if endpoint is unreachable).""" service = ProcessingService.objects.create(name="Sync Service", endpoint_url="http://nonexistent-host:9999") @@ -223,51 +279,50 @@ def test_model_has_last_seen_fields(self): class TestProjectPipelineRegistrationUpdatesLastSeen(APITestCase): - """Test that async pipeline registration updates last_seen on the processing service.""" + """Test that API-key-authenticated pipeline registration updates last_seen.""" def setUp(self): - from ami.users.roles import ProjectManager, create_roles_for_project + from unittest.mock import patch + + from ami.ml.models.processing_service import ProcessingServiceAPIKey self.user = User.objects.create_user(email="lastseen@example.com") # type: ignore self.project = Project.objects.create(name="Last Seen Project", owner=self.user, create_defaults=False) - create_roles_for_project(self.project) - ProjectManager.assign_user(self.user, self.project) + + with patch.object(ProcessingService, "get_status"): + self.service = ProcessingService.objects.create(name="AsyncTestWorker", endpoint_url=None) + self.service.projects.add(self.project) + _, self.api_key = ProcessingServiceAPIKey.objects.create_key(name="test-key", processing_service=self.service) def test_pipeline_registration_marks_service_as_seen(self): """Test that POSTing to the pipeline registration endpoint marks the service as last_seen_live.""" url = f"/api/v2/projects/{self.project.pk}/pipelines/" - payload = { - "processing_service_name": "AsyncTestWorker", - "pipelines": [], - } + payload = {"pipelines": []} - self.client.force_authenticate(user=self.user) + self.client.credentials(HTTP_AUTHORIZATION=f"Api-Key {self.api_key}") response = self.client.post(url, payload, format="json") self.assertEqual(response.status_code, 201) - service = ProcessingService.objects.get(name="AsyncTestWorker") - self.assertIsNotNone(service.last_seen) - self.assertTrue(service.last_seen_live) + self.service.refresh_from_db() + self.assertIsNotNone(self.service.last_seen) + self.assertTrue(self.service.last_seen_live) def test_repeated_registration_updates_last_seen(self): """Test that re-registering updates the last_seen timestamp.""" url = f"/api/v2/projects/{self.project.pk}/pipelines/" - payload = { - "processing_service_name": "AsyncTestWorkerRepeat", - "pipelines": [], - } + payload = {"pipelines": []} - self.client.force_authenticate(user=self.user) + self.client.credentials(HTTP_AUTHORIZATION=f"Api-Key {self.api_key}") # First registration self.client.post(url, payload, format="json") - service = ProcessingService.objects.get(name="AsyncTestWorkerRepeat") - first_seen = service.last_seen + self.service.refresh_from_db() + first_seen = self.service.last_seen # Second registration self.client.post(url, payload, format="json") - service.refresh_from_db() - second_seen = service.last_seen + self.service.refresh_from_db() + second_seen = self.service.last_seen self.assertIsNotNone(first_seen) self.assertIsNotNone(second_seen) @@ -1366,3 +1421,426 @@ def test_cleanup_removes_failed_set(self): # Verify all state is gone (get_progress returns None when total_key is deleted) progress = self.manager.get_progress("process") self.assertIsNone(progress) + + +class TestProcessingServiceAuth(APITestCase): + def setUp(self): + from unittest.mock import patch + + from ami.ml.models.processing_service import ProcessingServiceAPIKey + from ami.users.tests.factories import UserFactory + + self.user = UserFactory(is_staff=True) + self.project = Project.objects.create(name="Auth Test Project", owner=self.user) + with patch.object(ProcessingService, "get_status"): + self.ps = ProcessingService.objects.create( + name="Test Service", + endpoint_url=None, + ) + self.ps.projects.add(self.project) + self.api_key_obj, self.api_key = ProcessingServiceAPIKey.objects.create_key( + name="test-key", + processing_service=self.ps, + ) + + def test_authenticate_valid_key(self): + from django.contrib.auth.models import AnonymousUser + + from ami.ml.auth import ProcessingServiceAPIKeyAuthentication + + factory = APIRequestFactory() + request = factory.get("/", HTTP_AUTHORIZATION=f"Api-Key {self.api_key}") + + auth = ProcessingServiceAPIKeyAuthentication() + result = auth.authenticate(request) + + self.assertIsNotNone(result) + user, ps = result + self.assertEqual(ps.pk, self.ps.pk) + self.assertIsInstance(user, AnonymousUser) + + def test_authenticate_invalid_key_raises(self): + from rest_framework.exceptions import AuthenticationFailed + + from ami.ml.auth import ProcessingServiceAPIKeyAuthentication + + factory = APIRequestFactory() + request = factory.get("/", HTTP_AUTHORIZATION="Api-Key invalid.key") + + auth = ProcessingServiceAPIKeyAuthentication() + with self.assertRaises(AuthenticationFailed): + auth.authenticate(request) + + def test_authenticate_non_api_key_passes_through(self): + """Non Api-Key tokens should return None (fall through to next backend).""" + from ami.ml.auth import ProcessingServiceAPIKeyAuthentication + + factory = APIRequestFactory() + request = factory.get("/", HTTP_AUTHORIZATION="Token some_djoser_token") + auth = ProcessingServiceAPIKeyAuthentication() + result = auth.authenticate(request) + self.assertIsNone(result) + + def test_authenticate_no_header(self): + from ami.ml.auth import ProcessingServiceAPIKeyAuthentication + + factory = APIRequestFactory() + request = factory.get("/") + auth = ProcessingServiceAPIKeyAuthentication() + result = auth.authenticate(request) + self.assertIsNone(result) + + def test_revoked_key_raises(self): + from rest_framework.exceptions import AuthenticationFailed + + from ami.ml.auth import ProcessingServiceAPIKeyAuthentication + + self.api_key_obj.revoked = True + self.api_key_obj.save() + + factory = APIRequestFactory() + request = factory.get("/", HTTP_AUTHORIZATION=f"Api-Key {self.api_key}") + + auth = ProcessingServiceAPIKeyAuthentication() + with self.assertRaises(AuthenticationFailed): + auth.authenticate(request) + + +class TestProcessingServiceAPIKey(TestCase): + def _create_ps(self, **kwargs): + """Create a ProcessingService, mocking get_status to avoid HTTP calls.""" + from unittest.mock import patch + + with patch.object(ProcessingService, "get_status"): + return ProcessingService.objects.create(**kwargs) + + def test_create_ps_without_api_key(self): + """Sync services don't need an API key.""" + ps = self._create_ps(name="Sync Service", endpoint_url="http://example.com:2000") + self.assertEqual(ps.api_keys.count(), 0) + + def test_create_and_assign_api_key(self): + from ami.ml.models.processing_service import ProcessingServiceAPIKey + + ps = self._create_ps(name="Async Service", endpoint_url=None) + api_key_obj, plaintext_key = ProcessingServiceAPIKey.objects.create_key( + name="test-key", + processing_service=ps, + ) + self.assertIsNotNone(plaintext_key) + self.assertEqual(len(api_key_obj.prefix), 8) + self.assertIn(".", plaintext_key) + self.assertEqual(ps.api_keys.count(), 1) + + def test_revoke_and_create_new_key(self): + from ami.ml.models.processing_service import ProcessingServiceAPIKey + + ps = self._create_ps(name="Service", endpoint_url=None) + old_obj, old_key = ProcessingServiceAPIKey.objects.create_key( + name="key-1", + processing_service=ps, + ) + old_obj.revoked = True + old_obj.save() + + new_obj, new_key = ProcessingServiceAPIKey.objects.create_key( + name="key-2", + processing_service=ps, + ) + self.assertNotEqual(old_key, new_key) + self.assertEqual(ps.api_keys.filter(revoked=False).count(), 1) + self.assertEqual(ps.api_keys.filter(revoked=True).count(), 1) + + def test_last_seen_client_info_stored(self): + ps = self._create_ps(name="Service 2", endpoint_url=None) + ps.last_seen_client_info = {"hostname": "node-01", "software": "adc", "version": "2.0"} + ps.save() + ps.refresh_from_db() + self.assertEqual(ps.last_seen_client_info["hostname"], "node-01") + + +class TestProcessingServiceClientInfo(TestCase): + def test_valid_client_info(self): + from ami.ml.schemas import ProcessingServiceClientInfo + + info = ProcessingServiceClientInfo( + hostname="cedar-node-01", + software="ami-data-companion", + version="2.1.0", + platform="Linux x86_64", + ) + self.assertEqual(info.hostname, "cedar-node-01") + self.assertEqual(info.software, "ami-data-companion") + + def test_empty_client_info_is_valid(self): + from ami.ml.schemas import ProcessingServiceClientInfo + + info = ProcessingServiceClientInfo() + self.assertEqual(info.hostname, "") + self.assertEqual(info.ip, "") + + def test_extra_fields_allowed(self): + from ami.ml.schemas import ProcessingServiceClientInfo + + info = ProcessingServiceClientInfo( + hostname="node-01", + gpu="A100", + cuda="12.0", + ) + self.assertEqual(info.hostname, "node-01") + d = info.dict() + self.assertEqual(d["gpu"], "A100") + + def test_dict_roundtrip(self): + from ami.ml.schemas import ProcessingServiceClientInfo + + info = ProcessingServiceClientInfo( + hostname="node-01", + software="adc", + version="1.0", + ) + d = info.dict() + restored = ProcessingServiceClientInfo(**d) + self.assertEqual(restored.hostname, "node-01") + + +class TestProcessingServiceSerializerFields(APITestCase): + def setUp(self): + from ami.users.tests.factories import UserFactory + + self.user = UserFactory(is_staff=True) + self.project = Project.objects.create(name="Serializer Test", owner=self.user) + + def test_serializer_includes_api_key_prefix(self): + from unittest.mock import patch + + from ami.base.serializers import reverse_with_params + from ami.ml.models.processing_service import ProcessingServiceAPIKey + + with patch.object(ProcessingService, "get_status"): + ps = ProcessingService.objects.create(name="Test PS Serializer", endpoint_url=None) + ps.projects.add(self.project) + api_key_obj, _ = ProcessingServiceAPIKey.objects.create_key( + name="test-key", + processing_service=ps, + ) + + self.client.force_authenticate(user=self.user) + url = reverse_with_params( + "api:processingservice-detail", + kwargs={"pk": ps.pk}, + params={"project_id": self.project.pk}, + ) + resp = self.client.get(url) + self.assertEqual(resp.status_code, 200) + self.assertIn("api_key_prefix", resp.data) + self.assertEqual(resp.data["api_key_prefix"], api_key_obj.prefix) + + def test_serializer_includes_last_seen_client_info(self): + from unittest.mock import patch + + from ami.base.serializers import reverse_with_params + + with patch.object(ProcessingService, "get_status"): + ps = ProcessingService.objects.create(name="Test PS ClientInfo", endpoint_url=None) + ps.projects.add(self.project) + ps.last_seen_client_info = {"hostname": "node-01", "software": "adc"} + ps.save() + + self.client.force_authenticate(user=self.user) + url = reverse_with_params( + "api:processingservice-detail", + kwargs={"pk": ps.pk}, + params={"project_id": self.project.pk}, + ) + resp = self.client.get(url) + self.assertEqual(resp.status_code, 200) + self.assertIn("last_seen_client_info", resp.data) + self.assertEqual(resp.data["last_seen_client_info"]["hostname"], "node-01") + + def test_serializer_no_key_shows_null_prefix(self): + from unittest.mock import patch + + from ami.base.serializers import reverse_with_params + + with patch.object(ProcessingService, "get_status"): + ps = ProcessingService.objects.create(name="No Key PS", endpoint_url=None) + ps.projects.add(self.project) + + self.client.force_authenticate(user=self.user) + url = reverse_with_params( + "api:processingservice-detail", + kwargs={"pk": ps.pk}, + params={"project_id": self.project.pk}, + ) + resp = self.client.get(url) + self.assertEqual(resp.status_code, 200) + self.assertIsNone(resp.data["api_key_prefix"]) + + +class TestGenerateKeyAction(APITestCase): + def setUp(self): + from unittest.mock import patch + + from ami.users.tests.factories import UserFactory + + self.user = UserFactory(is_staff=True) + self.project = Project.objects.create(name="Key Gen Test", owner=self.user) + with patch.object(ProcessingService, "get_status"): + self.ps = ProcessingService.objects.create(name="Key Gen PS", endpoint_url=None) + self.ps.projects.add(self.project) + + def test_generate_key_returns_full_key(self): + from ami.base.serializers import reverse_with_params + + self.client.force_authenticate(user=self.user) + url = reverse_with_params( + "api:processingservice-generate-key", + args=[self.ps.pk], + params={"project_id": self.project.pk}, + ) + resp = self.client.post(url) + self.assertEqual(resp.status_code, 200) + self.assertIn("api_key", resp.data) + self.assertIn(".", resp.data["api_key"]) + self.assertIn("prefix", resp.data) + + def test_regenerate_key_revokes_old_key(self): + from ami.base.serializers import reverse_with_params + from ami.ml.models.processing_service import ProcessingServiceAPIKey + + ProcessingServiceAPIKey.objects.create_key( + name=f"{self.ps.name} key", + processing_service=self.ps, + ) + self.assertEqual(self.ps.api_keys.filter(revoked=False).count(), 1) + + self.client.force_authenticate(user=self.user) + url = reverse_with_params( + "api:processingservice-generate-key", + args=[self.ps.pk], + params={"project_id": self.project.pk}, + ) + resp = self.client.post(url) + self.assertEqual(resp.status_code, 200) + + self.assertEqual(self.ps.api_keys.filter(revoked=False).count(), 1) + self.assertEqual(self.ps.api_keys.filter(revoked=True).count(), 1) + + def test_full_key_not_in_get_response(self): + from ami.base.serializers import reverse_with_params + from ami.ml.models.processing_service import ProcessingServiceAPIKey + + ProcessingServiceAPIKey.objects.create_key( + name="test-key", + processing_service=self.ps, + ) + self.client.force_authenticate(user=self.user) + url = reverse_with_params( + "api:processingservice-detail", + kwargs={"pk": self.ps.pk}, + params={"project_id": self.project.pk}, + ) + resp = self.client.get(url) + self.assertEqual(resp.status_code, 200) + self.assertIn("api_key_prefix", resp.data) + self.assertNotIn("api_key", resp.data) + + +class TestProcessingServiceE2EFlow(APITestCase): + """End-to-end: create PS -> generate key -> register pipelines -> verify heartbeat.""" + + def setUp(self): + from ami.users.tests.factories import UserFactory + + self.user = UserFactory(is_staff=True) + self.project = Project.objects.create(name="E2E Project", owner=self.user) + from rest_framework.authtoken.models import Token + + self.admin_token = Token.objects.create(user=self.user) + + def test_full_lifecycle(self): + from rest_framework.test import APIClient + + from ami.base.serializers import reverse_with_params + + admin_client = APIClient() + admin_client.credentials(HTTP_AUTHORIZATION=f"Token {self.admin_token.key}") + + # Step 1: Admin creates PS via API + create_url = reverse_with_params("api:processingservice-list", params={"project_id": self.project.pk}) + resp = admin_client.post(create_url, {"name": "E2E Worker", "description": "Test"}, format="json") + self.assertEqual(resp.status_code, 201) + ps_id = resp.data["instance"]["id"] + + # Step 2: Generate API key + gen_url = reverse_with_params( + "api:processingservice-generate-key", + args=[ps_id], + params={"project_id": self.project.pk}, + ) + resp = admin_client.post(gen_url) + self.assertEqual(resp.status_code, 200) + api_key = resp.data["api_key"] + self.assertIn(".", api_key) # Library format: prefix.secret + + # Step 3: Worker authenticates with API key and registers pipelines + worker_client = APIClient() + worker_client.credentials(HTTP_AUTHORIZATION=f"Api-Key {api_key}") + + register_url = f"/api/v2/projects/{self.project.pk}/pipelines/" + resp = worker_client.post( + register_url, + { + "client_info": { + "hostname": "e2e-test-node", + "software": "test-worker", + "version": "0.1.0", + }, + "pipelines": [], + }, + format="json", + ) + self.assertEqual(resp.status_code, 201) + + # Step 4: Verify PS was updated with heartbeat + client_info + ps = ProcessingService.objects.get(pk=ps_id) + self.assertTrue(ps.last_seen_live) + self.assertEqual(ps.last_seen_client_info["hostname"], "e2e-test-node") + + # Step 5: Verify GET shows prefix but not full key + detail_url = reverse_with_params( + "api:processingservice-detail", + kwargs={"pk": ps_id}, + params={"project_id": self.project.pk}, + ) + resp = admin_client.get(detail_url) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.data["last_seen_client_info"]["hostname"], "e2e-test-node") + self.assertIn("api_key_prefix", resp.data) + self.assertNotIn("api_key", resp.data) + + def test_api_key_denied_for_wrong_project(self): + """API key for PS not linked to the target project should be denied.""" + from unittest.mock import patch + + from rest_framework.test import APIClient + + from ami.ml.models.processing_service import ProcessingServiceAPIKey + + with patch.object(ProcessingService, "get_status"): + ps = ProcessingService.objects.create(name="Wrong Project PS", endpoint_url=None) + ps.projects.add(self.project) + _, api_key = ProcessingServiceAPIKey.objects.create_key( + name="test-key", + processing_service=ps, + ) + + other_project = Project.objects.create(name="Other Project", owner=self.user) + + worker_client = APIClient() + worker_client.credentials(HTTP_AUTHORIZATION=f"Api-Key {api_key}") + + url = f"/api/v2/projects/{other_project.pk}/pipelines/" + resp = worker_client.post(url, {"pipelines": []}, format="json") + self.assertIn(resp.status_code, [403, 404]) diff --git a/ami/ml/views.py b/ami/ml/views.py index 58832a10b..308d1118b 100644 --- a/ami/ml/views.py +++ b/ami/ml/views.py @@ -16,6 +16,8 @@ from ami.main.api.schemas import project_id_doc_param from ami.main.api.views import DefaultViewSet from ami.main.models import Project, SourceImage +from ami.ml.auth import HasProcessingServiceAPIKey +from ami.ml.models.processing_service import ProcessingServiceAPIKey from ami.ml.schemas import PipelineRegistrationResponse from .models.algorithm import Algorithm, AlgorithmCategoryMap @@ -147,7 +149,7 @@ class ProcessingServiceViewSet(DefaultViewSet, ProjectMixin): API endpoint that allows processing services to be viewed or edited. """ - queryset = ProcessingService.objects.all() + queryset = ProcessingService.objects.all().prefetch_related("api_keys") serializer_class = ProcessingServiceSerializer filterset_fields = ["projects"] ordering_fields = ["id", "created_at", "updated_at"] @@ -208,13 +210,43 @@ def register_pipelines(self, request: Request, pk=None) -> Response: processing_service.save() return Response(response.dict()) + @extend_schema( + operation_id="processing_services_generate_key", + summary="Generate or regenerate API key", + description="Generates a new API key, revoking any existing one. " + "The full key is only shown in this response.", + responses={200: dict}, + tags=["ml"], + ) + @action(detail=True, methods=["post"], url_path="generate_key") + def generate_key(self, request: Request, pk=None) -> Response: + instance = self.get_object() + + with transaction.atomic(): + # Revoke existing keys + instance.api_keys.filter(revoked=False).update(revoked=True) + + # Create new key via library + api_key_obj, plaintext_key = ProcessingServiceAPIKey.objects.create_key( + name=f"{instance.name} key", + processing_service=instance, + ) + + return Response( + { + "api_key": plaintext_key, + "prefix": api_key_obj.prefix, + "message": "API key generated. This is the only time the full key will be shown.", + } + ) + class ProjectPipelineViewSet(ProjectMixin, mixins.ListModelMixin, mixins.CreateModelMixin, viewsets.GenericViewSet): """Pipelines for a specific project. GET lists, POST registers.""" queryset = Pipeline.objects.none() serializer_class = PipelineSerializer - permission_classes = [ProjectPipelineConfigPermission] + permission_classes = [ProjectPipelineConfigPermission | HasProcessingServiceAPIKey] require_project = True def get_queryset(self) -> QuerySet: @@ -265,19 +297,19 @@ def create(self, request, *args, **kwargs): serializer.is_valid(raise_exception=True) project = self.get_active_project() - with transaction.atomic(): - processing_service, _ = ProcessingService.objects.get_or_create( - name=serializer.validated_data["processing_service_name"], - defaults={"endpoint_url": None}, - ) - processing_service.projects.add(project) + if not isinstance(request.auth, ProcessingService): + raise api_exceptions.ValidationError("Pipeline registration requires API key authentication.") + processing_service = request.auth + with transaction.atomic(): response = processing_service.create_pipelines( pipeline_configs=serializer.validated_data["pipelines"], projects=Project.objects.filter(pk=project.pk), ) - # Record that we heard from this async processing service - processing_service.mark_seen(live=True) + # Update heartbeat and client info + from ami.ml.schemas import get_client_info + + processing_service.mark_seen(client_info=get_client_info(request)) return Response(response.dict(), status=status.HTTP_201_CREATED) diff --git a/config/settings/base.py b/config/settings/base.py index c3a8750dc..792d48931 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -87,6 +87,7 @@ "django_celery_beat", "rest_framework", "rest_framework.authtoken", + "rest_framework_api_key", "djoser", "corsheaders", "drf_spectacular", @@ -436,6 +437,7 @@ def _celery_result_backend_url(redis_url): "DEFAULT_AUTHENTICATION_CLASSES": ( # "rest_framework.authentication.SessionAuthentication", "rest_framework.authentication.TokenAuthentication", + "ami.ml.auth.ProcessingServiceAPIKeyAuthentication", ), "DEFAULT_PERMISSION_CLASSES": ("ami.base.permissions.IsActiveStaffOrReadOnly",), # "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.AllowAny",), diff --git a/processing_services/minimal/Dockerfile b/processing_services/minimal/Dockerfile index 0686b4471..2583ef130 100644 --- a/processing_services/minimal/Dockerfile +++ b/processing_services/minimal/Dockerfile @@ -5,5 +5,6 @@ WORKDIR /app COPY . /app RUN pip install -r ./requirements.txt +RUN chmod +x /app/start.sh -CMD ["python", "/app/main.py"] +CMD ["/app/start.sh"] diff --git a/processing_services/minimal/register.py b/processing_services/minimal/register.py new file mode 100644 index 000000000..425b8894e --- /dev/null +++ b/processing_services/minimal/register.py @@ -0,0 +1,267 @@ +""" +Register this processing service's pipelines with Antenna. + +Supports two modes: + +1. **API key mode** (recommended for production): + Set ANTENNA_API_KEY to an existing key. The service authenticates + directly and registers its pipelines. + +2. **Self-provisioning mode** (for local development / docker compose): + Set ANTENNA_USER and ANTENNA_PASSWORD (defaults to the local dev + superuser). The service logs in, creates itself via the REST API, + generates its own API key, and registers pipelines. The generated + API key is written to /tmp/antenna_api_key for subsequent requests. + +Environment variables: + ANTENNA_API_URL: Base URL of the Antenna API (e.g., http://django:8000) + ANTENNA_PROJECT_ID: Project ID to register pipelines for + ANTENNA_API_KEY: API key for authentication (mode 1) + ANTENNA_USER: Username for self-provisioning (mode 2) + ANTENNA_PASSWORD: Password for self-provisioning (mode 2) + ANTENNA_SERVICE_NAME: Name for the processing service (default: hostname) +""" + +import logging +import os +import platform +import socket +import sys +import time + +import requests + +logger = logging.getLogger(__name__) + +MAX_RETRIES = 10 +RETRY_DELAY = 5 # seconds + +# Local dev defaults (matches .envs/.local/.django) +DEFAULT_USER = "antenna@insectai.org" +DEFAULT_PASSWORD = "localadmin" + + +def get_client_info() -> dict: + return { + "hostname": socket.gethostname(), + "software": "antenna-minimal-ps", + "version": "0.1.0", + "platform": platform.platform(), + } + + +def get_own_pipeline_configs(port: int = 2000) -> list[dict]: + """Fetch pipeline configs from our own /info endpoint.""" + resp = requests.get(f"http://localhost:{port}/info", timeout=5) + resp.raise_for_status() + info = resp.json() + return info.get("pipelines", []) + + +def login(api_url: str, email: str, password: str) -> str: + """Log in with user credentials and return an auth token.""" + resp = requests.post( + f"{api_url}/api/v2/auth/token/login/", + json={"email": email, "password": password}, + timeout=10, + ) + resp.raise_for_status() + token = resp.json().get("auth_token") + if not token: + raise ValueError(f"Login succeeded but no auth_token in response: {resp.json()}") + return token + + +def create_processing_service(api_url: str, token: str, project_id: str, name: str) -> dict: + """Create a processing service via the REST API, or return an existing one.""" + headers = {"Authorization": f"Token {token}"} + + # Check if a service with this name already exists in the project + resp = requests.get( + f"{api_url}/api/v2/processing-services/", + headers=headers, + params={"project_id": project_id}, + timeout=10, + ) + resp.raise_for_status() + results = resp.json().get("results", []) + for svc in results: + if svc.get("name") == name: + logger.info(f"Found existing processing service: {name} (id={svc['id']})") + return svc + + # Create new service (no endpoint_url = async/pull-mode) + resp = requests.post( + f"{api_url}/api/v2/processing-services/", + headers=headers, + params={"project_id": project_id}, + json={"name": name}, + timeout=10, + ) + resp.raise_for_status() + svc = resp.json().get("instance", resp.json()) + logger.info(f"Created processing service: {name} (id={svc['id']})") + return svc + + +def generate_api_key(api_url: str, token: str, service_id: int) -> str: + """Generate an API key for the processing service and return the plaintext key.""" + headers = {"Authorization": f"Token {token}"} + resp = requests.post( + f"{api_url}/api/v2/processing-services/{service_id}/generate_key/", + headers=headers, + timeout=10, + ) + resp.raise_for_status() + data = resp.json() + return data["api_key"] + + +def register_with_antenna( + api_url: str, + api_key: str, + project_id: str, + pipelines: list[dict], + client_info: dict, +) -> bool: + """Register pipelines with Antenna's pipeline registration endpoint.""" + url = f"{api_url}/api/v2/projects/{project_id}/pipelines/" + headers = {"Authorization": f"Api-Key {api_key}"} + payload = { + "pipelines": pipelines, + "client_info": client_info, + } + + resp = requests.post(url, json=payload, headers=headers, timeout=30) + if resp.status_code == 201: + logger.info(f"Registered {len(pipelines)} pipelines with Antenna") + return True + else: + logger.error(f"Registration failed: {resp.status_code} {resp.text}") + return False + + +def self_provision(api_url: str, project_id: str, email: str, password: str) -> str: + """ + Self-provision a processing service and return a usable API key. + + Logs in with user credentials, creates the processing service (or finds + an existing one), and generates an API key. The key is also written to + /tmp/antenna_api_key for use by subsequent requests. + """ + service_name = os.environ.get("ANTENNA_SERVICE_NAME", socket.gethostname()) + + logger.info(f"Self-provisioning as '{service_name}' (user: {email})") + token = login(api_url, email, password) + + svc = create_processing_service(api_url, token, project_id, service_name) + service_id = svc["id"] + + # Check if the service already has a key (from a previous run) + existing_prefix = svc.get("api_key_prefix") + api_key_file = "/tmp/antenna_api_key" + + # If we have a cached key from a previous self-provision, try to use it + if existing_prefix and os.path.exists(api_key_file): + cached_key = open(api_key_file).read().strip() + if cached_key.startswith(existing_prefix.split(".")[0]): + logger.info(f"Reusing cached API key (prefix: {existing_prefix})") + return cached_key + + # Generate a new key (revokes any previous ones) + api_key = generate_api_key(api_url, token, service_id) + logger.info(f"Generated new API key for {service_name}") + + # Cache for subsequent requests + with open(api_key_file, "w") as f: + f.write(api_key) + + return api_key + + +def main(): + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + api_url = os.environ.get("ANTENNA_API_URL") + api_key = os.environ.get("ANTENNA_API_KEY") + project_id = os.environ.get("ANTENNA_PROJECT_ID") + + if not api_key and not os.environ.get("ANTENNA_USER"): + logger.info("Neither ANTENNA_API_KEY nor ANTENNA_USER set, skipping registration") + return + + if not api_url: + logger.error("ANTENNA_API_URL is required for registration") + sys.exit(1) + + if not project_id: + logger.error("ANTENNA_PROJECT_ID is required for registration") + sys.exit(1) + + # Wait for our own FastAPI server to be ready + for attempt in range(MAX_RETRIES): + try: + resp = requests.get("http://localhost:2000/livez", timeout=2) + resp.raise_for_status() + break + except (requests.ConnectionError, requests.Timeout, requests.HTTPError): + if attempt < MAX_RETRIES - 1: + logger.info(f"Waiting for local server to start (attempt {attempt + 1}/{MAX_RETRIES})...") + time.sleep(RETRY_DELAY) + else: + logger.error("Local server did not start in time") + sys.exit(1) + + # Fetch our own pipeline configs + try: + pipelines = get_own_pipeline_configs() + except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e: + logger.error(f"Failed to fetch pipeline configs from local server: {e}") + sys.exit(1) + + # Self-provision if no API key provided + if not api_key: + email = os.environ.get("ANTENNA_USER", DEFAULT_USER) + password = os.environ.get("ANTENNA_PASSWORD", DEFAULT_PASSWORD) + + for attempt in range(MAX_RETRIES): + try: + api_key = self_provision(api_url, project_id, email, password) + break + except (requests.ConnectionError, requests.Timeout): + pass + except requests.HTTPError as e: + logger.error(f"Self-provisioning failed: {e}") + if attempt == MAX_RETRIES - 1: + sys.exit(1) + + if attempt < MAX_RETRIES - 1: + logger.info(f"Retrying self-provisioning (attempt {attempt + 1}/{MAX_RETRIES})...") + time.sleep(RETRY_DELAY) + else: + logger.error("Failed to self-provision after all retries") + sys.exit(1) + + client_info = get_client_info() + + # Register pipelines with the API key + for attempt in range(MAX_RETRIES): + try: + if register_with_antenna(api_url, api_key, project_id, pipelines, client_info): + return + except (requests.ConnectionError, requests.Timeout): + pass + + if attempt < MAX_RETRIES - 1: + logger.info(f"Retrying registration (attempt {attempt + 1}/{MAX_RETRIES})...") + time.sleep(RETRY_DELAY) + + logger.error("Failed to register with Antenna after all retries") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/processing_services/minimal/start.sh b/processing_services/minimal/start.sh new file mode 100755 index 000000000..72087a10b --- /dev/null +++ b/processing_services/minimal/start.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e + +# Forward signals to child process for graceful shutdown +trap 'kill -TERM $SERVER_PID 2>/dev/null' TERM INT + +# Start FastAPI server in background +python /app/main.py & +SERVER_PID=$! + +# Run registration if API key is configured (non-fatal) +if [ -n "$ANTENNA_API_KEY" ] || [ -n "$ANTENNA_USER" ]; then + python /app/register.py || echo "Registration failed, continuing in push-mode" +fi + +# Wait for the server process +wait $SERVER_PID diff --git a/requirements/base.txt b/requirements/base.txt index 037eeea17..739d29381 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -31,6 +31,7 @@ crispy-bootstrap5==0.7 # https://github.com/django-crispy-forms/crispy-bootstra django-redis==5.3.0 # https://github.com/jazzband/django-redis # Django REST Framework djangorestframework==3.14.0 # https://github.com/encode/django-rest-framework +djangorestframework-api-key==3.0.0 # https://github.com/florimondmanca/djangorestframework-api-key django-cors-headers==4.1.0 # https://github.com/adamchainz/django-cors-headers # DRF-spectacular for api documentation drf-spectacular==0.26.3 # https://github.com/tfranzel/drf-spectacular