diff --git a/.gitignore b/.gitignore
index a25e14a..9a79f19 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,4 +11,8 @@ scripts/*.ps1
 scripts/*.sh
 **/dist
 **/build
-*.log
\ No newline at end of file
+*.log
+benchmark/
+modelTest/
+nc_workspace/
+debug_openai_history.txt
\ No newline at end of file
diff --git a/README.md b/README.md
index 4e4c883..b96bedd 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@ Run local LLMs on iGPU, APU and CPU (AMD , Intel, and Qualcomm (Coming Soon)). E
 | Model architectures   | Gemma 
 Llama \* 
 Mistral + 
Phi 
 |                   |                |
 | Platform              | Linux 
 Windows                                 |                   |                |
 | Architecture          | x86 
 x64 
                                 | Arm64             |                |
-| Hardware Acceleration | CUDA
DirectML
IpexLLM                       | QNN 
 ROCm    | OpenVINO       |
+| Hardware Acceleration | CUDA
DirectML
IpexLLM
OpenVINO          | QNN 
 ROCm    |                |
 
 \* The Llama model architecture supports similar model families such as CodeLlama, Vicuna, Yi, and more.
 
@@ -33,22 +33,12 @@ Run local LLMs on iGPU, APU and CPU (AMD , Intel, and Qualcomm (Coming Soon)). E
 - [Acknowledgements](#acknowledgements)
 
 ## Supported Models (Quick Start)
+  * Onnxruntime DirectML Models [Link](./docs/model/onnxruntime_directml_models.md)
+  * Onnxruntime CPU Models [Link](./docs/model/onnxruntime_cpu_models.md)
+  * Ipex-LLM Models [Link](./docs/model/ipex_models.md)
+  * OpenVINO-LLM Models [Link](./docs/model/openvino_models.md)
+  * NPU-LLM Models [Link](./docs/model/npu_models.md)
 
-| Models | Parameters | Context Length | Link |
-| --- | --- | --- | --- |
-| Gemma-2b-Instruct v1 | 2B | 8192 | [EmbeddedLLM/gemma-2b-it-onnx](https://huggingface.co/EmbeddedLLM/gemma-2b-it-onnx) |
-| Llama-2-7b-chat | 7B | 4096 | [EmbeddedLLM/llama-2-7b-chat-int4-onnx-directml](https://huggingface.co/EmbeddedLLM/llama-2-7b-chat-int4-onnx-directml) |
-| Llama-2-13b-chat | 13B | 4096 | [EmbeddedLLM/llama-2-13b-chat-int4-onnx-directml](https://huggingface.co/EmbeddedLLM/llama-2-13b-chat-int4-onnx-directml) |
-| Llama-3-8b-chat | 8B | 8192 | [EmbeddedLLM/mistral-7b-instruct-v0.3-onnx](https://huggingface.co/EmbeddedLLM/mistral-7b-instruct-v0.3-onnx) |
-| Mistral-7b-v0.3-instruct | 7B | 32768 | [EmbeddedLLM/mistral-7b-instruct-v0.3-onnx](https://huggingface.co/EmbeddedLLM/mistral-7b-instruct-v0.3-onnx) |
-| Phi-3-mini-4k-instruct-062024 | 3.8B | 4096 | [EmbeddedLLM/Phi-3-mini-4k-instruct-062024-onnx](https://huggingface.co/EmbeddedLLM/Phi-3-mini-4k-instruct-062024-onnx/tree/main/onnx/directml/Phi-3-mini-4k-instruct-062024-int4) |
-| Phi3-mini-4k-instruct | 3.8B | 4096 | [microsoft/Phi-3-mini-4k-instruct-onnx](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx) |
-| Phi3-mini-128k-instruct | 3.8B | 128k | [microsoft/Phi-3-mini-128k-instruct-onnx](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct-onnx) |
-| Phi3-medium-4k-instruct | 17B | 4096 | [microsoft/Phi-3-medium-4k-instruct-onnx-directml](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct-onnx-directml) |
-| Phi3-medium-128k-instruct | 17B | 128k | [microsoft/Phi-3-medium-128k-instruct-onnx-directml](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct-onnx-directml) |
-| Openchat-3.6-8b | 8B | 8192 | [EmbeddedLLM/openchat-3.6-8b-20240522-onnx](https://huggingface.co/EmbeddedLLM/openchat-3.6-8b-20240522-onnx) |
-| Yi-1.5-6b-chat | 6B | 32k | [EmbeddedLLM/01-ai_Yi-1.5-6B-Chat-onnx](https://huggingface.co/EmbeddedLLM/01-ai_Yi-1.5-6B-Chat-onnx) |
-| Phi-3-vision-128k-instruct |  | 128k | [EmbeddedLLM/Phi-3-vision-128k-instruct-onnx](https://huggingface.co/EmbeddedLLM/Phi-3-vision-128k-instruct-onnx/tree/main/onnx/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4) |
 
 ## Getting Started
 
@@ -70,12 +60,14 @@ Run local LLMs on iGPU, APU and CPU (AMD , Intel, and Qualcomm (Coming Soon)). E
      - **CUDA:** `$env:ELLM_TARGET_DEVICE='cuda'; pip install -e .[cuda]`
      - **IPEX:** `$env:ELLM_TARGET_DEVICE='ipex'; python setup.py develop`
      - **OpenVINO:** `$env:ELLM_TARGET_DEVICE='openvino'; pip install -e .[openvino]`
+     - **NPU:** `$env:ELLM_TARGET_DEVICE='npu'; pip install -e .[npu]`
      - **With Web UI**:
        - **DirectML:** `$env:ELLM_TARGET_DEVICE='directml'; pip install -e .[directml,webui]`
        - **CPU:** `$env:ELLM_TARGET_DEVICE='cpu'; pip install -e .[cpu,webui]`
        - **CUDA:** `$env:ELLM_TARGET_DEVICE='cuda'; pip install -e .[cuda,webui]`
        - **IPEX:** `$env:ELLM_TARGET_DEVICE='ipex'; python setup.py develop; pip install -r requirements-webui.txt`
        - **OpenVINO:** `$env:ELLM_TARGET_DEVICE='openvino'; pip install -e .[openvino,webui]`
+       - **NPU:** `$env:ELLM_TARGET_DEVICE='npu'; pip install -e .[npu,webui]`
 
 - **Linux**
 
@@ -91,12 +83,14 @@ Run local LLMs on iGPU, APU and CPU (AMD , Intel, and Qualcomm (Coming Soon)). E
      - **CUDA:** `ELLM_TARGET_DEVICE='cuda' pip install -e .[cuda]`
      - **IPEX:** `ELLM_TARGET_DEVICE='ipex' python setup.py develop`
      - **OpenVINO:** `ELLM_TARGET_DEVICE='openvino' pip install -e .[openvino]`
+     - **NPU:** `ELLM_TARGET_DEVICE='npu' pip install -e .[npu]`
      - **With Web UI**:
        - **DirectML:** `ELLM_TARGET_DEVICE='directml' pip install -e .[directml,webui]`
        - **CPU:** `ELLM_TARGET_DEVICE='cpu' pip install -e .[cpu,webui]`
        - **CUDA:** `ELLM_TARGET_DEVICE='cuda' pip install -e .[cuda,webui]`
        - **IPEX:** `ELLM_TARGET_DEVICE='ipex' python setup.py develop; pip install -r requirements-webui.txt`
        - **OpenVINO:** `ELLM_TARGET_DEVICE='openvino' pip install -e .[openvino,webui]`
+       - **NPU:** `ELLM_TARGET_DEVICE='npu' pip install -e .[npu,webui]`
 
 ### Launch OpenAI API Compatible Server
 
@@ -121,7 +115,7 @@ Run local LLMs on iGPU, APU and CPU (AMD , Intel, and Qualcomm (Coming Soon)). E
 
 ### Launch Chatbot Web UI
 
-1.  `ellm_chatbot --port 7788 --host localhost --server_port  --server_host localhost`. **Note:** To find out more of the supported arguments. `ellm_chatbot --help`.
+1.  `ellm_chatbot --port 7788 --host localhost --server_port  --server_host localhost --model_name `. **Note:** To find out more of the supported arguments. `ellm_chatbot --help`.
 
 
 
@@ -156,6 +150,9 @@ It is an interface that allows you to download and deploy OpenAI API compatible
 
    # OpenVINO
    ellm_server --model_path '.\meta-llama_Meta-Llama-3.1-8B-Instruct\'  --backend 'openvino' --device 'gpu' --port 5555 --served_model_name 'meta-llama_Meta/Llama-3.1-8B-Instruct'
+
+   # NPU
+   ellm_server --model_path 'microsoft/Phi-3-mini-4k-instruct'  --backend 'npu' --device 'npu' --port 5555 --served_model_name 'microsoft/Phi-3-mini-4k-instruct'
    ```
 
 ## Prebuilt OpenAI API Compatible Windows Executable (Alpha)
@@ -168,13 +165,16 @@ _Powershell/Terminal Usage (Use it like `ellm_server`)_:
 .\ellm_api_server.exe --model_path 
 
 # DirectML
-.\ellm_api_server.exe --model_path 'EmbeddedLLM_Phi-3-mini-4k-instruct-062024-onnx\onnx\directml\Phi-3-mini-4k-instruct-062024-int4' --port 5555
+.\ellm_api_server.exe --model_path 'EmbeddedLLM/Phi-3-mini-4k-instruct-onnx-directml' --port 5555
 
 # IPEX-LLM
 .\ellm_api_server.exe --model_path '.\meta-llama_Meta-Llama-3.1-8B-Instruct\'  --backend 'ipex' --device 'xpu' --port 5555 --served_model_name 'meta-llama_Meta/Llama-3.1-8B-Instruct'
 
 # OpenVINO
 .\ellm_api_server.exe --model_path '.\meta-llama_Meta-Llama-3.1-8B-Instruct\'  --backend 'openvino' --device 'gpu' --port 5555 --served_model_name 'meta-llama_Meta/Llama-3.1-8B-Instruct'
+
+# NPU
+.\ellm_api_server.exe --model_path 'microsoft/Phi-3-mini-4k-instruct'  --backend 'npu' --device 'npu' --port 5555 --served_model_name 'microsoft/Phi-3-mini-4k-instruct'
 ```
 
 ## Acknowledgements
diff --git a/docs/model/npu_models.md b/docs/model/npu_models.md
new file mode 100644
index 0000000..c1d2b06
--- /dev/null
+++ b/docs/model/npu_models.md
@@ -0,0 +1,15 @@
+# Model Powered by NPU-LLM
+
+## Verified Models
+Verified models can be found from EmbeddedLLM NPU-LLM model collections
+* EmbeddedLLM NPU-LLM Model collections: [link](https://huggingface.co/collections/EmbeddedLLM/npu-llm-66d692817e6c9509bb8ead58)
+
+| Model | Model Link |
+| --- | --- |
+| Phi-3-mini-4k-instruct | [link](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) |
+| Phi-3-mini-128k-instruct | [link](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) |
+| Phi-3-medium-4k-instruct | [link](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct) |
+| Phi-3-medium-128k-instruct | [link](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct) |
+
+## Contribution
+We welcome contributions to the verified model list.
\ No newline at end of file
diff --git a/requirements-npu.txt b/requirements-npu.txt
new file mode 100644
index 0000000..dbcb8cf
--- /dev/null
+++ b/requirements-npu.txt
@@ -0,0 +1,3 @@
+intel-npu-acceleration-library
+torch>=2.4
+transformers>=4.42
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 4520ee6..50ce2f9 100644
--- a/setup.py
+++ b/setup.py
@@ -54,6 +54,10 @@ def _is_openvino() -> bool:
     return ELLM_TARGET_DEVICE == "openvino"
 
 
+def _is_npu() -> bool:
+    return ELLM_TARGET_DEVICE == "npu"
+
+
 class ELLMInstallCommand(install):
     def run(self):
         install.run(self)
@@ -198,6 +202,8 @@ def get_requirements() -> List[str]:
         requirements = _read_requirements("requirements-ipex.txt")
     elif _is_openvino():
         requirements = _read_requirements("requirements-openvino.txt")
+    elif _is_npu():
+        requirements = _read_requirements("requirements-npu.txt")
     else:
         raise ValueError("Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.")
     return requirements
@@ -216,6 +222,8 @@ def get_ellm_version() -> str:
         version += "+ipex"
     elif _is_openvino():
         version += "+openvino"
+    elif _is_npu():
+        version += "+npu"
     else:
         raise RuntimeError("Unknown runtime environment")
 
@@ -268,6 +276,7 @@ def get_ellm_version() -> str:
         "cuda": ["onnxruntime-genai-cuda==0.3.0rc2"],
         "ipex": [],
         "openvino": [],
+        "npu": [],
     },
     dependency_links=dependency_links,
     entry_points={
diff --git a/src/embeddedllm/backend/intel_npu_engine.py b/src/embeddedllm/backend/intel_npu_engine.py
new file mode 100644
index 0000000..c245e43
--- /dev/null
+++ b/src/embeddedllm/backend/intel_npu_engine.py
@@ -0,0 +1,268 @@
+import contextlib
+import time
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import AsyncIterator, List, Optional
+
+from loguru import logger
+from PIL import Image
+from transformers import (
+    AutoConfig,
+    PreTrainedTokenizer,
+    PreTrainedTokenizerFast,
+    TextIteratorStreamer,
+)
+
+from threading import Thread
+
+import intel_npu_acceleration_library as npu_lib
+
+from embeddedllm.inputs import PromptInputs
+from embeddedllm.protocol import CompletionOutput, RequestOutput
+from embeddedllm.sampling_params import SamplingParams
+from embeddedllm.backend.base_engine import BaseLLMEngine, _get_and_verify_max_len
+
+RECORD_TIMING = True
+
+
+class NPUEngine(BaseLLMEngine):
+    def __init__(self, model_path: str, vision: bool, device: str = "npu"):
+        self.model_path = model_path
+        self.model_config: AutoConfig = AutoConfig.from_pretrained(
+            self.model_path, trust_remote_code=True
+        )
+        self.device = device
+
+        # model_config is to find out the max length of the model
+        self.max_model_len = _get_and_verify_max_len(
+            hf_config=self.model_config,
+            max_model_len=None,
+            disable_sliding_window=False,
+            sliding_window_len=self.get_hf_config_sliding_window(),
+        )
+
+        logger.info("Model Context Length: " + str(self.max_model_len))
+
+        try:
+            logger.info("Attempt to load fast tokenizer")
+            self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.model_path)
+        except Exception:
+            logger.info("Attempt to load slower tokenizer")
+            self.tokenizer = PreTrainedTokenizer.from_pretrained(self.model_path)
+
+        self.model = npu_lib.NPUModelForCausalLM.from_pretrained(
+                        self.model_path,
+                        torch_dtype="auto",
+                        dtype=npu_lib.int4,
+                        trust_remote_code=True,
+                        export=False
+                    )
+
+        logger.info("Model loaded")
+        self.tokenizer_stream = TextIteratorStreamer(
+            self.tokenizer, skip_prompt=True, skip_special_tokens=True
+        )
+        logger.info("Tokenizer created")
+
+        self.vision = vision
+
+        # if self.vision:
+        #     self.onnx_processor = self.model.create_multimodal_processor()
+        #     self.processor = AutoImageProcessor.from_pretrained(
+        #         self.model_path, trust_remote_code=True
+        #     )
+        #     print(dir(self.processor))
+
+    async def generate_vision(
+        self,
+        inputs: PromptInputs,
+        sampling_params: SamplingParams,
+        request_id: str,
+        stream: bool = True,
+    ) -> AsyncIterator[RequestOutput]:
+        raise NotImplementedError(f"generate_vision yet to be implemented.")
+
+    async def generate(
+        self,
+        inputs: PromptInputs,
+        sampling_params: SamplingParams,
+        request_id: str,
+        stream: bool = True,
+    ) -> AsyncIterator[RequestOutput]:
+        """Generate outputs for a request.
+
+        Generate outputs for a request. This method is a coroutine. It adds the
+        request into the waiting queue of the LLMEngine and streams the outputs
+        from the LLMEngine to the caller.
+
+        """
+
+        prompt_text = inputs["prompt"]
+        input_token_length = None
+        input_tokens = None  # for text only use case
+        # logger.debug("inputs: " + prompt_text)
+
+        input_tokens = self.tokenizer.encode(prompt_text, return_tensors="pt")
+        # logger.debug(f"input_tokens: {input_tokens}")
+        input_token_length = len(input_tokens[0])
+
+        max_tokens = sampling_params.max_tokens
+
+        assert input_token_length is not None
+
+        if input_token_length + max_tokens > self.max_model_len:
+            raise ValueError("Exceed Context Length")
+
+        generation_options = {
+            name: getattr(sampling_params, name)
+            for name in [
+                "do_sample",
+                # "max_length",
+                "max_new_tokens",
+                "min_length",
+                "top_p",
+                "top_k",
+                "temperature",
+                "repetition_penalty",
+            ]
+            if hasattr(sampling_params, name)
+        }
+        generation_options["max_length"] = self.max_model_len
+        generation_options["input_ids"] = input_tokens.clone()
+        # generation_options["input_ids"] = input_tokens.clone().to(self.device)
+        generation_options["max_new_tokens"] = max_tokens
+        print(generation_options)
+
+        token_list: List[int] = []
+        output_text: str = ""
+        if stream:
+            generation_options["streamer"] = self.tokenizer_stream
+            if RECORD_TIMING:
+                started_timestamp = time.time()
+                first_token_timestamp = 0
+                first = True
+                new_tokens = []
+            try:
+                thread = Thread(target=self.model.generate, kwargs=generation_options)
+                started_timestamp = time.time()
+                first_token_timestamp = None
+                thread.start()
+                output_text = ""
+                first = True
+                for new_text in self.tokenizer_stream:
+                    if new_text == "":
+                        continue
+                    if RECORD_TIMING:
+                        if first:
+                            first_token_timestamp = time.time()
+                            first = False
+                    # logger.debug(f"new text: {new_text}")
+                    output_text += new_text
+                    token_list = self.tokenizer.encode(output_text, return_tensors="pt")
+
+                    output = RequestOutput(
+                        request_id=request_id,
+                        prompt=prompt_text,
+                        prompt_token_ids=input_tokens[0],
+                        finished=False,
+                        outputs=[
+                            CompletionOutput(
+                                index=0,
+                                text=output_text,
+                                token_ids=token_list[0],
+                                cumulative_logprob=-1.0,
+                            )
+                        ],
+                    )
+                    yield output
+                    # logits = generator.get_output("logits")
+                    # print(logits)
+                    if RECORD_TIMING:
+                        new_tokens = token_list[0]
+
+                yield RequestOutput(
+                    request_id=request_id,
+                    prompt=prompt_text,
+                    prompt_token_ids=input_tokens[0],
+                    finished=True,
+                    outputs=[
+                        CompletionOutput(
+                            index=0,
+                            text=output_text,
+                            token_ids=token_list[0],
+                            cumulative_logprob=-1.0,
+                            finish_reason="stop",
+                        )
+                    ],
+                )
+                if RECORD_TIMING:
+                    prompt_time = first_token_timestamp - started_timestamp
+                    run_time = time.time() - first_token_timestamp
+                    logger.info(
+                        f"Prompt length: {len(input_tokens[0])}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens[0])/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps"
+                    )
+
+            except Exception as e:
+                logger.error(str(e))
+
+                error_output = RequestOutput(
+                    prompt=inputs,
+                    prompt_token_ids=input_tokens,
+                    finished=True,
+                    request_id=request_id,
+                    outputs=[
+                        CompletionOutput(
+                            index=0,
+                            text=output_text,
+                            token_ids=token_list,
+                            cumulative_logprob=-1.0,
+                            finish_reason="error",
+                            stop_reason=str(e),
+                        )
+                    ],
+                )
+                yield error_output
+        else:
+            try:
+                token_list = self.model.generate(**generation_options)[0]
+
+                output_text = self.tokenizer.decode(
+                    token_list[input_token_length:], skip_special_tokens=True
+                )
+
+                yield RequestOutput(
+                    request_id=request_id,
+                    prompt=prompt_text,
+                    prompt_token_ids=input_tokens[0],
+                    finished=True,
+                    outputs=[
+                        CompletionOutput(
+                            index=0,
+                            text=output_text,
+                            token_ids=token_list,
+                            cumulative_logprob=-1.0,
+                            finish_reason="stop",
+                        )
+                    ],
+                )
+
+            except Exception as e:
+                logger.error(str(e))
+
+                error_output = RequestOutput(
+                    prompt=prompt_text,
+                    prompt_token_ids=input_tokens[0],
+                    finished=True,
+                    request_id=request_id,
+                    outputs=[
+                        CompletionOutput(
+                            index=0,
+                            text=output_text,
+                            token_ids=token_list,
+                            cumulative_logprob=-1.0,
+                            finish_reason="error",
+                            stop_reason=str(e),
+                        )
+                    ],
+                )
+                yield error_output
\ No newline at end of file
diff --git a/src/embeddedllm/engine.py b/src/embeddedllm/engine.py
index e2c5a9d..b341472 100644
--- a/src/embeddedllm/engine.py
+++ b/src/embeddedllm/engine.py
@@ -56,6 +56,22 @@ def __init__(self, model_path: str, vision: bool, device: str = "xpu", backend:
 
             self.engine = OnnxruntimeEngine(self.model_path, self.vision, self.device)
             logger.info(f"Initializing onnxruntime backend ({backend.upper()}): OnnxruntimeEngine")
+            
+        elif self.backend == "npu":
+            assert self.device == "npu", f"To run npu backend, device must be npu."
+            processor = get_processor_type()
+            if(processor == "Intel"):
+                from embeddedllm.backend.intel_npu_engine import NPUEngine
+                
+                self.engine = NPUEngine(self.model_path, self.vision, self.device)
+                logger.info(f"Initializing Intel npu backend (NPU): NPUEngine")
+                
+            elif(processor == "AMD"):
+                raise SystemError(f"NPU support on AMD platform is not supported yet.")
+            
+            else:
+                raise SystemError(f"Unknown processor is not supported.")
+        
         elif self.backend == "cpu":
             assert self.device == "cpu", f"To run `cpu` backend, `device` must be `cpu`."
             processor = get_processor_type()
@@ -80,7 +96,7 @@ def __init__(self, model_path: str, vision: bool, device: str = "xpu", backend:
 
         else:
             raise ValueError(
-                f"EmbeddedLLMEngine only supports `cpu`, `ipex`, `cuda`, `openvino` and `directml`."
+                f"EmbeddedLLMEngine only supports `cpu`, `npu`, `ipex`, `cuda`, `openvino` and `directml`."
             )
         self.tokenizer = self.engine.tokenizer
 
diff --git a/src/embeddedllm/entrypoints/modelui.py b/src/embeddedllm/entrypoints/modelui.py
index 9c82355..81cb681 100644
--- a/src/embeddedllm/entrypoints/modelui.py
+++ b/src/embeddedllm/entrypoints/modelui.py
@@ -20,7 +20,7 @@ def get_embeddedllm_backend():
         version = importlib.metadata.version("embeddedllm")
 
         # Use regex to extract the backend
-        match = re.search(r"\+(directml|cpu|cuda|ipex|openvino)$", version)
+        match = re.search(r"\+(directml|npu|cpu|cuda|ipex|openvino)$", version)
 
         if match:
             backend = match.group(1)
@@ -260,6 +260,41 @@ class ModelCard(BaseModel):
     ),
 }
 
+npu_model_dict_list = {
+    "microsoft/Phi-3-mini-4k-instruct": ModelCard(
+        hf_url="https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/tree/main/",
+        repo_id="microsoft/Phi-3-mini-4k-instruct",
+        model_name="Phi-3-mini-4k-instruct",
+        subfolder=".",
+        repo_type="model",
+        context_length=4096,
+    ),
+    "microsoft/Phi-3-mini-128k-instruct": ModelCard(
+        hf_url="https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/tree/main",
+        repo_id="microsoft/Phi-3-mini-128k-instruct",
+        model_name="Phi-3-mini-128k-instruct",
+        subfolder=".",
+        repo_type="model",
+        context_length=131072,
+    ),
+    "microsoft/Phi-3-medium-4k-instruct": ModelCard(
+        hf_url="https://huggingface.co/microsoft/Phi-3-medium-4k-instruct/tree/main",
+        repo_id="microsoft/Phi-3-medium-4k-instruct",
+        model_name="Phi-3-medium-4k-instruct",
+        subfolder=".",
+        repo_type="model",
+        context_length=4096,
+    ),
+    "microsoft/Phi-3-medium-128k-instruct": ModelCard(
+        hf_url="https://huggingface.co/microsoft/Phi-3-medium-128k-instruct/tree/main",
+        repo_id="microsoft/Phi-3-medium-128k-instruct",
+        model_name="Phi-3-medium-128k-instruct",
+        subfolder=".",
+        repo_type="model",
+        context_length=131072,
+    ),
+}
+
 ipex_model_dict_list = {
     "microsoft/Phi-3-mini-4k-instruct": ModelCard(
         hf_url="https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/tree/main/",
@@ -507,6 +542,11 @@ def compute_memory_size(repo_id, path_in_repo, repo_type: str = "model"):
         repo_id=v.repo_id, path_in_repo=v.subfolder, repo_type=v.repo_type
     )
 
+for k, v in npu_model_dict_list.items():
+    v.size = compute_memory_size(
+        repo_id=v.repo_id, path_in_repo=v.subfolder, repo_type=v.repo_type
+    )
+
 for k, v in ipex_model_dict_list.items():
     v.size = compute_memory_size(
         repo_id=v.repo_id, path_in_repo=v.subfolder, repo_type=v.repo_type
@@ -603,6 +643,9 @@ def update_model_list(engine_type):
     if engine_type == "DirectML":
         models = sorted(list(dml_model_dict_list.keys()))
         models_pandas = convert_to_dataframe(dml_model_dict_list)
+    elif backend == "npu":
+        models = sorted(list(npu_model_dict_list.keys()))
+        models_pandas = convert_to_dataframe(npu_model_dict_list)
     elif backend == "ipex":
         models = sorted(list(ipex_model_dict_list.keys()))
         models_pandas = convert_to_dataframe(ipex_model_dict_list)
@@ -631,6 +674,8 @@ def deploy_model(engine_type, model_name, port_number):
 
     if engine_type == "DirectML":
         llm_model_card = dml_model_dict_list[model_name]
+    elif backend == "npu":
+        llm_model_card = npu_model_dict_list[model_name]
     elif backend == "ipex":
         llm_model_card = ipex_model_dict_list[model_name]
     elif backend == "openvino":
@@ -654,7 +699,9 @@ def deploy_model(engine_type, model_name, port_number):
     model_path = llm_model_card.repo_id
     print("Model path:", model_path)
 
-    if engine_type == "Ipex":
+    if engine_type == "NPU":
+        device = "npu"
+    elif engine_type == "Ipex":
         device = "xpu"
     elif engine_type == "OpenVino":
         device = "gpu"
@@ -718,6 +765,8 @@ def download_model(engine_type, model_name):
 
     if engine_type == "DirectML":
         llm_model_card = dml_model_dict_list[model_name]
+    elif backend == "npu":
+        llm_model_card = npu_model_dict_list[model_name]
     elif backend == "ipex":
         llm_model_card = ipex_model_dict_list[model_name]
     elif backend == "openvino":
@@ -771,6 +820,8 @@ def main():
 
         if backend == "directml":
             default_value = "DirectML"
+        elif backend == "npu":
+            default_value = "NPU"
         elif backend == "ipex":
             default_value = "Ipex"
         elif backend == "openvino":