Ginkgo Generated from branch based on master. Ginkgo version 1.8.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
workspace.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_SOLVER_WORKSPACE_HPP_
6#define GKO_PUBLIC_CORE_SOLVER_WORKSPACE_HPP_
7
8
9#include <typeinfo>
10
11
12#include <ginkgo/core/matrix/dense.hpp>
13
14
15namespace gko {
16namespace solver {
17namespace detail {
18
19
23class any_array {
24public:
25 template <typename ValueType>
26 array<ValueType>& init(std::shared_ptr<const Executor> exec, size_type size)
27 {
28 auto container = std::make_unique<concrete_container<ValueType>>(
29 std::move(exec), size);
30 auto& arr = container->arr;
31 data_ = std::move(container);
32 return arr;
33 }
34
35 bool empty() const { return data_.get() == nullptr; }
36
37 template <typename ValueType>
38 bool contains() const
39 {
40 return dynamic_cast<const concrete_container<ValueType>*>(data_.get());
41 }
42
43 template <typename ValueType>
44 array<ValueType>& get()
45 {
46 GKO_ASSERT(this->template contains<ValueType>());
47 return dynamic_cast<concrete_container<ValueType>*>(data_.get())->arr;
48 }
49
50 template <typename ValueType>
51 const array<ValueType>& get() const
52 {
53 GKO_ASSERT(this->template contains<ValueType>());
54 return dynamic_cast<const concrete_container<ValueType>*>(data_.get())
55 ->arr;
56 }
57
58 void clear() { data_.reset(); }
59
60private:
61 struct generic_container {
62 virtual ~generic_container() = default;
63 };
64
65 template <typename ValueType>
66 struct concrete_container : generic_container {
67 template <typename... Args>
68 concrete_container(Args&&... args) : arr{std::forward<Args>(args)...}
69 {}
70
72 };
73
74 std::unique_ptr<generic_container> data_;
75};
76
77
78class workspace {
79public:
80 workspace(std::shared_ptr<const Executor> exec) : exec_{std::move(exec)} {}
81
82 workspace(const workspace& other) : workspace{other.get_executor()} {}
83
84 workspace(workspace&& other) : workspace{other.get_executor()}
85 {
86 other.clear();
87 }
88
89 workspace& operator=(const workspace& other) { return *this; }
90
91 workspace& operator=(workspace&& other)
92 {
93 other.clear();
94 return *this;
95 }
96
97 template <typename LinOpType, typename CreateOperation>
98 LinOpType* create_or_get_op(int op_id, CreateOperation create,
99 const std::type_info& expected_type,
100 dim<2> size, size_type stride)
101 {
102 GKO_ASSERT(op_id >= 0 && op_id < operators_.size());
103 // does the existing object have the wrong type?
104 // vector types may vary e.g. if users derive from Dense
105 auto stored_op = operators_[op_id].get();
106 LinOpType* op{};
107 if (!stored_op || typeid(*stored_op) != expected_type) {
108 auto new_op = create();
109 op = new_op.get();
110 operators_[op_id] = std::move(new_op);
111 return op;
112 }
113 // does the existing object have the wrong dimensions?
114 op = dynamic_cast<LinOpType*>(operators_[op_id].get());
115 GKO_ASSERT(op);
116 if (op->get_size() != size || op->get_stride() != stride) {
117 auto new_op = create();
118 op = new_op.get();
119 operators_[op_id] = std::move(new_op);
120 }
121 return op;
122 }
123
124 const LinOp* get_op(int op_id) const
125 {
126 GKO_ASSERT(op_id >= 0 && op_id < operators_.size());
127 return operators_[op_id].get();
128 }
129
130 template <typename ValueType>
131 array<ValueType>& init_or_get_array(int array_id)
132 {
133 GKO_ASSERT(array_id >= 0 && array_id < arrays_.size());
134 auto& array = arrays_[array_id];
135 if (array.empty()) {
136 auto& result =
137 array.template init<ValueType>(this->get_executor(), 0);
138 return result;
139 }
140 // array types should not change!
141 GKO_ASSERT(array.template contains<ValueType>());
142 return array.template get<ValueType>();
143 }
144
145 template <typename ValueType>
146 array<ValueType>& create_or_get_array(int array_id, size_type size)
147 {
148 auto& result = init_or_get_array<ValueType>(array_id);
149 if (result.get_size() != size) {
150 result.resize_and_reset(size);
151 }
152 return result;
153 }
154
155 std::shared_ptr<const Executor> get_executor() const { return exec_; }
156
157 void set_size(int num_operators, int num_arrays)
158 {
159 operators_.resize(num_operators);
160 arrays_.resize(num_arrays);
161 }
162
163 void clear()
164 {
165 for (auto& op : operators_) {
166 op.reset();
167 }
168 for (auto& array : arrays_) {
169 array.clear();
170 }
171 }
172
173private:
174 std::shared_ptr<const Executor> exec_;
175 std::vector<std::unique_ptr<LinOp>> operators_;
176 std::vector<any_array> arrays_;
177};
178
179
180} // namespace detail
181} // namespace solver
182} // namespace gko
183
184#endif // GKO_PUBLIC_CORE_SOLVER_WORKSPACE_HPP_
The Ginkgo namespace.
Definition abstract_factory.hpp:20
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:86
@ array
The matrix should be written as dense matrix in column-major order.