MetalCompute 1.0
An API to make GPU compute calls easier
Loading...
Searching...
No Matches
MTLComputeCommandManager.hpp
Go to the documentation of this file.
5#include <iostream>
6#include <vector>
7
8#pragma once
9
10namespace MTLCompute {
11
12 template< typename T>
14
15 private:
16 MTL::Device *gpu;
18 MTL::ComputePipelineState *pipeline;
19 MTL::CommandQueue *commandQueue;
20 MTL::CommandBuffer *commandBuffer;
21 MTL::ComputeCommandEncoder *commandEncoder;
22
23 std::vector<Buffer<T>> buffers = std::vector<Buffer<T>>(MAX_BUFFERS);
24 std::vector<Texture<T>> textures = std::vector<Texture<T>>(MAX_TEXTURES);
25 int bufferlength = -1;
26 int texwidth = -1;
27 int texheight = -1;
28
29 public:
30
42 this->gpu = gpu;
43 this->kernel = kernel;
44 this->pipeline = this->kernel->getPLS();
45
46 this->commandQueue = this->gpu->newCommandQueue();
47
48 }
49
54 CommandManager() = default;
55
63 this->commandQueue->autorelease();
64 }
65
75 void loadBuffer(Buffer<T> buffer, int index) {
76 if (this->bufferlength == -1) {
77 this->bufferlength = buffer.length;
78 } else if (this->bufferlength != buffer.length) {
79 throw std::invalid_argument("Buffer lengths do not match");
80 }
81
82 this->buffers[index] = buffer;
83 }
84
85
95 void loadTexture(Texture<T> texture, int index) {
96 if (this->texwidth == -1 && this->texheight == -1) {
97 this->texwidth = texture.getWidth();
98 this->texheight = texture.getHeight();
99 if (this->texwidth > MAX_TEXTURE_SIZE || this->texheight > MAX_TEXTURE_SIZE) {
100 throw std::invalid_argument("Texture size too large, max size is 16384");
101 }
102 } else if (this->texwidth != texture.getWidth() || this->texheight != texture.getHeight()) {
103 std::cout << this->texwidth << " " << this->texheight << std::endl;
104 std::cout << texture.getTexture()->width() << " " << texture.getTexture()->height() << std::endl;
105 throw std::invalid_argument("Texture sizes do not match");
106 }
107
108 this->textures[index] = texture;
109 }
110
118 void dispatch() {
119 if (this->kernel->getPLS() != this->pipeline) {
120 // Refresh the pipeline if it has changed
121 this->pipeline = this->kernel->getPLS();
122 }
123
124 // Create a new command buffer and command encoder
125 this->commandBuffer = this->commandQueue->commandBuffer();
126 this->commandEncoder = this->commandBuffer->computeCommandEncoder();
127 this->commandEncoder->setComputePipelineState(this->pipeline);
128 bool usingbuffers = false;
129 bool usingtextures = false;
130
131 // Load the buffers and textures into the commandEncoder
132 for (int i = 0; i < MAX_BUFFERS; i++) {
133 if (buffers[i].length == this->bufferlength && buffers[i].getBuffer() != nullptr) {
134 this->commandEncoder->setBuffer(buffers[i].getBuffer(), 0, i);
135 usingbuffers = true;
136 }
137 if (textures[i].getWidth() == this->texwidth && textures[i].getHeight() == this->texheight
138 && textures[i].getTexture() != nullptr) {
139 this->commandEncoder->setTexture(textures[i].getTexture(), i);
140 usingtextures = true;
141 }
142 }
143 for (int i = MAX_BUFFERS; i < MAX_TEXTURES; i++) {
144 if (textures[i].getWidth() == this->texwidth && textures[i].getHeight() == this->texheight
145 && textures[i].getTexture() != nullptr) {
146 this->commandEncoder->setTexture(textures[i].getTexture(), i);
147 usingtextures = true;
148 }
149 }
150
151 // Calculate the grid size and thread group size
152 MTL::Size threadsPerThreadgroup;
153 threadsPerThreadgroup.width = this->pipeline->threadExecutionWidth();
154 threadsPerThreadgroup.height = this->pipeline->maxTotalThreadsPerThreadgroup() / threadsPerThreadgroup.width;
155 threadsPerThreadgroup.depth = 1;
156
157 MTL::Size threadsPerGrid;
158 if (usingbuffers && usingtextures) {
159 if (this->bufferlength > this->texwidth)
160 threadsPerGrid = MTL::Size::Make(this->bufferlength, this->texheight, 1);
161 else
162 threadsPerGrid = MTL::Size::Make(this->texwidth, this->texheight, 1);
163 } else if (usingbuffers && !usingtextures) {
164 threadsPerGrid = MTL::Size::Make(this->bufferlength, 1, 1);
165
166 } else if (!usingbuffers && usingtextures) {
167 threadsPerGrid = MTL::Size::Make(this->texwidth, this->texheight, 1);
168
169 } else {
170 throw std::invalid_argument("No buffers or textures loaded");
171 }
172
173 // Use dispatchThreads NOT dispatchThreadgroups
174 this->commandEncoder->dispatchThreads(threadsPerGrid, threadsPerThreadgroup);
175 this->commandEncoder->endEncoding();
176 this->commandBuffer->commit();
177 this->commandBuffer->waitUntilCompleted();
178
179 // Release the command encoder and command buffer
180 this->commandEncoder->release();
181 this->commandBuffer->release();
182 }
183
189 this->buffers.clear();
190 this->buffers = std::vector<Buffer<T>>(MAX_BUFFERS);
191 this->bufferlength = -1;
192 }
193
199 this->textures.clear();
200 this->textures = std::vector<Texture<T>>(MAX_TEXTURES);
201 this->texwidth = -1;
202 this->texheight = -1;
203 }
204
211 void reset() {
212 this->resetBuffers();
213 this->resetTextures();
214 }
215
222 MTL::Device *getGPU() {
223 return this->gpu;
224 }
225
233 return this->kernel;
234 }
235
242 std::vector<Buffer<T>>& getBuffers() {
243 return this->buffers;
244 }
245
252 std::vector<Texture<T>>& getTextures() {
253 return this->textures;
254 }
255
256 };
257
258}
Definition MTLComputeBuffer.hpp:10
size_t length
The length of the buffer.
Definition MTLComputeBuffer.hpp:233
Definition MTLComputeCommandManager.hpp:13
MTL::ComputePipelineState * pipeline
The Metal compute pipeline state object.
Definition MTLComputeCommandManager.hpp:18
void loadBuffer(Buffer< T > buffer, int index)
Load a buffer into the CommandManager.
Definition MTLComputeCommandManager.hpp:75
CommandManager()=default
Default constructor for the CommandManager class.
std::vector< Buffer< T > > & getBuffers()
Get the loaded buffers.
Definition MTLComputeCommandManager.hpp:242
std::vector< Buffer< T > > buffers
The buffers.
Definition MTLComputeCommandManager.hpp:23
Kernel * getKernel()
Get the kernel object.
Definition MTLComputeCommandManager.hpp:232
std::vector< Texture< T > > & getTextures()
Get the loaded textures.
Definition MTLComputeCommandManager.hpp:252
int bufferlength
The length of the buffers.
Definition MTLComputeCommandManager.hpp:25
~CommandManager()
Destructor for the CommandManager class.
Definition MTLComputeCommandManager.hpp:62
Kernel * kernel
The kernel object.
Definition MTLComputeCommandManager.hpp:17
MTL::Device * getGPU()
Get the GPU device.
Definition MTLComputeCommandManager.hpp:222
void loadTexture(Texture< T > texture, int index)
Load a texture into the CommandManager.
Definition MTLComputeCommandManager.hpp:95
CommandManager(MTL::Device *gpu, MTLCompute::Kernel *kernel)
Constructor for the CommandManager class.
Definition MTLComputeCommandManager.hpp:41
MTL::ComputeCommandEncoder * commandEncoder
The Metal compute command encoder object.
Definition MTLComputeCommandManager.hpp:21
void resetBuffers()
reset the buffers and cached length
Definition MTLComputeCommandManager.hpp:188
void reset()
reset the buffers and textures
Definition MTLComputeCommandManager.hpp:211
MTL::CommandQueue * commandQueue
The Metal command queue object.
Definition MTLComputeCommandManager.hpp:19
std::vector< Texture< T > > textures
The textures.
Definition MTLComputeCommandManager.hpp:24
MTL::CommandBuffer * commandBuffer
The Metal command buffer object.
Definition MTLComputeCommandManager.hpp:20
void dispatch()
Dispatch the kernel.
Definition MTLComputeCommandManager.hpp:118
void resetTextures()
reset the textures and cached width and height
Definition MTLComputeCommandManager.hpp:198
int texwidth
The width of the textures.
Definition MTLComputeCommandManager.hpp:26
int texheight
The height of the textures.
Definition MTLComputeCommandManager.hpp:27
MTL::Device * gpu
The Metal device object.
Definition MTLComputeCommandManager.hpp:16
Definition MTLComputeKernel.hpp:7
MTL::ComputePipelineState * getPLS()
Get the MTL::ComputePipelineState object.
Definition MTLComputeKernel.hpp:117
Definition MTLComputeTexture.hpp:10
MTL::Texture * getTexture()
Get the MTL::Texture object.
Definition MTLComputeTexture.hpp:270
int getWidth()
Get the width of the texture.
Definition MTLComputeTexture.hpp:310
int getHeight()
Get the height of the texture.
Definition MTLComputeTexture.hpp:320
Definition MTLComputeBuffer.hpp:7
constexpr long MAX_TEXTURE_SIZE
Definition MTLComputeGlobals.hpp:13
constexpr int MAX_TEXTURES
Definition MTLComputeGlobals.hpp:12
constexpr int MAX_BUFFERS
Definition MTLComputeGlobals.hpp:11