# -------------- Modification 1 --------------
# Modify sam2_video_predictor.py using the following changes:
# 1. Insert "import numpy as np" after "import torch.nn.functional as F" line 11.
# so that the import section looks like this:
import warnings
from collections import OrderedDict

import torch
import torch.nn.functional as F
import numpy as np

# -------------- Modification 2 --------------
# insert the following function after 
    @torch.inference_mode()
    def init_state(
	....
	....
	)

# Function starts here:
    @torch.inference_mode()
    def init_state_from_array(
        self,
        image_array,
        offload_video_to_cpu=False,
        offload_state_to_cpu=False,
        async_loading_frames=False,  # Included for compatibility, though not used here
    ):
        """Initialize an inference state with a NumPy array of images.

        Args:
            image_array (np.ndarray): A NumPy array of shape (num_frames, height, width, 3) containing RGB images.
            offload_video_to_cpu (bool): Whether to offload video frames to CPU memory.
            offload_state_to_cpu (bool): Whether to offload the inference state to CPU memory.
            async_loading_frames (bool): Ignored in this version, included for compatibility.

        Returns:
            dict: The initialized inference state.
        """
        compute_device = self.device  # device of the model, e.g., cuda:0

        # Validate input
        if not isinstance(image_array, np.ndarray):
            raise ValueError("image_array must be a NumPy array")
        if image_array.ndim != 4 or image_array.shape[-1] != 3:
            raise ValueError("image_array must have shape (num_frames, height, width, 3)")

        num_frames, video_height, video_width = image_array.shape[:3]

        # Convert NumPy array to torch tensor and preprocess
        images = torch.from_numpy(image_array).float()  # Shape: (N, H, W, C)
        images = images.permute(0, 3, 1, 2)  # Shape: (N, C, H, W)

        # Move images to compute_device immediately
        images = images.to(compute_device, non_blocking=True)

        # Resize images to model’s image_size
        if video_height != self.image_size or video_width != self.image_size:
            images = torch.nn.functional.interpolate(
                images,
                size=(self.image_size, self.image_size),
                mode="bilinear",
                align_corners=False,
            )

        # Normalize images (ensure mean and std are on the same device)
        img_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32, device=compute_device)[:, None, None]
        img_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32, device=compute_device)[:, None, None]
        images = images / 255.0  # Scale to [0, 1] if not already
        images -= img_mean
        images /= img_std

        # Move to appropriate device based on offload_video_to_cpu
        if offload_video_to_cpu:
            images = images.to(torch.device("cpu"), non_blocking=True)

        # Initialize inference state
        inference_state = {}
        inference_state["images"] = images
        inference_state["num_frames"] = num_frames
        inference_state["offload_video_to_cpu"] = offload_video_to_cpu
        inference_state["offload_state_to_cpu"] = offload_state_to_cpu
        inference_state["video_height"] = video_height
        inference_state["video_width"] = video_width
        inference_state["device"] = compute_device
        if offload_state_to_cpu:
            inference_state["storage_device"] = torch.device("cpu")
        else:
            inference_state["storage_device"] = compute_device

        # Initialize other state variables consistent with "sam2_video_predictor-fb.py"
        inference_state["point_inputs_per_obj"] = {}
        inference_state["mask_inputs_per_obj"] = {}
        inference_state["cached_features"] = {}
        inference_state["constants"] = {}
        inference_state["obj_id_to_idx"] = OrderedDict()
        inference_state["obj_idx_to_id"] = OrderedDict()
        inference_state["obj_ids"] = []
        inference_state["output_dict_per_obj"] = {}
        inference_state["temp_output_dict_per_obj"] = {}
        inference_state["frames_tracked_per_obj"] = {}

        # Warm up the visual backbone and cache the image feature on frame 0
        self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
        return inference_state