```python
import hashlib
import os
import pickle
from typing import Self

import numpy as np
import torch
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"


class ImageTensor(torch.Tensor):
    @classmethod
    def from_pixels(cls, pixels):
        instance = (
            cls._square_prune(pixels).to(device).as_subclass(cls)
            if isinstance(pixels, torch.Tensor)
            else cls._square_prune(
                torch.tensor(pixels, device=device, dtype=torch.float)
            ).as_subclass(cls)
        )
        instance.__init__()
        return instance

    def __init__(self) -> None:
        super().__init__()
        assert self.dim() in {2, 3}

        self.img_shape = torch.tensor(self.shape[:2])

        self.radius = torch.min(self.img_shape).divide(2)
        self.center_y, self.center_x = self.img_shape / 2.0

        self._circle_mask()

    def show(self) -> Self:
        numpy_array = torch.clamp(self, min=0, max=255).cpu().numpy().astype("uint8")
        Image.fromarray(
            numpy_array, mode="CMYK"
        ).show() if self.dim() == 3 else Image.fromarray(numpy_array).show()
        return self

    def _circle_mask(self) -> None:
        file = hashlib.sha256(f"circle_mask{self.shape}".encode()).hexdigest()
        if os.path.isfile(f"tmp/{file}.mask"):
            with open(f"tmp/{file}.mask", "rb") as f:
                mask = pickle.load(f)
        else:
            y_indices, x_indices = meshgrid(*self.img_shape)
            distances = torch.hypot(
                (y_indices - self.center_y).float(),
                (x_indices - self.center_x).float(),
            )
            mask = distances > self.radius
            with open(f"tmp/{file}.mask", "wb") as f:
                pickle.dump(mask, f)

        self[mask] = 255

    @staticmethod
    def _square_prune(tensor: torch.Tensor) -> torch.Tensor:
        assert tensor.dim() in {2, 3}

        center_y, center_x = torch.tensor(tensor.shape[:2]) / torch.tensor(2.0)
        radius = min(tensor.shape[:2]) / torch.tensor(2.0)

        if center_y == center_x:
            return tensor

        y_min = int(center_y - radius)
        y_max = int(center_y + radius)
        x_min = int(center_x - radius)
        x_max = int(center_x + radius)

        return (
            tensor[y_min:y_max, x_min:x_max]
            if tensor.dim() == 2
            else tensor[y_min:y_max, x_min:x_max, :]
        )

    def __add__(self, other) -> Self:
        return ImageTensor.from_pixels(torch.add(self, other))

    def __sub__(self, other) -> Self:
        return ImageTensor.from_pixels(torch.sub(self, other))

    def add_to_tensor(self, other: Self) -> torch.Tensor:
        return torch.add(self, other)

    def channel(self, idx):
        assert self.dim() == 3
        return self.from_pixels(self[:, :, idx])


def meshgrid(
    y_size: int | torch.Tensor, x_size: int | torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    y_indices, x_indices = torch.meshgrid(
        torch.arange(int(y_size), device=device),
        torch.arange(int(x_size), device=device),
        indexing="ij",
    )
    return y_indices, x_indices


def calculate_lines(
    parent_img: ImageTensor,
    current_point: torch.Tensor,
    points: torch.Tensor,
    decay_factor: float,
    brightness_factor: float,
) -> torch.Tensor:
    assert parent_img.dim() == 2
    assert current_point.dim() == 1
    assert points.dim() == 2

    y_indices, x_indices = meshgrid(*parent_img.img_shape)

    y_indices = y_indices.unsqueeze(0).tile((points.shape[0], 1, 1)).to(device)
    x_indices = x_indices.unsqueeze(0).tile((points.shape[0], 1, 1)).to(device)

    y0 = points[:, 0].unsqueeze(1).unsqueeze(2).expand_as(y_indices).to(device)
    x0 = points[:, 1].unsqueeze(1).unsqueeze(2).expand_as(x_indices).to(device)
    y1 = current_point[0]
    x1 = current_point[1]
    delta_y = y1 - y0
    delta_x = x1 - x0

    numerator = torch.abs(
        (delta_y * x_indices) - (delta_x * y_indices) + (x1 * y0) - (y1 * x0)
    )
    denominator = torch.hypot(delta_y.float(), delta_x.float())

    distance = numerator / denominator
    brightness = brightness_factor * brightness_decay(distance, decay_factor)
    brightness[(points == current_point).all(dim=1).nonzero(as_tuple=True)[0], :, :] = (
        torch.zeros_like(parent_img)
    )
    return brightness.permute((1, 2, 0))


def brightness_decay(distance: torch.Tensor, decay_factor: float):
    return torch.exp((-((distance / decay_factor) ** 2)))


def pickle_object(obj, filename):
    with open(filename, "wb") as f:
        pickle.dump(obj, f)

    return obj


def unpickle_object(filename):
    with open(filename, "rb") as f:
        return pickle.load(f)


def create_string_art(
    image_file,
    nail_count: int,
    brightness_factor: float,
    decay_factor: float,
    print_status: bool,
) -> ImageTensor:
    img, cyan_channel, magenta_channel, yellow_channel, key_channel = cmyk_split(
        image_file
    )

    cyan_string_art = ImageTensor.from_pixels(torch.zeros_like(cyan_channel))
    magenta_string_art = ImageTensor.from_pixels(torch.zeros_like(magenta_channel))
    yellow_string_art = ImageTensor.from_pixels(torch.zeros_like(yellow_channel))

    nail_angles = torch.linspace(0, 2 * torch.pi, nail_count)
    nail_locations = torch.stack(
        [
            torch.tensor(
                [
                    img.center_y + img.radius * torch.sin(angle),
                    img.center_x + img.radius * torch.cos(angle),
                ]
            )
            for angle in nail_angles
        ]
    )

    for channel, string_art in [
        (cyan_channel, cyan_string_art),
        (magenta_channel, magenta_string_art),
        (yellow_channel, yellow_string_art),
    ]:
        is_done = False
        nail_order = [nail_count // 2]
        channel_mse = np.inf

        while not is_done:
            current_nail = nail_order[-1]

            all_lines = calculate_lines(
                channel,
                nail_locations[current_nail],
                nail_locations,
                decay_factor,
                brightness_factor,
            )
            line_error = bulk_square_error(
                torch.add(
                    all_lines, torch.unsqueeze(string_art, 2).expand_as(all_lines)
                ),
                channel,
            )

            best_mse = torch.min(line_error)
            line_idx = int(line_error.argmin())
            if torch.min(line_error) < channel_mse:
                channel_mse = best_mse
                nail_order.append(line_idx)
                string_art += ImageTensor.from_pixels(all_lines[:, :, line_idx])

                print(f"Iteration: {len(nail_order) - 1}") if print_status else ...
            else:
                is_done = True

    return pickle_object(
        ImageTensor.from_pixels(
            torch.stack(
                [cyan_string_art, magenta_string_art, yellow_string_art, key_channel],
                dim=-1,
            )
        ),
        f"{image_file}.stringart",
    )


def bulk_square_error(inputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    target = target.unsqueeze(2).expand_as(inputs)
    return torch.sum((inputs - target) ** 2, dim=(0, 1))


def cmyk_split(image_file: str):
    img = ImageTensor.from_pixels(np.array(Image.open(image_file).convert("CMYK")))

    cyan_channel = img.channel(0)
    magenta_channel = img.channel(1)
    yellow_channel = img.channel(2)
    key_channel = img.channel(3)

    return img, cyan_channel, magenta_channel, yellow_channel, key_channel


def main() -> None:
    create_string_art("img/unnamed.jpg", 600, 30, 0.9, True).show()


if __name__ == "__main__":
    main()

```

