Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
## TorchSharp Release Notes

Releases, starting with 9/2/2021, are listed with the most recent release at the top.
# NuGet Version 0.106.0 (Upcoming)

This release upgrades the libtorch backend to v2.9.0.

__API Changes__:

#1498 Add support for torch.export ExportedProgram models (.pt2 files)<br/>
TorchSharp now supports loading and executing PyTorch models exported via torch.export using AOTInductor compilation. Use `torch.export.load()` to load `.pt2` model packages compiled with `torch._inductor.aoti_compile_and_package()` in Python. This provides 30-40% better inference latency compared to TorchScript models. Note: This is an inference-only API with no training support.<br/>

# NuGet Version 0.105.2

This release upgrades the libtorch backend to v2.7.1, using CUDA 12.8.
Expand Down
2 changes: 1 addition & 1 deletion build/Dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

<!-- Other/Non-Core Product Dependencies -->
<PropertyGroup>
<LibTorchVersion>2.7.1</LibTorchVersion>
<LibTorchVersion>2.9.0</LibTorchVersion>
<LibTorchVersion Condition="'$(TargetArchitecture)' == 'x64' and '$(TargetOS)' == 'mac'">2.2.2</LibTorchVersion>
<CudaVersionDot>12.8</CudaVersionDot>
<CudaVersionNoDot>128</CudaVersionNoDot>
Expand Down
2 changes: 2 additions & 0 deletions src/Native/LibTorchSharp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ set(SOURCES
crc32c.h
THSAutograd.h
THSData.h
THSExport.h
THSJIT.h
THSNN.h
THSStorage.h
Expand All @@ -23,6 +24,7 @@ set(SOURCES
THSActivation.cpp
THSAutograd.cpp
THSData.cpp
THSExport.cpp
THSFFT.cpp
THSJIT.cpp
THSLinearAlgebra.cpp
Expand Down
51 changes: 51 additions & 0 deletions src/Native/LibTorchSharp/THSExport.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
#include "THSExport.h"

// torch.export support via AOTInductor
// This uses torch::inductor::AOTIModelPackageLoader which is INFERENCE-ONLY
// Models must be compiled with torch._inductor.aoti_compile_and_package() in Python

ExportedProgramModule THSExport_load(const char* filename)
{
CATCH(
// Load .pt2 file using AOTIModelPackageLoader
// This requires models to be compiled with aoti_compile_and_package()
auto* loader = new torch::inductor::AOTIModelPackageLoader(filename);
return loader;
);

return nullptr;
}

void THSExport_Module_dispose(const ExportedProgramModule module)
{
delete module;
}

void THSExport_Module_run(
const ExportedProgramModule module,
const Tensor* input_tensors,
const int input_length,
Tensor** result_tensors,
int* result_length)
{
CATCH(
// Convert input tensor pointers to std::vector<torch::Tensor>
std::vector<torch::Tensor> inputs;
inputs.reserve(input_length);
for (int i = 0; i < input_length; i++) {
inputs.push_back(*input_tensors[i]);
}

// Run inference
std::vector<torch::Tensor> outputs = module->run(inputs);

// Allocate output array and copy results
*result_length = outputs.size();
*result_tensors = new Tensor[outputs.size()];

for (size_t i = 0; i < outputs.size(); i++) {
(*result_tensors)[i] = new torch::Tensor(outputs[i]);
}
);
}
32 changes: 32 additions & 0 deletions src/Native/LibTorchSharp/THSExport.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
#pragma once

#include "../Stdafx.h"

#include "torch/torch.h"
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"

#include "Utils.h"

// torch.export support via AOTInductor - Load and execute PyTorch ExportedProgram models (.pt2 files)
// ExportedProgram is PyTorch 2.x's recommended way to export models for production deployment
//
// IMPORTANT: This implementation uses torch::inductor::AOTIModelPackageLoader which is
// INFERENCE-ONLY. Training, parameter updates, and device movement are not supported.
// Models must be compiled with torch._inductor.aoti_compile_and_package() in Python.

// Load an AOTInductor-compiled model package from a .pt2 file
EXPORT_API(ExportedProgramModule) THSExport_load(const char* filename);

// Dispose of an ExportedProgram module
EXPORT_API(void) THSExport_Module_dispose(const ExportedProgramModule module);

// Execute the ExportedProgram's forward method (inference only)
// Input: Array of tensors
// Output: Array of result tensors (caller must free)
EXPORT_API(void) THSExport_Module_run(
const ExportedProgramModule module,
const Tensor* input_tensors,
const int input_length,
Tensor** result_tensors,
int* result_length);
4 changes: 4 additions & 0 deletions src/Native/LibTorchSharp/THSJIT.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,7 @@ EXPORT_API(TensorOrScalar*) THSJIT_AllocateTensorOrScalarArray(int32_t size);
EXPORT_API(void) THSJIT_FreeTensorOrScalarArray(TensorOrScalar* ptr);
EXPORT_API(void) THSJIT_SetTensorOrScalar(TensorOrScalar* array, int32_t index, int64_t type_code, int64_t array_index, ptrdiff_t handle);
EXPORT_API(TensorOrScalar*) THSJIT_GetTensorOrScalar(TensorOrScalar* array, int32_t index);

// Helper functions (shared with THSExport)
std::vector<c10::IValue> toIValue(const TensorOrScalar* tensorPtrs, const int length);
TensorOrScalar* ReturnHelper(c10::IValue result, TensorOrScalar* (*allocator)(int32_t idx, size_t length), int8_t* typeCode, int32_t* idx);
5 changes: 5 additions & 0 deletions src/Native/LibTorchSharp/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <string>

#include "torch/torch.h"
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"

extern thread_local char *torch_last_err;

Expand All @@ -24,6 +25,10 @@ typedef std::shared_ptr<torch::jit::Function> * JITFunction;
typedef std::shared_ptr<c10::Type> * JITType;
typedef std::shared_ptr<c10::TensorType>* JITTensorType;

// torch.export ExportedProgram module via AOTInductor
// Note: Uses torch::inductor::AOTIModelPackageLoader for inference-only execution
typedef torch::inductor::AOTIModelPackageLoader* ExportedProgramModule;

struct TensorArray {
Tensor *array;
int64_t size;
Expand Down
1 change: 1 addition & 0 deletions src/Redist/libtorch-cpu/libtorch-macos-arm64-2.9.0.zip.sha
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6D6AF87CAB301FA25CB4909697A03C65ED234E784CD96C8743A9AD6586238D0E
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
22DE42ABDE933BE46CE843467930BD0190B72271BFA2C11F84DB95591A9834F1
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
C826069DA829550BD3F1205159F8A95EE906A447DD141D08F42C568D4EE9E05E
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0892B92717B2396FE7ED62BE9AA6B78074C48BBB34D239F96FCCC70BE4560098
215 changes: 215 additions & 0 deletions src/TorchSharp/Export/ExportedProgram.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.

using System;
using System.Runtime.InteropServices;
using TorchSharp.PInvoke;
using static TorchSharp.PInvoke.NativeMethods;

namespace TorchSharp
{
public static partial class torch
{
public static partial class export
{
/// <summary>
/// Load a PyTorch ExportedProgram from a .pt2 file compiled with AOTInductor.
/// </summary>
/// <param name="filename">Path to the .pt2 file</param>
/// <returns>ExportedProgram model for inference</returns>
/// <remarks>
/// IMPORTANT: The .pt2 file must be compiled with torch._inductor.aoti_compile_and_package() in Python.
/// Models saved with torch.export.save() alone will NOT work - they require AOTInductor compilation.
///
/// This implementation is INFERENCE-ONLY. Training, parameter updates, and device movement
/// are not supported. The model is compiled for a specific device (CPU/CUDA) at compile time.
///
/// Example Python code to create compatible .pt2 files:
/// <code>
/// import torch
/// import torch._inductor
///
/// # Export the model
/// exported = torch.export.export(model, example_inputs)
///
/// # Compile with AOTInductor (required for C++ loading)
/// torch._inductor.aoti_compile_and_package(
/// exported,
/// package_path="model.pt2"
/// )
/// </code>
/// </remarks>
public static ExportedProgram load(string filename)
{
return new ExportedProgram(filename);
}

/// <summary>
/// Load a PyTorch ExportedProgram with typed output.
/// </summary>
public static ExportedProgram<TResult> load<TResult>(string filename)
{
return new ExportedProgram<TResult>(filename);
}
}
}

/// <summary>
/// Represents a PyTorch ExportedProgram loaded from an AOTInductor-compiled .pt2 file.
/// This is an INFERENCE-ONLY implementation - training and parameter updates are not supported.
/// </summary>
/// <remarks>
/// Unlike TorchScript models, ExportedProgram models are ahead-of-time (AOT) compiled for
/// a specific device and are optimized for inference performance. They provide 30-40% better
/// latency compared to TorchScript in many cases.
///
/// Key limitations:
/// - Inference only (no training, no gradients)
/// - No parameter access or updates
/// - No device movement (compiled for specific device)
/// - No dynamic model structure changes
///
/// Use torch.jit for models that require training or dynamic behavior.
/// </remarks>
public class ExportedProgram : IDisposable
{
private IntPtr handle;
private bool _disposed = false;

internal ExportedProgram(string filename)
{
handle = THSExport_load(filename);
if (handle == IntPtr.Zero)
torch.CheckForErrors();
}

/// <summary>
/// Run inference on the model with the given input tensors.
/// </summary>
/// <param name="inputs">Input tensors for the model</param>
/// <returns>Array of output tensors</returns>
/// <remarks>
/// The number and shapes of inputs must match what the model was exported with.
/// All inputs must be on the same device that the model was compiled for.
/// </remarks>
public torch.Tensor[] run(params torch.Tensor[] inputs)
{
if (_disposed)
throw new ObjectDisposedException(nameof(ExportedProgram));

// Convert managed tensors to IntPtr array
IntPtr[] input_handles = new IntPtr[inputs.Length];
for (int i = 0; i < inputs.Length; i++)
{
input_handles[i] = inputs[i].Handle;
}

// Call native run method
THSExport_Module_run(handle, input_handles, inputs.Length, out IntPtr result_ptr, out int result_length);
torch.CheckForErrors();

// Marshal result array
torch.Tensor[] results = new torch.Tensor[result_length];
IntPtr[] result_handles = new IntPtr[result_length];
Marshal.Copy(result_ptr, result_handles, 0, result_length);

for (int i = 0; i < result_length; i++)
{
results[i] = new torch.Tensor(result_handles[i]);
}

// Free the native array (tensors are now owned by managed Tensor objects)
Marshal.FreeHGlobal(result_ptr);

return results;
}

/// <summary>
/// Synonym for run() - executes forward pass.
/// </summary>
public torch.Tensor[] forward(params torch.Tensor[] inputs) => run(inputs);

/// <summary>
/// Synonym for run() - executes the model.
/// </summary>
public torch.Tensor[] call(params torch.Tensor[] inputs) => run(inputs);

public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}

protected virtual void Dispose(bool disposing)
{
if (!_disposed)
{
if (handle != IntPtr.Zero)
{
THSExport_Module_dispose(handle);
handle = IntPtr.Zero;
}
_disposed = true;
}
}

~ExportedProgram()
{
Dispose(false);
}
}

/// <summary>
/// Generic version of ExportedProgram with typed output.
/// </summary>
/// <typeparam name="TResult">The return type (Tensor, Tensor[], or tuple of Tensors)</typeparam>
public class ExportedProgram<TResult> : ExportedProgram
{
internal ExportedProgram(string filename) : base(filename)
{
}

/// <summary>
/// Run inference with typed return value.
/// </summary>
public new TResult run(params torch.Tensor[] inputs)
{
var results = base.run(inputs);

// Handle different return types
if (typeof(TResult) == typeof(torch.Tensor))
{
if (results.Length != 1)
throw new InvalidOperationException($"Expected 1 output tensor, got {results.Length}");
return (TResult)(object)results[0];
}

if (typeof(TResult) == typeof(torch.Tensor[]))
{
return (TResult)(object)results;
}

// Handle tuple types
if (typeof(TResult).IsGenericType)
{
var genericType = typeof(TResult).GetGenericTypeDefinition();
if (genericType == typeof(ValueTuple<,>))
{
if (results.Length != 2)
throw new InvalidOperationException($"Expected 2 output tensors, got {results.Length}");
return (TResult)Activator.CreateInstance(typeof(TResult), results[0], results[1]);
}
if (genericType == typeof(ValueTuple<,,>))
{
if (results.Length != 3)
throw new InvalidOperationException($"Expected 3 output tensors, got {results.Length}");
return (TResult)Activator.CreateInstance(typeof(TResult), results[0], results[1], results[2]);
}
}

throw new NotSupportedException($"Return type {typeof(TResult)} is not supported");
}

public new TResult forward(params torch.Tensor[] inputs) => run(inputs);
public new TResult call(params torch.Tensor[] inputs) => run(inputs);
}
}
Loading