libflame  revision_anchor
Functions
FLA_Scalr_external_gpu.c File Reference

(r)

Functions

FLA_Error FLA_Scalr_external_gpu (FLA_Uplo uplo, FLA_Obj alpha, FLA_Obj A, void *A_gpu)
 

Function Documentation

◆ FLA_Scalr_external_gpu()

FLA_Error FLA_Scalr_external_gpu ( FLA_Uplo  uplo,
FLA_Obj  alpha,
FLA_Obj  A,
void *  A_gpu 
)

References FLA_Check_error_level(), FLA_Obj_datatype(), FLA_Obj_equals(), FLA_Obj_has_zero_dim(), FLA_Obj_length(), FLA_Obj_width(), FLA_ONE, FLA_Scalr_check(), and i.

Referenced by FLASH_Queue_exec_task_gpu().

18 {
19  FLA_Datatype datatype;
20  int m_A, n_A;
21  int ldim_A, inc_A;
22  int i;
23 
24  if ( FLA_Check_error_level() == FLA_FULL_ERROR_CHECKING )
25  FLA_Scalr_check( uplo, alpha, A );
26 
27  if ( FLA_Obj_has_zero_dim( A ) ) return FLA_SUCCESS;
28 
29  if ( FLA_Obj_equals( alpha, FLA_ONE ) )
30  {
31  return FLA_SUCCESS;
32  }
33 
34  datatype = FLA_Obj_datatype( A );
35 
36  m_A = FLA_Obj_length( A );
37  n_A = FLA_Obj_width( A );
38  ldim_A = FLA_Obj_length( A );
39  inc_A = 1;
40 
41  if ( uplo == FLA_LOWER_TRIANGULAR ){
42 
43  switch ( datatype ){
44 
45  case FLA_FLOAT:
46  {
47  float* buff_alpha = ( float* ) FLA_FLOAT_PTR( alpha );
48  float* buff_A_gpu = ( float* ) A_gpu;
49 
50  for ( i = 0; i < min( n_A, m_A ); i++ )
51  cublasSscal( m_A - i,
52  *buff_alpha,
53  buff_A_gpu + i * ldim_A + i, inc_A );
54 
55  break;
56  }
57 
58  case FLA_DOUBLE:
59  {
60  double* buff_alpha = ( double* ) FLA_DOUBLE_PTR( alpha );
61  double* buff_A_gpu = ( double* ) A_gpu;
62 
63  for ( i = 0; i < min( n_A, m_A ); i++ )
64  cublasDscal( m_A - i,
65  *buff_alpha,
66  buff_A_gpu + i * ldim_A + i, inc_A );
67 
68  break;
69  }
70 
71  case FLA_COMPLEX:
72  {
73  cuComplex* buff_alpha = ( cuComplex* ) FLA_COMPLEX_PTR( alpha );
74  cuComplex* buff_A_gpu = ( cuComplex* ) A_gpu;
75 
76  for ( i = 0; i < min( n_A, m_A ); i++ )
77  cublasCscal( m_A - i,
78  *buff_alpha,
79  buff_A_gpu + i * ldim_A + i, inc_A );
80 
81  break;
82  }
83 
84  case FLA_DOUBLE_COMPLEX:
85  {
86  cuDoubleComplex* buff_alpha = ( cuDoubleComplex* ) FLA_DOUBLE_COMPLEX_PTR( alpha );
87  cuDoubleComplex* buff_A_gpu = ( cuDoubleComplex* ) A_gpu;
88 
89  for ( i = 0; i < min( n_A, m_A ); i++ )
90  cublasZscal( m_A - i,
91  *buff_alpha,
92  buff_A_gpu + i * ldim_A + i, inc_A );
93 
94  break;
95  }
96 
97  }
98 
99  }
100 
101  else if ( uplo == FLA_UPPER_TRIANGULAR ){
102 
103  switch ( datatype ){
104 
105  case FLA_FLOAT:
106  {
107  float* buff_alpha = ( float* ) FLA_FLOAT_PTR( alpha );
108  float* buff_A_gpu = ( float* ) A_gpu;
109 
110  for ( i = 0; i < n_A; i++ )
111  cublasSscal( min( i + 1, m_A ),
112  *buff_alpha,
113  buff_A_gpu + i * ldim_A, inc_A );
114 
115  break;
116  }
117 
118  case FLA_DOUBLE:
119  {
120  double* buff_alpha = ( double* ) FLA_DOUBLE_PTR( alpha );
121  double* buff_A_gpu = ( double* ) A_gpu;
122 
123  for ( i = 0; i < n_A; i++ )
124  cublasDscal( min( i + 1, m_A ),
125  *buff_alpha,
126  buff_A_gpu + i * ldim_A, inc_A );
127 
128  break;
129  }
130 
131  case FLA_COMPLEX:
132  {
133  cuComplex* buff_alpha = ( cuComplex* ) FLA_COMPLEX_PTR( alpha );
134  cuComplex* buff_A_gpu = ( cuComplex* ) A_gpu;
135 
136  for ( i = 0; i < n_A; i++ )
137  cublasCscal( min( i + 1, m_A ),
138  *buff_alpha,
139  buff_A_gpu + i * ldim_A, inc_A );
140 
141  break;
142  }
143 
144  case FLA_DOUBLE_COMPLEX:
145  {
146  cuDoubleComplex* buff_alpha = ( cuDoubleComplex* ) FLA_DOUBLE_COMPLEX_PTR( alpha );
147  cuDoubleComplex* buff_A_gpu = ( cuDoubleComplex* ) A_gpu;
148 
149  for ( i = 0; i < n_A; i++ )
150  cublasZscal( min( i + 1, m_A ),
151  *buff_alpha,
152  buff_A_gpu + i * ldim_A, inc_A );
153 
154  break;
155  }
156 
157  }
158 
159  }
160 
161  return FLA_SUCCESS;
162 }
FLA_Error FLA_Scalr_check(FLA_Uplo uplo, FLA_Obj alpha, FLA_Obj A)
Definition: FLA_Scalr_check.c:13
FLA_Obj FLA_ONE
Definition: FLA_Init.c:18
FLA_Datatype FLA_Obj_datatype(FLA_Obj obj)
Definition: FLA_Query.c:13
FLA_Bool FLA_Obj_has_zero_dim(FLA_Obj A)
Definition: FLA_Query.c:400
dim_t FLA_Obj_width(FLA_Obj obj)
Definition: FLA_Query.c:123
FLA_Bool FLA_Obj_equals(FLA_Obj A, FLA_Obj B)
Definition: FLA_Query.c:507
unsigned int FLA_Check_error_level(void)
Definition: FLA_Check.c:18
int FLA_Datatype
Definition: FLA_type_defs.h:49
int i
Definition: bl1_axmyv2.c:145
dim_t FLA_Obj_length(FLA_Obj obj)
Definition: FLA_Query.c:116