Home / Function/ testChannelsLastConv2d() — pytorch Function Reference

testChannelsLastConv2d() — pytorch Function Reference

Architecture documentation for the testChannelsLastConv2d() function in PytorchTestBase.java from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  c8c5fea2_b14a_8035_e2bd_13d0bf71824a["testChannelsLastConv2d()"]
  fd3711c0_c2a1_0848_309d_ba2ef1be4466["loadModel()"]
  c8c5fea2_b14a_8035_e2bd_13d0bf71824a -->|calls| fd3711c0_c2a1_0848_309d_ba2ef1be4466
  076e553b_79aa_4059_bfb2_6682a21ae462["assertIValueTensor()"]
  c8c5fea2_b14a_8035_e2bd_13d0bf71824a -->|calls| 076e553b_79aa_4059_bfb2_6682a21ae462
  style c8c5fea2_b14a_8035_e2bd_13d0bf71824a fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

android/pytorch_android/src/androidTest/java/org/pytorch/PytorchTestBase.java lines 348–406

  @Test
  public void testChannelsLastConv2d() throws IOException {
    long[] inputShape = new long[] {1, 3, 2, 2};
    long[] dataNCHW = new long[] {
      111, 112,
      121, 122,

      211, 212,
      221, 222,

      311, 312,
      321, 322};
    Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
    long[] dataNHWC = new long[] {
      111, 211, 311,       112, 212, 312,

      121, 221, 321,       122, 222, 322};
    Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
    long[] weightShape = new long[] {3, 3, 1, 1};
    long[] dataWeightOIHW = new long[] {
      2, 0, 0,
      0, 1, 0,
      0, 0, -1};
    Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
    long[] dataWeightOHWI = new long[] {
      2, 0, 0,
      0, 1, 0,
      0, 0, -1};

    Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);

    final Module module = loadModel(TEST_MODULE_ASSET_NAME);

    final IValue outputNCHW =
        module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
    assertIValueTensor(
        outputNCHW,
        MemoryFormat.CONTIGUOUS,
        new long[] {1, 3, 2, 2},
        new long[] {
          2*111, 2*112,
          2*121, 2*122,

          211, 212,
          221, 222,

          -311, -312,
          -321, -322});

    final IValue outputNHWC =
        module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
    assertIValueTensor(
        outputNHWC,
        MemoryFormat.CHANNELS_LAST,
        new long[] {1, 3, 2, 2},
        new long[] {
          2*111, 211, -311,      2*112, 212, -312,
          2*121, 221, -321,      2*122, 222, -322});
  }

Subdomains

Frequently Asked Questions

What does testChannelsLastConv2d() do?
testChannelsLastConv2d() is a function in the pytorch codebase.
What does testChannelsLastConv2d() call?
testChannelsLastConv2d() calls 2 function(s): assertIValueTensor, loadModel.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free