diff --git a/openwisp_controller/connection/api/serializers.py b/openwisp_controller/connection/api/serializers.py index 142c1c30a..5d2a6c681 100644 --- a/openwisp_controller/connection/api/serializers.py +++ b/openwisp_controller/connection/api/serializers.py @@ -12,6 +12,7 @@ DeviceConnection = load_model("connection", "DeviceConnection") Credentials = load_model("connection", "Credentials") Device = load_model("config", "Device") +BatchCommand = load_model("connection", "BatchCommand") class ValidatedDeviceFieldSerializer(ValidatedModelSerializer): @@ -43,6 +44,10 @@ class CommandSerializer(ValidatedDeviceFieldSerializer): required=False, pk_field=serializers.UUIDField(format="hex_verbose"), ) + batch_command = serializers.PrimaryKeyRelatedField( + read_only=True, + pk_field=serializers.UUIDField(format="hex_verbose"), + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -115,3 +120,102 @@ class Meta: "is_working": {"read_only": True}, } read_only_fields = ("created", "modified") + + +class BatchCommandExecuteSerializer( + FilterSerializerByOrgManaged, serializers.ModelSerializer +): + type = serializers.CharField() + input = serializers.JSONField(allow_null=True, required=False) + devices = serializers.PrimaryKeyRelatedField( + many=True, + queryset=Device.objects.all(), + required=False, + allow_empty=True, + pk_field=serializers.UUIDField(format="hex_verbose"), + ) + execute_all = serializers.BooleanField(required=False, default=True) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + request = self.context.get("request") + if request and request.method == "GET": + self.fields["type"].required = False + + class Meta: + model = BatchCommand + fields = ( + "organization", + "type", + "input", + "devices", + "group", + "location", + "execute_all", + ) + extra_kwargs = { + "organization": {"required": False, "allow_null": True}, + } + + def validate(self, data): + org = data.get("organization") + execute_all = data.get("execute_all", False) + devices = data.get("devices") + group = data.get("group") + location = data.get("location") + if not org and not self.context["request"].user.is_superuser: + raise serializers.ValidationError( + _("Only superusers can execute batch commands without an organization.") + ) + if not execute_all and not org and not devices and not group and not location: + raise serializers.ValidationError( + _( + "Specify at least one targeting option " + "or set execute_all to true." + ) + ) + if devices: + for device in devices: + if org and device.organization_id != org.id: + raise serializers.ValidationError( + { + "devices": _( + "All devices must belong to the same organization." + ) + } + ) + return data + + +class BatchCommandSerializer(BaseSerializer): + device_count = serializers.IntegerField(source="devices.count", read_only=True) + + class Meta: + model = BatchCommand + fields = ( + "id", + "organization", + "status", + "type", + "input", + "group", + "location", + "device_count", + "created", + "modified", + ) + read_only_fields = ( + "created", + "modified", + ) + + +class BatchCommandDetailSerializer(BatchCommandSerializer): + devices = serializers.PrimaryKeyRelatedField( + many=True, + read_only=True, + pk_field=serializers.UUIDField(format="hex_verbose"), + ) + + class Meta(BatchCommandSerializer.Meta): + fields = BatchCommandSerializer.Meta.fields + ("devices",) diff --git a/openwisp_controller/connection/api/urls.py b/openwisp_controller/connection/api/urls.py index 4ec3e70ab..4a94d171d 100644 --- a/openwisp_controller/connection/api/urls.py +++ b/openwisp_controller/connection/api/urls.py @@ -40,6 +40,21 @@ def get_api_urls(api_views): api_views.deviceconnection_detail_view, name="deviceconnection_detail", ), + path( + "api/v1/controller/batch-command/", + api_views.batch_command_list_view, + name="batch_command_list", + ), + path( + "api/v1/controller/batch-command//", + api_views.batch_command_detail_view, + name="batch_command_detail", + ), + path( + "api/v1/controller/batch-command/execute/", + api_views.batch_command_execute_view, + name="batch_command_execute", + ), ] diff --git a/openwisp_controller/connection/api/views.py b/openwisp_controller/connection/api/views.py index 6af1270c7..0a4912d76 100644 --- a/openwisp_controller/connection/api/views.py +++ b/openwisp_controller/connection/api/views.py @@ -1,12 +1,17 @@ +from django.core.exceptions import ValidationError from django.utils.translation import gettext_lazy as _ from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema +from rest_framework import status from rest_framework.generics import ( + GenericAPIView, + ListAPIView, ListCreateAPIView, RetrieveAPIView, RetrieveUpdateDestroyAPIView, get_object_or_404, ) +from rest_framework.response import Response from swapper import load_model from openwisp_utils.api.pagination import OpenWispPagination @@ -17,6 +22,9 @@ RelatedDeviceProtectedAPIMixin, ) from .serializers import ( + BatchCommandDetailSerializer, + BatchCommandExecuteSerializer, + BatchCommandSerializer, CommandSerializer, CredentialSerializer, DeviceConnectionSerializer, @@ -26,6 +34,7 @@ Device = load_model("config", "Device") Credentials = load_model("connection", "Credentials") DeviceConnection = load_model("connection", "DeviceConnection") +BatchCommand = load_model("connection", "BatchCommand") class BaseCommandView(RelatedDeviceProtectedAPIMixin): @@ -138,6 +147,50 @@ class DeviceConnectionListCreateView(BaseDeviceConnection, ListCreateAPIView): DeviceConnenctionListCreateView = DeviceConnectionListCreateView +class BatchCommandExecuteView(ProtectedAPIMixin, GenericAPIView): + model = BatchCommand + queryset = BatchCommand.objects.all() + serializer_class = BatchCommandExecuteSerializer + + def post(self, request): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + serializer.validated_data.pop("execute_all", None) + try: + batch = BatchCommand.execute(**serializer.validated_data) + except ValidationError as e: + return Response( + getattr(e, "message_dict", e.messages), + status=status.HTTP_400_BAD_REQUEST, + ) + return Response({"batch": str(batch.pk)}, status=201) + + def get(self, request): + serializer = self.get_serializer(data=request.query_params) + serializer.is_valid(raise_exception=True) + serializer.validated_data.pop("execute_all", None) + try: + data = BatchCommand.dry_run(**serializer.validated_data) + except ValidationError as e: + return Response( + getattr(e, "message_dict", e.messages), + status=status.HTTP_400_BAD_REQUEST, + ) + data["devices"] = [str(d.pk) for d in data["devices"]] + return Response(data) + + +class BatchCommandListView(ProtectedAPIMixin, ListAPIView): + queryset = BatchCommand.objects.all().order_by("-created") + serializer_class = BatchCommandSerializer + pagination_class = OpenWispPagination + + +class BatchCommandDetailView(ProtectedAPIMixin, RetrieveAPIView): + queryset = BatchCommand.objects.all() + serializer_class = BatchCommandDetailSerializer + + class DeviceConnectionDetailView(BaseDeviceConnection, RetrieveUpdateDestroyAPIView): def get_object(self): queryset = self.filter_queryset(self.get_queryset()) @@ -158,3 +211,6 @@ def get_object(self): # TODO: remove in version 1.4 deviceconnection_details_view = deviceconnection_detail_view +batch_command_execute_view = BatchCommandExecuteView.as_view() +batch_command_list_view = BatchCommandListView.as_view() +batch_command_detail_view = BatchCommandDetailView.as_view() diff --git a/openwisp_controller/connection/base/models.py b/openwisp_controller/connection/base/models.py index 17a06f7bb..e92888474 100644 --- a/openwisp_controller/connection/base/models.py +++ b/openwisp_controller/connection/base/models.py @@ -30,7 +30,11 @@ ) from ..exceptions import NoWorkingDeviceConnectionError from ..signals import is_working_changed -from ..tasks import auto_add_credentials_to_devices, launch_command +from ..tasks import ( + auto_add_credentials_to_devices, + launch_batch_command, + launch_command, +) logger = logging.getLogger(__name__) @@ -467,6 +471,13 @@ class AbstractCommand(TimeStampedEditableModel): encoder=DjangoJSONEncoder, ) output = models.TextField(blank=True) + batch_command = models.ForeignKey( + get_model_name("connection", "BatchCommand"), + on_delete=models.SET_NULL, + blank=True, + null=True, + related_name="batch_commands", + ) class Meta: verbose_name = _("Command") @@ -557,6 +568,8 @@ def save(self, *args, **kwargs): output = super().save(*args, **kwargs) if adding: self._schedule_command() + if self.batch_command_id and self.status != "in-progress": + self.batch_command.calculate_and_update_status() return output def _save_without_resurrecting(self): @@ -719,3 +732,266 @@ def _enforce_not_custom(self): f"arguments property is not applicable in " f'command instance of type "{self.type}"' ) + + +class AbstractBatchCommand(TimeStampedEditableModel): + STATUS_CHOICES = ( + ("idle", _("idle")), + ("in-progress", _("in progress")), + ("success", _("success")), + ("failed", _("failed")), + ) + + organization = models.ForeignKey( + get_model_name("openwisp_users", "Organization"), + on_delete=models.CASCADE, + blank=True, + null=True, + ) + status = models.CharField( + max_length=12, choices=STATUS_CHOICES, default=STATUS_CHOICES[0][0] + ) + type = models.CharField( + max_length=16, + choices=(COMMAND_CHOICES if django.VERSION < (5, 0) else get_command_choices), + ) + input = JSONField(blank=True, null=True, encoder=DjangoJSONEncoder) + group = models.ForeignKey( + get_model_name("config", "DeviceGroup"), + on_delete=models.SET_NULL, + blank=True, + null=True, + verbose_name=_("device group"), + ) + location = models.ForeignKey( + get_model_name("geo", "Location"), + on_delete=models.SET_NULL, + blank=True, + null=True, + verbose_name=_("location"), + ) + devices = models.ManyToManyField( + get_model_name("config", "Device"), + blank=True, + verbose_name=_("devices"), + ) + + class Meta: + abstract = True + verbose_name = _("Batch command") + verbose_name_plural = _("Batch commands") + + @cached_property + def total_devices(self): + return self.batch_commands.count() + + @property + def successful(self): + return self.batch_commands.filter(status="success").count() + + @property + def failed(self): + return self.batch_commands.filter(status="failed").count() + + def _validate_org_relations(self): + if not self.organization_id: + return + if self.group and self.group.organization != self.organization: + raise ValidationError( + { + "group": _( + "The organization of the group doesn't match " + "the organization of the batch command operation" + ) + } + ) + if self.location and self.location.organization != self.organization: + raise ValidationError( + { + "location": _( + "The organization of the location doesn't match " + "the organization of the batch command operation" + ) + } + ) + if self.pk and self.devices.exists(): + org_mismatch = self.devices.exclude(organization=self.organization).exists() + if org_mismatch: + raise ValidationError( + { + "devices": _( + "All devices must belong to the same " + "organization as the batch command." + ) + } + ) + + def clean(self): + super().clean() + self._validate_org_relations() + Command = load_model("connection", "Command") + allowed = dict( + Command.get_org_allowed_commands(organization_id=self.organization_id) + ) + if self.type not in allowed: + raise ValidationError( + { + "type": _( + '"{command}" command is not available for this organization' + ).format(command=self.type) + } + ) + try: + jsonschema.Draft4Validator(get_command_schema(self.type)).validate( + self.input + ) + except SchemaError as e: + raise ValidationError({"input": e.message}) + + def resolve_devices(self): + """ + Returns an iterator of devices targeted by this batch command, + resolved from explicit M2M devices or filtered by organization, + group, and location. Returns an empty iterator if no devices match. + """ + if self.pk and self.devices.exists(): + return self.devices.iterator() + Device = load_model("config", "Device") + qs = Device.objects.all() + if self.organization_id: + qs = qs.filter(organization=self.organization) + if self.group: + qs = qs.filter(group=self.group) + if self.location: + qs = qs.filter(devicelocation__location=self.location) + return qs.iterator() + + @classmethod + def execute(cls, **kwargs): + """ + Creates, validates, and persists the batch command, then schedules + execution via a background task. Raises ValidationError and deletes + the batch if no devices match the criteria. + """ + devices_list = kwargs.pop("devices", None) + batch = cls(**kwargs) + batch.full_clean() + batch.save() + if devices_list: + batch.devices.set(devices_list) + if not batch.devices.exists(): + batch.delete() + raise ValidationError( + _("No devices match the specified criteria."), + ) + elif not any(batch.resolve_devices()): + batch.delete() + raise ValidationError( + _("No devices match the specified criteria."), + ) + batch.status = "in-progress" + batch.save(update_fields=["status"]) + transaction.on_commit(lambda: launch_batch_command.delay(batch.pk)) + return batch + + @classmethod + def dry_run(cls, **kwargs): + """ + Returns the devices that would be targeted by this batch command + without executing it. Skips full validation when command type is + not provided case for GET request. + """ + devices_list = kwargs.pop("devices", None) + cmd_type = kwargs.pop("type", None) + batch = cls(**kwargs) + if cmd_type: + batch.type = cmd_type + batch.full_clean() + else: + batch._validate_org_relations() + if devices_list: + return {"devices": list(devices_list)} + return {"devices": list(batch.resolve_devices())} + + def create_commands(self): + """ + Creates individual Command instances for each device targeted by + this batch command. Returns early if commands already exist + (idempotent guard). Devices that fail validation are silently + skipped. + """ + if self.batch_commands.exists(): + return + Command = load_model("connection", "Command") + self.status = "in-progress" + self.save() + for device in self.resolve_devices(): + command = Command( + device=device, + type=self.type, + input=self.input, + batch_command=self, + ) + try: + command.save() + except ValidationError as e: + logger.warning( + "Skipping device %s for batch %s: %s", + device.pk, + self.pk, + e, + ) + self.calculate_and_update_status() + + def calculate_and_update_status(self): + """ + Calculate batch status based on individual command statuses and update if + changed. + - No commands exist: status set to "idle". + - Commands still running: status set to "in-progress". + - Some commands failed: status set to "failed". + - All commands completed successfully: status set to "success". + - Status unchanged: no database write performed. + """ + stats = self.batch_commands.aggregate( + total_operations=models.Count("id"), + in_progress=models.Count( + models.Case( + models.When(status="in-progress", then=1), + output_field=models.IntegerField(), + ) + ), + completed=models.Count( + models.Case( + models.When(~models.Q(status="in-progress"), then=1), + output_field=models.IntegerField(), + ) + ), + successful=models.Count( + models.Case( + models.When(status="success", then=1), + output_field=models.IntegerField(), + ) + ), + failed=models.Count( + models.Case( + models.When(status="failed", then=1), + output_field=models.IntegerField(), + ) + ), + ) + if stats["total_operations"] == 0: + new_status = "idle" + elif stats["in_progress"] > 0: + new_status = "in-progress" + elif stats["failed"] > 0: + new_status = "failed" + elif ( + stats["successful"] > 0 and stats["completed"] == stats["total_operations"] + ): + new_status = "success" + else: + new_status = self.status + if self.status != new_status: + self.status = new_status + self.save(update_fields=["status"]) diff --git a/openwisp_controller/connection/migrations/0011_batchcommand_command_batch_command.py b/openwisp_controller/connection/migrations/0011_batchcommand_command_batch_command.py new file mode 100644 index 000000000..5f5b7f590 --- /dev/null +++ b/openwisp_controller/connection/migrations/0011_batchcommand_command_batch_command.py @@ -0,0 +1,142 @@ +# Generated by Django 5.2.15 on 2026-06-14 18:00 + +import uuid + +import django +import django.core.serializers.json +import django.db.models.deletion +import django.utils.timezone +import model_utils.fields +import swapper +from django.db import migrations, models + +from openwisp_controller import connection as connection_config + + +class Migration(migrations.Migration): + + dependencies = [ + ("connection", "0010_replace_jsonfield_with_django_builtin"), + ("openwisp_users", "0022_user_expiration_date"), + ("config", "0063_replace_jsonfield_with_django_builtin"), + ("geo", "0006_create_geo_settings_for_existing_orgs"), + ] + + operations = [ + migrations.CreateModel( + name="BatchCommand", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "created", + model_utils.fields.AutoCreatedField( + default=django.utils.timezone.now, + editable=False, + verbose_name="created", + ), + ), + ( + "modified", + model_utils.fields.AutoLastModifiedField( + default=django.utils.timezone.now, + editable=False, + verbose_name="modified", + ), + ), + ( + "status", + models.CharField( + choices=[ + ("idle", "idle"), + ("in-progress", "in progress"), + ("success", "success"), + ("failed", "failed"), + ], + default="idle", + max_length=12, + ), + ), + ( + "type", + models.CharField( + max_length=16, + choices=( + connection_config.commands.COMMAND_CHOICES + if django.VERSION < (5, 0) + else connection_config.commands.get_command_choices + ), + ), + ), + ( + "input", + models.JSONField( + blank=True, + encoder=django.core.serializers.json.DjangoJSONEncoder, + null=True, + ), + ), + ( + "devices", + models.ManyToManyField( + blank=True, + to=swapper.get_model_name("config", "Device"), + verbose_name="devices", + ), + ), + ( + "group", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=swapper.get_model_name("config", "DeviceGroup"), + verbose_name="device group", + ), + ), + ( + "location", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=swapper.get_model_name("geo", "Location"), + verbose_name="location", + ), + ), + ( + "organization", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to=swapper.get_model_name("openwisp_users", "Organization"), + ), + ), + ], + options={ + "verbose_name": "Batch command", + "verbose_name_plural": "Batch commands", + "abstract": False, + "swappable": "CONNECTION_BATCHCOMMAND_MODEL", + }, + ), + migrations.AddField( + model_name="command", + name="batch_command", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="batch_commands", + to=swapper.get_model_name("connection", "BatchCommand"), + ), + ), + ] diff --git a/openwisp_controller/connection/models.py b/openwisp_controller/connection/models.py index 39d485b57..533ad6c4d 100644 --- a/openwisp_controller/connection/models.py +++ b/openwisp_controller/connection/models.py @@ -1,6 +1,11 @@ import swapper -from .base.models import AbstractCommand, AbstractCredentials, AbstractDeviceConnection +from .base.models import ( + AbstractBatchCommand, + AbstractCommand, + AbstractCredentials, + AbstractDeviceConnection, +) class Credentials(AbstractCredentials): @@ -19,3 +24,9 @@ class Command(AbstractCommand): class Meta(AbstractCommand.Meta): abstract = False swappable = swapper.swappable_setting("connection", "Command") + + +class BatchCommand(AbstractBatchCommand): + class Meta(AbstractBatchCommand.Meta): + abstract = False + swappable = swapper.swappable_setting("connection", "BatchCommand") diff --git a/openwisp_controller/connection/tasks.py b/openwisp_controller/connection/tasks.py index d6ea2828b..2ad4b153a 100644 --- a/openwisp_controller/connection/tasks.py +++ b/openwisp_controller/connection/tasks.py @@ -98,6 +98,17 @@ def launch_command(command_id): command._save_without_resurrecting() +@shared_task(bind=True, soft_time_limit=app_settings.SSH_COMMAND_TIMEOUT * 1.2) +def launch_batch_command(self, batch_id): + BatchCommand = load_model("connection", "BatchCommand") + try: + batch = BatchCommand.objects.get(pk=batch_id) + except BatchCommand.DoesNotExist: + logger.warning(f"The BatchCommand object with id {batch_id} has been deleted") + return + batch.create_commands() + + @shared_task(soft_time_limit=3600) def auto_add_credentials_to_devices(credential_id, organization_id): Credentials = load_model("connection", "Credentials") diff --git a/openwisp_controller/connection/tests/pytest.py b/openwisp_controller/connection/tests/pytest.py index 220604cef..d7f7f5659 100644 --- a/openwisp_controller/connection/tests/pytest.py +++ b/openwisp_controller/connection/tests/pytest.py @@ -65,6 +65,7 @@ def _get_expected_response(self, command): "output": command.output, "device": str(command.device_id), "connection": str(command.connection_id), + "batch_command": None, }, } diff --git a/openwisp_controller/connection/tests/test_api.py b/openwisp_controller/connection/tests/test_api.py index 6e3a827ab..8ab0bb5d0 100644 --- a/openwisp_controller/connection/tests/test_api.py +++ b/openwisp_controller/connection/tests/test_api.py @@ -21,6 +21,7 @@ Command = load_model("connection", "Command") DeviceConnection = load_model("connection", "DeviceConnection") +BatchCommand = load_model("connection", "BatchCommand") command_qs = Command.objects.order_by("-created") OrganizationUser = load_model("openwisp_users", "OrganizationUser") Group = load_model("openwisp_users", "Group") @@ -860,3 +861,266 @@ def test_deviceconnection_unauthenticated_user(self): "delete": 401, }, ) + + +class TestBatchCommandsAPI( + TestAdminMixin, AuthenticationMixin, TestCase, CreateConnectionsMixin +): + url_namespace = "connection_api" + + def setUp(self): + super().setUp() + self._login() + + def _create_batch_command(self, organization, **kwargs): + opts = dict( + organization=organization, + type="custom", + input={"command": "echo test"}, + ) + devices = kwargs.pop("devices", None) + opts.update(kwargs) + batch = BatchCommand(**opts) + batch.full_clean() + batch.save() + if devices is not None: + if not isinstance(devices, (list, tuple)): + devices = [devices] + batch.devices.set(devices) + return batch + + def test_batch_command_list(self): + org = self._get_org() + url = reverse("connection_api:batch_command_list") + for _ in range(3): + self._create_batch_command(organization=org) + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["count"], 3) + self.assertEqual(len(response.data["results"]), 3) + created_list = [cmd["created"] for cmd in response.data["results"]] + sorted_created = sorted(created_list, reverse=True) + self.assertEqual(created_list, sorted_created) + result = response.data["results"][0] + self.assertIn("id", result) + self.assertIn("status", result) + self.assertIn("type", result) + self.assertIn("input", result) + self.assertIn("device_count", result) + self.assertIn("created", result) + self.assertEqual(result["device_count"], 0) + + def test_batch_command_detail(self): + org = self._get_org() + device = self._create_device(organization=org) + self._create_config(device=device) + batch = self._create_batch_command(organization=org, devices=[device]) + url = reverse("connection_api:batch_command_detail", args=[batch.pk]) + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["id"], str(batch.pk)) + self.assertEqual(response.data["status"], batch.status) + self.assertEqual(response.data["type"], batch.type) + self.assertEqual(response.data["input"], batch.input) + self.assertIn("devices", response.data) + self.assertEqual(response.data["devices"], [str(device.pk)]) + self.assertEqual(response.data["device_count"], 1) + + def test_batch_command_execute(self): + org = self._get_org() + device = self._create_device(organization=org) + self._create_config(device=device) + self._create_device_connection(device=device) + payload = { + "organization": str(org.pk), + "type": "custom", + "input": {"command": "echo test"}, + "devices": [str(device.pk)], + } + response = self.client.post( + reverse("connection_api:batch_command_execute"), + data=json.dumps(payload), + content_type="application/json", + ) + self.assertEqual(response.status_code, 201) + self.assertIn("batch", response.data) + batch = BatchCommand.objects.get(pk=response.data["batch"]) + # transaction.on_commit doesn't fire in TestCase; trigger manually + batch.create_commands() + command = Command.objects.get(batch_command=batch) + self.assertEqual(command.device.pk, device.pk) + + def test_batch_command_execute_queries(self): + org = self._get_org() + devices = [] + for i in range(3): + d = self._create_device( + name=f"q-dev-{i}", + mac_address=f"00:11:22:33:44:{i:02x}", + organization=org, + ) + self._create_config(device=d) + devices.append(d) + payload = { + "organization": str(org.pk), + "type": "custom", + "input": {"command": "echo test"}, + "devices": [str(d.pk) for d in devices], + } + with self.assertNumQueries(13): + response = self.client.post( + reverse("connection_api:batch_command_execute"), + data=json.dumps(payload), + content_type="application/json", + ) + self.assertEqual(response.status_code, 201) + self.assertIn("batch", response.data) + batch = BatchCommand.objects.get(pk=response.data["batch"]) + self.assertEqual(batch.devices.count(), 3) + self.assertCountEqual( + batch.devices.values_list("pk", flat=True), + [d.pk for d in devices], + ) + + def test_batch_command_execute_org_has_no_devices(self): + org = self._get_org() + payload = { + "organization": str(org.pk), + "type": "custom", + "input": {"command": "echo test"}, + "execute_all": True, + } + response = self.client.post( + reverse("connection_api:batch_command_execute"), + data=json.dumps(payload), + content_type="application/json", + ) + self.assertEqual(response.status_code, 400) + + def test_batch_command_execute_no_org(self): + org = self._get_org() + self.client.logout() + operator = self._create_operator(organizations=[org]) + add_perm = Permission.objects.get(codename="add_batchcommand") + operator.user_permissions.add(add_perm) + self.client.force_login(operator) + payload = { + "type": "custom", + "input": {"command": "echo test"}, + "execute_all": True, + } + response = self.client.post( + reverse("connection_api:batch_command_execute"), + data=json.dumps(payload), + content_type="application/json", + ) + self.assertEqual(response.status_code, 400) + self.assertIn( + "Only superusers", + str(response.data), + ) + + def test_batch_command_dry_run(self): + org = self._get_org() + device = self._create_device(organization=org) + self._create_config(device=device) + url = "{0}?organization={1}".format( + reverse("connection_api:batch_command_execute"), str(org.pk) + ) + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + self.assertIn("devices", response.data) + self.assertIn(str(device.pk), response.data["devices"]) + + def test_batch_command_list_organization_scoped(self): + org = self._get_org() + org2 = self._create_org(name="org2", slug="org2") + self._create_batch_command(organization=org) + self._create_batch_command(organization=org2) + self.client.logout() + operator = self._create_operator(organizations=[org]) + view_perm = Permission.objects.get(codename="view_batchcommand") + operator.user_permissions.add(view_perm) + self.client.force_login(operator) + response = self.client.get(reverse("connection_api:batch_command_list")) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["count"], 1) + + def test_batch_command_unauthorized(self): + self.client.logout() + with self.subTest("List"): + response = self.client.get(reverse("connection_api:batch_command_list")) + self.assertEqual(response.status_code, 401) + + with self.subTest("Detail"): + response = self.client.get( + reverse( + "connection_api:batch_command_detail", + args=[uuid.uuid4()], + ) + ) + self.assertEqual(response.status_code, 401) + + with self.subTest("Dry run"): + response = self.client.get(reverse("connection_api:batch_command_execute")) + self.assertEqual(response.status_code, 401) + + with self.subTest("Execute"): + response = self.client.post( + reverse("connection_api:batch_command_execute"), + data=json.dumps({"type": "custom"}), + content_type="application/json", + ) + self.assertEqual(response.status_code, 401) + + def test_batch_command_cross_org_restrictions(self): + org = self._get_org() + org2 = self._create_org(name="org2", slug="org2") + device_a = self._create_device( + name="device-a", + mac_address="00:11:22:33:44:aa", + organization=org, + ) + self._create_config(device=device_a) + self._create_device_connection(device=device_a) + device_b = self._create_device( + name="device-b", + mac_address="00:11:22:33:44:bb", + organization=org2, + ) + self._create_config(device=device_b) + self._create_device_connection(device=device_b) + + with patch.dict( + ORGANIZATION_ENABLED_COMMANDS, + {str(org2.pk): ("reboot",)}, + ): + payload = { + "type": "custom", + "input": {"command": "echo test"}, + "devices": [str(device_a.pk), str(device_b.pk)], + } + response = self.client.post( + reverse("connection_api:batch_command_execute"), + data=json.dumps(payload), + content_type="application/json", + ) + self.assertEqual(response.status_code, 201) + batch = BatchCommand.objects.get(pk=response.data["batch"]) + # transaction.on_commit doesn't fire in TestCase; + # trigger create_commands() manually + batch.create_commands() + batch.refresh_from_db() + command_qs = Command.objects.filter(batch_command=batch) + self.assertTrue(command_qs.filter(device=device_a).exists()) + self.assertFalse(command_qs.filter(device=device_b).exists()) + # Verify rendering works for created commands + url = reverse( + "connection_api:device_command_list", + args=[device_a.pk], + ) + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + cmd_data = response.data["results"][0] + self.assertIn("type", cmd_data) + self.assertIn("input", cmd_data) diff --git a/openwisp_controller/geo/estimated_location/tests/tests.py b/openwisp_controller/geo/estimated_location/tests/tests.py index c037106dd..fee63ae94 100644 --- a/openwisp_controller/geo/estimated_location/tests/tests.py +++ b/openwisp_controller/geo/estimated_location/tests/tests.py @@ -566,7 +566,7 @@ def _verify_location_details(device, mocked_response): device2.save() # 3 queries related to notifications cleanup device2.refresh_from_db() - with self.assertNumQueries(15): + with self.assertNumQueries(16): manage_estimated_locations(device2.pk, device2.last_ip) mock_info.assert_called_once_with( f"Estimated location saved successfully for {device2.pk}" diff --git a/openwisp_controller/geo/tests/test_api.py b/openwisp_controller/geo/tests/test_api.py index 156550692..df46460cf 100644 --- a/openwisp_controller/geo/tests/test_api.py +++ b/openwisp_controller/geo/tests/test_api.py @@ -694,7 +694,7 @@ def test_change_location_type_to_outdoor_api(self): def test_delete_location_detail(self): l1 = self._create_location() path = reverse("geo_api:detail_location", args=[l1.pk]) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.delete(path) self.assertEqual(response.status_code, 204) diff --git a/tests/openwisp2/sample_connection/api/views.py b/tests/openwisp2/sample_connection/api/views.py index fd2207cbc..22d04a924 100644 --- a/tests/openwisp2/sample_connection/api/views.py +++ b/tests/openwisp2/sample_connection/api/views.py @@ -1,3 +1,12 @@ +from openwisp_controller.connection.api.views import ( + BatchCommandDetailView as BaseBatchCommandDetailView, +) +from openwisp_controller.connection.api.views import ( + BatchCommandExecuteView as BaseBatchCommandExecuteView, +) +from openwisp_controller.connection.api.views import ( + BatchCommandListView as BaseBatchCommandListView, +) from openwisp_controller.connection.api.views import ( CommandDetailsView as BaseCommandDetailsView, ) @@ -42,9 +51,24 @@ class DeviceConnectionDetailView(BaseDeviceConnectionDetailView): pass +class BatchCommandExecuteView(BaseBatchCommandExecuteView): + pass + + +class BatchCommandListView(BaseBatchCommandListView): + pass + + +class BatchCommandDetailView(BaseBatchCommandDetailView): + pass + + command_list_create_view = CommandListCreateView.as_view() command_details_view = CommandDetailsView.as_view() credential_list_create_view = CredentialListCreateView.as_view() credential_detail_view = CredentialDetailView.as_view() deviceconnection_list_create_view = DeviceConnectionListCreateView.as_view() deviceconnection_detail_view = DeviceConnectionDetailView.as_view() +batch_command_execute_view = BatchCommandExecuteView.as_view() +batch_command_list_view = BatchCommandListView.as_view() +batch_command_detail_view = BatchCommandDetailView.as_view() diff --git a/tests/openwisp2/sample_connection/migrations/0005_batchcommand_command_batch_command.py b/tests/openwisp2/sample_connection/migrations/0005_batchcommand_command_batch_command.py new file mode 100644 index 000000000..de9613fc0 --- /dev/null +++ b/tests/openwisp2/sample_connection/migrations/0005_batchcommand_command_batch_command.py @@ -0,0 +1,138 @@ +# Generated by Django 5.2.15 on 2026-06-15 18:15 + +import uuid + +import django +import django.core.serializers.json +import django.db.models.deletion +import django.utils.timezone +import model_utils.fields +from django.db import migrations, models + +from openwisp_controller import connection as connection_config + + +class Migration(migrations.Migration): + + dependencies = [ + ("sample_config", "0009_replace_jsonfield_with_django_builtin"), + ("sample_connection", "0004_replace_jsonfield_with_django_builtin"), + ("sample_geo", "0005_organizationgeosettings"), + ("sample_users", "0005_user_expiration_date_user_user_active_expiry_idx"), + ] + + operations = [ + migrations.CreateModel( + name="BatchCommand", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "created", + model_utils.fields.AutoCreatedField( + default=django.utils.timezone.now, + editable=False, + verbose_name="created", + ), + ), + ( + "modified", + model_utils.fields.AutoLastModifiedField( + default=django.utils.timezone.now, + editable=False, + verbose_name="modified", + ), + ), + ( + "status", + models.CharField( + choices=[ + ("idle", "idle"), + ("in-progress", "in progress"), + ("success", "success"), + ("failed", "failed"), + ], + default="idle", + max_length=12, + ), + ), + ( + "type", + models.CharField( + max_length=16, + choices=( + connection_config.commands.COMMAND_CHOICES + if django.VERSION < (5, 0) + else connection_config.commands.get_command_choices + ), + ), + ), + ( + "input", + models.JSONField( + blank=True, + encoder=django.core.serializers.json.DjangoJSONEncoder, + null=True, + ), + ), + ( + "devices", + models.ManyToManyField( + blank=True, to="sample_config.device", verbose_name="devices" + ), + ), + ( + "group", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to="sample_config.devicegroup", + verbose_name="device group", + ), + ), + ( + "location", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to="sample_geo.location", + verbose_name="location", + ), + ), + ( + "organization", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="sample_users.organization", + ), + ), + ], + options={ + "verbose_name": "Batch command", + "verbose_name_plural": "Batch commands", + "abstract": False, + }, + ), + migrations.AddField( + model_name="command", + name="batch_command", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="batch_commands", + to="sample_connection.batchcommand", + ), + ), + ] diff --git a/tests/openwisp2/sample_connection/models.py b/tests/openwisp2/sample_connection/models.py index 7964c5471..4e96065e8 100644 --- a/tests/openwisp2/sample_connection/models.py +++ b/tests/openwisp2/sample_connection/models.py @@ -1,6 +1,7 @@ from django.db import models from openwisp_controller.connection.base.models import ( + AbstractBatchCommand, AbstractCommand, AbstractCredentials, AbstractDeviceConnection, @@ -27,3 +28,8 @@ class Meta(AbstractDeviceConnection.Meta): class Command(AbstractCommand): class Meta(AbstractCommand.Meta): abstract = False + + +class BatchCommand(AbstractBatchCommand): + class Meta(AbstractBatchCommand.Meta): + abstract = False diff --git a/tests/openwisp2/settings.py b/tests/openwisp2/settings.py index c45eb5537..0e27adfde 100644 --- a/tests/openwisp2/settings.py +++ b/tests/openwisp2/settings.py @@ -293,6 +293,7 @@ CONNECTION_CREDENTIALS_MODEL = "sample_connection.Credentials" CONNECTION_DEVICECONNECTION_MODEL = "sample_connection.DeviceConnection" CONNECTION_COMMAND_MODEL = "sample_connection.Command" + CONNECTION_BATCHCOMMAND_MODEL = "sample_connection.BatchCommand" SUBNET_DIVISION_SUBNETDIVISIONRULE_MODEL = ( "sample_subnet_division.SubnetDivisionRule" )