I am trying to make a program that approximates an image by wrapping "string" around "nails" that are around the edge of the circle.

The basic flow of my code is as follows:
- Load the image and split it into CMYK
- Generate a 3d tensor where each 2d slice is a possible line that can be drawn between the current "nail" and all other nails (including itself)
- Find the slice with the lowese summed square error when added to the current string art and compared with the real image
- Repeat until all possible lines you can draw have worse/the same summed square error as the current string art
- Repeat for each color channel

This works really well for very small images (100x100 to 250x250) and I can use up to like almost 2000 "nails" and have it run really quick. It does start to get really slow really fast, and I get a CUDA timeout error on around 800x800 and anything around 1000x1000 I just get a GPU memory error. The timeout error had some advice on solving it, but I don't know what that actually does/means and I would rather fix any issues with allocating more memory than necessary in my code first. (I am using a GTX 980 which is almost 10 years old and only has 4GB of VRAM, so I would assume on a more modern GPU with more VRAM, it would be able to handle larger images before running into the CUDA launch error and GPU memory error)

```stdout
RuntimeError: CUDA error: the launch timed out and was terminated
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
```

I am fairly certain the `calculate_lines()` and `bulk_square_error()` methods can be improved a lot.

`calculate_lines()` takes in a parent image of type `ImageTensor` which is just my subclass of `torch.Tensor`,  the current points of type `torch.Tensor` and of the form `[y, x]`,  and then the list of all points to draw to of type `torch.Tensor` of the form `[[y_0, x_0], [y_1, x_1], ...]` and returns a `torch.Tensor` of shape `[y, x, nail_index]`. I am pretty sure that the biggest bottleneck of this function is the creation of `y0, x0, y1, x1, delta_y, and delta_x`. I am sure there is a better way to be able to handle them than what I am doing, which is just expanding them to be the desired output shape. but I am sure that allocates a lot more memory than is necessary and that there is a better way to handle that, but I am very new to using pytorch and am not very familiar with all of its methods. I am using [this](https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line#Line_defined_by_two_points) formula to draw the lines.

`bulk_square_error()` takes in a 3d tensor of each of the lines added to the current string art and  the  2d real target image and returns a 1d tensor of the sumemd square error of each 2d slice. This function has the same problem as `calculate_lines()` where I am expanding the target image to be the same shape as the 3d tensor. I am sure there is a way to compare each slice to the same 2d tensor without having to expand it to take up more memory.

I am looking for any any ways to improve the speed/GPU memory usage of my program, and any criticisms/feedback regarding code style or my aproach to the problem.

Any help is appreciated. Thanks!

-----

Edits:

There is no good reason as to why I am using pytorch rather than numpy. Originally, I was calculating the mean squared error of using `torch.nn.functional.mse_loss()` method to compare each line+string art to the real image individually, but it was extremely slow to iterate over each possible line with a for loop. Not that I am not using that function any more, I could switch to numpy, but pytorch is fine.

Additionally, I feel like I should share my inspiration for this project. [This](https://www.youtube.com/watch?v=WGccIFf6MF8&ab_channel=VirtuallyPassed) youtube video was my inspiration. The creator recently made a follow-up to a video using the radon transform. Maybe I will attempt that implementation eventually, but for now I am content with this implementation.