PytorchTestBase Class — pytorch Architecture
Architecture documentation for the PytorchTestBase class in PytorchTestBase.java from the pytorch codebase.
Entity Profile
Source Code
android/pytorch_android/src/androidTest/java/org/pytorch/PytorchTestBase.java lines 15–694
public abstract class PytorchTestBase {
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
@Test
public void testForwardNull() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
assertTrue(input.isTensor());
final IValue output = module.forward(input);
assertTrue(output.isNull());
}
@Test
public void testEqBool() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (boolean value : new boolean[] {false, true}) {
final IValue input = IValue.from(value);
assertTrue(input.isBool());
assertTrue(value == input.toBool());
final IValue output = module.runMethod("eqBool", input);
assertTrue(output.isBool());
assertTrue(value == output.toBool());
}
}
@Test
public void testEqInt() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
final IValue input = IValue.from(value);
assertTrue(input.isLong());
assertTrue(value == input.toLong());
final IValue output = module.runMethod("eqInt", input);
assertTrue(output.isLong());
assertTrue(value == output.toLong());
}
}
@Test
public void testEqFloat() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
double[] values =
new double[] {
-Double.MAX_VALUE,
Double.MAX_VALUE,
-Double.MIN_VALUE,
Double.MIN_VALUE,
-Math.exp(1.d),
-Math.sqrt(2.d),
-3.1415f,
3.1415f,
-1,
0,
1,
};
for (double value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isDouble());
assertTrue(value == input.toDouble());
final IValue output = module.runMethod("eqFloat", input);
assertTrue(output.isDouble());
assertTrue(value == output.toDouble());
}
}
@Test
public void testEqTensor() throws IOException {
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
final long numElements = Tensor.numel(inputTensorShape);
final float[] inputTensorData = new float[(int) numElements];
for (int i = 0; i < numElements; ++i) {
inputTensorData[i] = i;
}
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(inputTensor);
assertTrue(input.isTensor());
assertTrue(inputTensor == input.toTensor());
final IValue output = module.runMethod("eqTensor", input);
assertTrue(output.isTensor());
final Tensor outputTensor = output.toTensor();
assertNotNull(outputTensor);
assertArrayEquals(inputTensorShape, outputTensor.shape());
float[] outputData = outputTensor.getDataAsFloatArray();
for (int i = 0; i < numElements; i++) {
assertTrue(inputTensorData[i] == outputData[i]);
}
}
@Test
public void testEqDictIntKeyIntValue() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<Long, IValue> inputMap = new HashMap<>();
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
inputMap.put(0l, IValue.from(0l));
inputMap.put(1l, IValue.from(-1l));
inputMap.put(-1l, IValue.from(1l));
final IValue input = IValue.dictLongKeyFrom(inputMap);
assertTrue(input.isDictLongKey());
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
assertTrue(output.isDictLongKey());
final Map<Long, IValue> outputMap = output.toDictLongKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@Test
public void testEqDictStrKeyIntValue() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<String, IValue> inputMap = new HashMap<>();
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
inputMap.put("long_0", IValue.from(0l));
inputMap.put("long_1", IValue.from(1l));
inputMap.put("long_-1", IValue.from(-1l));
final IValue input = IValue.dictStringKeyFrom(inputMap);
assertTrue(input.isDictStringKey());
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
assertTrue(output.isDictStringKey());
final Map<String, IValue> outputMap = output.toDictStringKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@Test
public void testListIntSumReturnTuple() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (int n : new int[] {0, 1, 128}) {
long[] a = new long[n];
long sum = 0;
for (int i = 0; i < n; i++) {
a[i] = i;
sum += a[i];
}
final IValue input = IValue.listFrom(a);
assertTrue(input.isLongList());
final IValue output = module.runMethod("listIntSumReturnTuple", input);
assertTrue(output.isTuple());
assertTrue(2 == output.toTuple().length);
IValue output0 = output.toTuple()[0];
IValue output1 = output.toTuple()[1];
assertArrayEquals(a, output0.toLongList());
assertTrue(sum == output1.toLong());
}
}
@Test
public void testOptionalIntIsNone() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
}
@Test
public void testIntEq0None() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
}
@Test(expected = IllegalArgumentException.class)
public void testRunUndefinedMethod() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
module.runMethod("test_undefined_method_throws_exception");
}
@Test
public void testTensorMethods() {
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
int[] ints = new int[numel];
float[] floats = new float[numel];
byte[] bytes = new byte[numel];
for (int i = 0; i < numel; i++) {
bytes[i] = (byte) ((i % 255) - 128);
ints[i] = i;
floats[i] = i / 1000.f;
}
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
assertTrue(tensorBytes.dtype() == DType.INT8);
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
Tensor tensorInts = Tensor.fromBlob(ints, shape);
assertTrue(tensorInts.dtype() == DType.INT32);
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
float[] floatsOut = tensorFloats.getDataAsFloatArray();
assertTrue(floatsOut.length == numel);
for (int i = 0; i < numel; i++) {
assertTrue(floats[i] == floatsOut[i]);
}
}
@Test(expected = IllegalStateException.class)
public void testTensorIllegalStateOnWrongType() {
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
float[] floats = new float[numel];
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
tensorFloats.getDataAsByteArray();
}
@Test
public void testEqString() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[] {
"smoketest",
"проверка не латинских символов", // not latin symbols check
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("eqStr", input);
assertTrue(output.isString());
assertTrue(value.equals(output.toStr()));
}
}
@Test
public void testStr3Concat() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[] {
"smoketest",
"проверка не латинских символов", // not latin symbols check
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("str3Concat", input);
assertTrue(output.isString());
String expectedOutput =
new StringBuilder().append(value).append(value).append(value).toString();
assertTrue(expectedOutput.equals(output.toStr()));
}
}
@Test
public void testEmptyShape() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final long someNumber = 43;
final IValue input = IValue.from(Tensor.fromBlob(new long[] {someNumber}, new long[] {}));
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
assertTrue(output.isTensor());
Tensor value = output.toTensor();
assertArrayEquals(new long[] {}, value.shape());
assertArrayEquals(new long[] {someNumber}, value.getDataAsLongArray());
}
@Test
public void testAliasWithOffset() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testAliasWithOffset");
assertTrue(output.isTensorList());
Tensor[] tensors = output.toTensorList();
assertEquals(100, tensors[0].getDataAsLongArray()[0]);
assertEquals(200, tensors[1].getDataAsLongArray()[0]);
}
@Test
public void testNonContiguous() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testNonContiguous");
assertTrue(output.isTensor());
Tensor value = output.toTensor();
assertArrayEquals(new long[] {2}, value.shape());
assertArrayEquals(new long[] {100, 300}, value.getDataAsLongArray());
}
@Test
public void testChannelsLast() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2};
long[] data = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
assertIValueTensor(
outputNCHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2},
new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104});
final IValue outputNHWC = module.runMethod("contiguousChannelsLast", IValue.from(inputNHWC));
assertIValueTensor(outputNHWC, MemoryFormat.CHANNELS_LAST, inputShape, data);
}
@Test
public void testChannelsLast3d() throws IOException {
long[] shape = new long[] {1, 2, 2, 2, 2};
long[] dataNCHWD = new long[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
long[] dataNHWDC = new long[] {1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
Tensor inputNCHWD = Tensor.fromBlob(dataNCHWD, shape, MemoryFormat.CONTIGUOUS);
final IValue outputNHWDC =
module.runMethod("contiguousChannelsLast3d", IValue.from(inputNCHWD));
assertIValueTensor(outputNHWDC, MemoryFormat.CHANNELS_LAST_3D, shape, dataNHWDC);
}
@Test
public void testChannelsLastConv2d() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2};
long[] dataNCHW = new long[] {
111, 112,
121, 122,
211, 212,
221, 222,
311, 312,
321, 322};
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
long[] dataNHWC = new long[] {
111, 211, 311, 112, 212, 312,
121, 221, 321, 122, 222, 322};
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
long[] weightShape = new long[] {3, 3, 1, 1};
long[] dataWeightOIHW = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1};
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
long[] dataWeightOHWI = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1};
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW =
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
assertIValueTensor(
outputNCHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2},
new long[] {
2*111, 2*112,
2*121, 2*122,
211, 212,
221, 222,
-311, -312,
-321, -322});
final IValue outputNHWC =
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
assertIValueTensor(
outputNHWC,
MemoryFormat.CHANNELS_LAST,
new long[] {1, 3, 2, 2},
new long[] {
2*111, 211, -311, 2*112, 212, -312,
2*121, 221, -321, 2*122, 222, -322});
}
@Test
public void testChannelsLastConv3d() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2, 2};
long[] dataNCDHW = new long[] {
1111, 1112,
1121, 1122,
1211, 1212,
1221, 1222,
2111, 2112,
2121, 2122,
2211, 2212,
2221, 2222,
3111, 3112,
3121, 3122,
3211, 3212,
3221, 3222};
Tensor inputNCDHW = Tensor.fromBlob(dataNCDHW, inputShape, MemoryFormat.CONTIGUOUS);
long[] dataNDHWC = new long[] {
1111, 2111, 3111,
1112, 2112, 3112,
1121, 2121, 3121,
1122, 2122, 3122,
1211, 2211, 3211,
1212, 2212, 3212,
1221, 2221, 3221,
1222, 2222, 3222};
Tensor inputNDHWC = Tensor.fromBlob(dataNDHWC, inputShape, MemoryFormat.CHANNELS_LAST_3D);
long[] weightShape = new long[] {3, 3, 1, 1, 1};
long[] dataWeightOIDHW = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1,
};
Tensor wNCDHW = Tensor.fromBlob(dataWeightOIDHW, weightShape, MemoryFormat.CONTIGUOUS);
long[] dataWeightODHWI = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1,
};
Tensor wNDHWC = Tensor.fromBlob(dataWeightODHWI, weightShape, MemoryFormat.CHANNELS_LAST_3D);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCDHW =
module.runMethod("conv3d", IValue.from(inputNCDHW), IValue.from(wNCDHW), IValue.from(false));
assertIValueTensor(
outputNCDHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2, 2},
new long[] {
2*1111, 2*1112, 2*1121, 2*1122,
2*1211, 2*1212, 2*1221, 2*1222,
2111, 2112, 2121, 2122,
2211, 2212, 2221, 2222,
-3111, -3112, -3121, -3122,
-3211, -3212, -3221, -3222});
final IValue outputNDHWC =
module.runMethod("conv3d", IValue.from(inputNDHWC), IValue.from(wNDHWC), IValue.from(true));
assertIValueTensor(
outputNDHWC,
MemoryFormat.CHANNELS_LAST_3D,
new long[] {1, 3, 2, 2, 2},
new long[] {
2*1111, 2111, -3111, 2*1112, 2112, -3112,
2*1121, 2121, -3121, 2*1122, 2122, -3122,
2*1211, 2211, -3211, 2*1212, 2212, -3212,
2*1221, 2221, -3221, 2*1222, 2222, -3222});
}
@Test
public void testMobileNetV2() throws IOException {
try {
final Module module = loadModel("mobilenet_v2.ptl");
final IValue inputs = module.runMethod("get_all_bundled_inputs");
assertTrue(inputs.isList());
final IValue input = inputs.toList()[0];
assertTrue(input.isTuple());
module.forward(input.toTuple()[0]);
assertTrue(true);
} catch (Exception ex) {
assertTrue("failed to run MobileNetV2 " + ex.getMessage(), false);
}
}
@Test
public void testPointwiseOps() throws IOException {
runModel("pointwise_ops");
}
@Test
public void testReductionOps() throws IOException {
runModel("reduction_ops");
}
@Test
public void testComparisonOps() throws IOException {
runModel("comparison_ops");
}
@Test
public void testOtherMathOps() throws IOException {
runModel("other_math_ops");
}
@Test
@Ignore
public void testSpectralOps() throws IOException {
// NB: This model fails without lite interpreter. The error is as follows:
// RuntimeError: stft requires the return_complex parameter be given for real inputs
runModel("spectral_ops");
}
@Test
public void testBlasLapackOps() throws IOException {
runModel("blas_lapack_ops");
}
@Test
public void testSamplingOps() throws IOException {
runModel("sampling_ops");
}
@Test
public void testTensorOps() throws IOException {
runModel("tensor_general_ops");
}
@Test
public void testTensorCreationOps() throws IOException {
runModel("tensor_creation_ops");
}
@Test
public void testTensorIndexingOps() throws IOException {
runModel("tensor_indexing_ops");
}
@Test
public void testTensorTypingOps() throws IOException {
runModel("tensor_typing_ops");
}
@Test
public void testTensorViewOps() throws IOException {
runModel("tensor_view_ops");
}
@Test
public void testConvolutionOps() throws IOException {
runModel("convolution_ops");
}
@Test
public void testPoolingOps() throws IOException {
runModel("pooling_ops");
}
@Test
public void testPaddingOps() throws IOException {
runModel("padding_ops");
}
@Test
public void testActivationOps() throws IOException {
runModel("activation_ops");
}
@Test
public void testNormalizationOps() throws IOException {
runModel("normalization_ops");
}
@Test
public void testRecurrentOps() throws IOException {
runModel("recurrent_ops");
}
@Test
public void testTransformerOps() throws IOException {
runModel("transformer_ops");
}
@Test
public void testLinearOps() throws IOException {
runModel("linear_ops");
}
@Test
public void testDropoutOps() throws IOException {
runModel("dropout_ops");
}
@Test
public void testSparseOps() throws IOException {
runModel("sparse_ops");
}
@Test
public void testDistanceFunctionOps() throws IOException {
runModel("distance_function_ops");
}
@Test
public void testLossFunctionOps() throws IOException {
runModel("loss_function_ops");
}
@Test
public void testVisionFunctionOps() throws IOException {
runModel("vision_function_ops");
}
@Test
public void testShuffleOps() throws IOException {
runModel("shuffle_ops");
}
@Test
public void testNNUtilsOps() throws IOException {
runModel("nn_utils_ops");
}
@Test
public void testQuantOps() throws IOException {
runModel("general_quant_ops");
}
@Test
public void testDynamicQuantOps() throws IOException {
runModel("dynamic_quant_ops");
}
@Test
public void testStaticQuantOps() throws IOException {
runModel("static_quant_ops");
}
@Test
public void testFusedQuantOps() throws IOException {
runModel("fused_quant_ops");
}
@Test
public void testTorchScriptBuiltinQuantOps() throws IOException {
runModel("torchscript_builtin_ops");
}
@Test
public void testTorchScriptCollectionQuantOps() throws IOException {
runModel("torchscript_collection_ops");
}
static void assertIValueTensor(
final IValue ivalue,
final MemoryFormat memoryFormat,
final long[] expectedShape,
final long[] expectedData) {
assertTrue(ivalue.isTensor());
Tensor t = ivalue.toTensor();
assertEquals(memoryFormat, t.memoryFormat());
assertArrayEquals(expectedShape, t.shape());
assertArrayEquals(expectedData, t.getDataAsLongArray());
}
void runModel(final String name) throws IOException {
final Module storage_module = loadModel(name + ".ptl");
storage_module.forward();
// TODO enable this once the on-the-fly script is ready
// final Module on_the_fly_module = loadModel(name + "_temp.ptl");
// on_the_fly_module.forward();
assertTrue(true);
}
protected abstract Module loadModel(String assetName) throws IOException;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free