model_variant_handle.h Source File

model_variant_handle.h Source File#

SDK qb Runtime Library: model_variant_handle.h Source File
SDK qb Runtime Library v1.1
MCS001-
model_variant_handle.h
Go to the documentation of this file.
1
4
5#ifndef QBRUNTIME_MODEL_VARIANT_HANDLE_H_
6#define QBRUNTIME_MODEL_VARIANT_HANDLE_H_
7
8#include <stdint.h>
9
10#include <vector>
11
12#include "qbruntime/export.h"
14#include "qbruntime/type.h"
15
16namespace mobilint {
17
22
23class ModelImpl;
24
34class QBRUNTIME_EXPORT ModelVariantHandle {
35public:
36 ModelVariantHandle(const ModelVariantHandle& other) = delete;
37 ModelVariantHandle(ModelVariantHandle&& other) = delete;
38 ModelVariantHandle& operator=(const ModelVariantHandle& rhs) = delete;
39 ModelVariantHandle& operator=(ModelVariantHandle&& rhs) noexcept = delete;
40 ~ModelVariantHandle();
41
47 int getVariantIdx() const;
48
54 const std::vector<std::vector<int64_t>>& getModelInputShape() const;
55
61 const std::vector<std::vector<int64_t>>& getModelOutputShape() const;
62
68 const std::vector<BufferInfo>& getInputBufferInfo() const;
69
75 const std::vector<BufferInfo>& getOutputBufferInfo() const;
76
82 std::vector<Scale> getInputScale() const;
83
89 std::vector<Scale> getOutputScale() const;
90
97
104
127
128 // Acquire buffer
129 std::vector<Buffer> acquireInputBuffer(
130 const std::vector<std::vector<int>>& seqlens = {}) const;
131 std::vector<Buffer> acquireOutputBuffer(
132 const std::vector<std::vector<int>>& seqlens = {}) const;
133 std::vector<std::vector<Buffer>> acquireInputBuffers(
134 int batch_size, const std::vector<std::vector<int>>& seqlens = {}) const;
135 std::vector<std::vector<Buffer>> acquireOutputBuffers(
136 int batch_size, const std::vector<std::vector<int>>& seqlens = {}) const;
137
138 // Deallocate acquired Input/Output buffer
139 StatusCode releaseBuffer(std::vector<Buffer>& buffer) const;
140 StatusCode releaseBuffers(std::vector<std::vector<Buffer>>& buffers) const;
141
142 // Reposition single batch
143 StatusCode repositionInputs(const std::vector<float*>& input,
144 std::vector<Buffer>& input_buf,
145 const std::vector<std::vector<int>>& seqlens = {}) const;
146 StatusCode repositionOutputs(const std::vector<Buffer>& output_buf,
147 std::vector<float*>& output,
148 const std::vector<std::vector<int>>& seqlens = {}) const;
149 StatusCode repositionOutputs(const std::vector<Buffer>& output_buf,
150 std::vector<std::vector<float>>& output,
151 const std::vector<std::vector<int>>& seqlens = {}) const;
152 StatusCode repositionInputs(const std::vector<uint8_t*>& input,
153 std::vector<Buffer>& input_buf,
154 const std::vector<std::vector<int>>& seqlens = {}) const;
155
156 // Reposition multiple batches
157 StatusCode repositionInputs(const std::vector<float*>& input,
158 std::vector<std::vector<Buffer>>& input_buf,
159 const std::vector<std::vector<int>>& seqlens = {}) const;
160 StatusCode repositionOutputs(const std::vector<std::vector<Buffer>>& output_buf,
161 std::vector<float*>& output,
162 const std::vector<std::vector<int>>& seqlens = {}) const;
163 StatusCode repositionOutputs(const std::vector<std::vector<Buffer>>& output_buf,
164 std::vector<std::vector<float>>& output,
165 const std::vector<std::vector<int>>& seqlens = {}) const;
166 StatusCode repositionInputs(const std::vector<uint8_t*>& input,
167 std::vector<std::vector<Buffer>>& input_buf,
168 const std::vector<std::vector<int>>& seqlens = {}) const;
170
171private:
172 ModelVariantHandle(int variant_idx, const ModelImpl& model_impl);
173
174 const int mIdx;
175 const ModelImpl& mModelImpl;
176
177 friend class ModelImpl;
178};
179
180} // namespace mobilint
181
182#endif // QBRUNTIME_MODEL_VARIANT_HANDLE_H_
DataType
DataType.
Definition type.h:508
const std::vector< std::vector< int64_t > > & getModelOutputShape() const
Returns the output shape for this model variant.
const std::vector< BufferInfo > & getOutputBufferInfo() const
Returns the output buffer information for this variant.
std::vector< Scale > getInputScale() const
Returns the input quantization scale(s) for this variant.
StatusCode
Enumerates status codes for the qbruntime.
Definition status_code.h:26
std::vector< Scale > getOutputScale() const
Returns the output quantization scale(s) for this variant.
DataType getModelOutputDataType() const
Returns a data type for model outputs.
DataType getModelInputDataType() const
Returns a data type for model inputs.
const std::vector< BufferInfo > & getInputBufferInfo() const
Returns the input buffer information for this variant.
const std::vector< std::vector< int64_t > > & getModelInputShape() const
Returns the input shape for this model variant.
int getVariantIdx() const
Returns the index of this model variant.