94181959e2d4495a0cc753a161587e4ca748790b
[gnome.gobject-introspection] / giscanner / transformer.py
1 # -*- Mode: Python -*-
2 # GObject-Introspection - a framework for introspecting GObject libraries
3 # Copyright (C) 2008  Johan Dahlin
4 #
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; either version 2
8 # of the License, or (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
18 # 02110-1301, USA.
19 #
20
21 import os
22 import re
23
24 from giscanner.ast import (Callback, Enum, Function, Namespace, Member,
25                            Parameter, Return, Array, Struct, Field,
26                            Type, Alias, Interface, Class, Node, Union,
27                            List, Map, Varargs, Constant, type_name_from_ctype,
28                            type_names, default_array_types, TYPE_STRING)
29 from giscanner.config import DATADIR
30 from .glibast import GLibBoxed
31 from giscanner.sourcescanner import (
32     SourceSymbol, ctype_name, CTYPE_POINTER,
33     CTYPE_BASIC_TYPE, CTYPE_UNION, CTYPE_ARRAY, CTYPE_TYPEDEF,
34     CTYPE_VOID, CTYPE_ENUM, CTYPE_FUNCTION, CTYPE_STRUCT,
35     CSYMBOL_TYPE_FUNCTION, CSYMBOL_TYPE_TYPEDEF, CSYMBOL_TYPE_STRUCT,
36     CSYMBOL_TYPE_ENUM, CSYMBOL_TYPE_UNION, CSYMBOL_TYPE_OBJECT,
37     CSYMBOL_TYPE_MEMBER, CSYMBOL_TYPE_ELLIPSIS, CSYMBOL_TYPE_CONST,
38     TYPE_QUALIFIER_CONST)
39 from .odict import odict
40 from .utils import strip_common_prefix, to_underscores
41
42 _xdg_data_dirs = [x for x in os.environ.get('XDG_DATA_DIRS', '').split(':') \
43                       + [DATADIR, '/usr/share'] if x]
44
45
46 class SkipError(Exception):
47     pass
48
49
50 class Names(object):
51     names = property(lambda self: self._names)
52     aliases = property(lambda self: self._aliases)
53     type_names = property(lambda self: self._type_names)
54     ctypes = property(lambda self: self._ctypes)
55
56     def __init__(self):
57         super(Names, self).__init__()
58         self._names = odict() # Maps from GIName -> (namespace, node)
59         self._aliases = {} # Maps from GIName -> GIName
60         self._type_names = {} # Maps from GTName -> (namespace, node)
61         self._ctypes = {} # Maps from CType -> (namespace, node)
62
63
64 class Transformer(object):
65
66     def __init__(self, generator, namespace_name, namespace_version):
67         self.generator = generator
68         self._namespace = Namespace(namespace_name, namespace_version)
69         self._names = Names()
70         self._typedefs_ns = {}
71         self._strip_prefix = ''
72         self._includes = set()
73         self._includepaths = []
74
75         self._list_ctypes = []
76         self._map_ctypes = []
77
78     def get_names(self):
79         return self._names
80
81     def get_includes(self):
82         return self._includes
83
84     def set_container_types(self, list_ctypes, map_ctypes):
85         self._list_ctypes = list_ctypes
86         self._map_ctypes = map_ctypes
87
88     def set_strip_prefix(self, strip_prefix):
89         self._strip_prefix = strip_prefix
90
91     def parse(self):
92         nodes = []
93         for symbol in self.generator.get_symbols():
94             node = self._traverse_one(symbol)
95             self._add_node(node)
96         return self._namespace
97
98     def register_include(self, filename):
99         (dirname, basename) = os.path.split(filename)
100         if dirname:
101             path = filename
102             (name, suffix) = os.path.splitext(basename)
103         else:
104             path = None
105             name = filename
106             if name.endswith('.gir'):
107                 (name, suffix) = os.path.splitext(name)
108         if name in self._includes:
109             return
110         source = filename
111         if path is None:
112             girname = name + '.gir'
113             searchdirs = [os.path.join(d, 'gir') for d \
114                               in _xdg_data_dirs]
115             searchdirs.extend(self._includepaths)
116             for d in searchdirs:
117                 path = os.path.join(d, girname)
118                 if os.path.exists(path):
119                     break
120                 path = None
121             if not path:
122                 raise ValueError("Couldn't find include %r (search path: %r)"\
123                                      % (girname, searchdirs))
124         d = os.path.dirname(path)
125         if d not in self._includepaths:
126             self._includepaths.append(d)
127         self._includes.add(name)
128         from .girparser import GIRParser
129         parser = GIRParser(path)
130         for include in parser.get_includes():
131             self.register_include(include)
132         nsname = parser.get_namespace().name
133         for node in parser.get_namespace().nodes:
134             if isinstance(node, Alias):
135                 self._names.aliases[node.name] = (nsname, node)
136             elif isinstance(node, (GLibBoxed, Interface, Class)):
137                 self._names.type_names[node.type_name] = (nsname, node)
138             self._names.names[node.name] = (nsname, node)
139             if hasattr(node, 'ctype'):
140                 self._names.ctypes[node.ctype] = (nsname, node)
141             elif hasattr(node, 'symbol'):
142                 self._names.ctypes[node.symbol] = (nsname, node)
143
144     def strip_namespace_object(self, name):
145         prefix = self._namespace.name.lower()
146         if len(name) > len(prefix) and name.lower().startswith(prefix):
147             return name[len(prefix):]
148         return self._remove_prefix(name)
149
150     # Private
151
152     def _add_node(self, node):
153         if node is None:
154             return
155         if node.name.startswith('_'):
156             return
157         self._namespace.nodes.append(node)
158         self._names.names[node.name] = (None, node)
159
160     def _strip_namespace_func(self, name):
161         prefix = self._namespace.name.lower() + '_'
162         if name.lower().startswith(prefix):
163             name = name[len(prefix):]
164         else:
165             prefix = to_underscores(self._namespace.name).lower() + '_'
166             if name.lower().startswith(prefix):
167                 name = name[len(prefix):]
168         return self._remove_prefix(name, isfunction=True)
169
170     def _remove_prefix(self, name, isfunction=False):
171         # when --strip-prefix=g:
172         #   GHashTable -> HashTable
173         #   g_hash_table_new -> hash_table_new
174         prefix = self._strip_prefix.lower()
175         if isfunction:
176             prefix += '_'
177         if name.lower().startswith(prefix):
178             name = name[len(prefix):]
179
180         while name.startswith('_'):
181             name = name[1:]
182         return name
183
184     def _traverse_one(self, symbol, stype=None):
185         assert isinstance(symbol, SourceSymbol), symbol
186
187         if stype is None:
188             stype = symbol.type
189         if stype == CSYMBOL_TYPE_FUNCTION:
190             try:
191                 return self._create_function(symbol)
192             except SkipError:
193                 return
194         elif stype == CSYMBOL_TYPE_TYPEDEF:
195             return self._create_typedef(symbol)
196         elif stype == CSYMBOL_TYPE_STRUCT:
197             return self._create_struct(symbol)
198         elif stype == CSYMBOL_TYPE_ENUM:
199             return self._create_enum(symbol)
200         elif stype == CSYMBOL_TYPE_OBJECT:
201             return self._create_object(symbol)
202         elif stype == CSYMBOL_TYPE_MEMBER:
203             return self._create_member(symbol)
204         elif stype == CSYMBOL_TYPE_UNION:
205             return self._create_union(symbol)
206         elif stype == CSYMBOL_TYPE_CONST:
207             return self._create_const(symbol)
208         else:
209             raise NotImplementedError(
210                 'Transformer: unhandled symbol: %r' % (symbol, ))
211
212     def _create_enum(self, symbol):
213         members = []
214         for child in symbol.base_type.child_list:
215             name = strip_common_prefix(symbol.ident, child.ident).lower()
216             members.append(Member(name,
217                                   child.const_int,
218                                   child.ident))
219
220         enum_name = self.strip_namespace_object(symbol.ident)
221         enum_name = symbol.ident[-len(enum_name):]
222         enum_name = self._remove_prefix(enum_name)
223         enum = Enum(enum_name, symbol.ident, members)
224         self._names.type_names[symbol.ident] = (None, enum)
225         return enum
226
227     def _create_object(self, symbol):
228         return Member(symbol.ident, symbol.base_type.name,
229                       symbol.ident)
230
231     def _parse_deprecated(self, node, directives):
232         deprecated = directives.get('deprecated', False)
233         if deprecated:
234             deprecated_value = deprecated[0]
235             if ':' in deprecated_value:
236                 # Split out gtk-doc version
237                 (node.deprecated_version, node.deprecated) = \
238                     [x.strip() for x in deprecated_value.split(':', 1)]
239             else:
240                 # No version, just include str
241                 node.deprecated = deprecated_value.strip()
242
243     def _pair_array(self, params, array):
244         if not array.type.length_param_name:
245             return
246         target_name = array.type.length_param_name
247         for i, param in enumerate(params):
248             if param.name == array.type.length_param_name:
249                 array.type.length_param_index = i
250                 return
251         raise ValueError("Unmatched length parameter name %r"\
252                              % (target_name, ))
253
254     def _pair_annotations(self, params):
255         names = {}
256         for param in params:
257             if param.name in names:
258                 raise ValueError("Duplicate parameter name %r"\
259                                      % (param.name, ))
260             names[param.name] = 1
261             if isinstance(param.type, Array):
262                 self._pair_array(params, param)
263
264     # We take the annotations from the parser as strings; here we
265     # want to split them into components, so:
266     # (transfer full) -> {'transfer' : [ 'full' ]}
267
268     def _parse_options(self, options):
269         ret = {}
270         ws_re = re.compile(r'\s+')
271         for opt in options:
272             items = ws_re.split(opt)
273             ret[items[0]] = items[1:]
274         return ret
275
276     def _create_function(self, symbol):
277         directives = symbol.directives()
278         parameters = list(self._create_parameters(
279             symbol.base_type, directives))
280         self._pair_annotations(parameters)
281         return_ = self._create_return(symbol.base_type.base_type,
282                                       directives.get('return', {}))
283         name = self._strip_namespace_func(symbol.ident)
284         func = Function(name, return_, parameters, symbol.ident)
285         self._parse_deprecated(func, directives)
286         return func
287
288     def _create_source_type(self, source_type):
289         if source_type is None:
290             return 'None'
291         if source_type.type == CTYPE_VOID:
292             value = 'void'
293         elif source_type.type == CTYPE_BASIC_TYPE:
294             value = source_type.name
295         elif source_type.type == CTYPE_TYPEDEF:
296             value = source_type.name
297         elif source_type.type == CTYPE_ARRAY:
298             return self._create_source_type(source_type.base_type)
299         elif source_type.type == CTYPE_POINTER:
300             value = self._create_source_type(source_type.base_type) + '*'
301         else:
302             print 'TRANSFORMER: Unhandled source type %r' % (
303                 source_type, )
304             value = 'any'
305         return value
306
307     def _create_parameters(self, base_type, directives=None):
308         if directives is None:
309             dirs = {}
310         else:
311             dirs = directives
312         for child in base_type.child_list:
313             yield self._create_parameter(
314                 child, dirs.get(child.ident, {}))
315
316     def _create_member(self, symbol):
317         ctype = symbol.base_type.type
318         if (ctype == CTYPE_POINTER and
319             symbol.base_type.base_type.type == CTYPE_FUNCTION):
320             node = self._create_callback(symbol)
321         else:
322             ftype = self._create_type(symbol.base_type, {})
323             node = Field(symbol.ident, ftype, symbol.ident, symbol.const_int)
324         return node
325
326     def _create_typedef(self, symbol):
327         ctype = symbol.base_type.type
328         if (ctype == CTYPE_POINTER and
329             symbol.base_type.base_type.type == CTYPE_FUNCTION):
330             node = self._create_callback(symbol)
331         elif ctype == CTYPE_STRUCT:
332             node = self._create_typedef_struct(symbol)
333         elif ctype == CTYPE_UNION:
334             node = self._create_typedef_union(symbol)
335         elif ctype == CTYPE_ENUM:
336             return self._create_enum(symbol)
337         elif ctype in (CTYPE_TYPEDEF,
338                        CTYPE_POINTER,
339                        CTYPE_BASIC_TYPE,
340                        CTYPE_VOID):
341             name = self.strip_namespace_object(symbol.ident)
342             if symbol.base_type.name:
343                 target = self.strip_namespace_object(symbol.base_type.name)
344             else:
345                 target = 'none'
346             if name in type_names:
347                 return None
348             return Alias(name, target, ctype=symbol.ident)
349         else:
350             raise NotImplementedError(
351                 "symbol %r of type %s" % (symbol.ident, ctype_name(ctype)))
352         return node
353
354     def _parse_ctype(self, ctype):
355         canonical = type_name_from_ctype(ctype)
356         derefed = canonical.replace('*', '')
357         return derefed
358
359     def _create_type(self, source_type, options):
360         ctype = self._create_source_type(source_type)
361         if ctype == 'va_list':
362             raise SkipError()
363         # FIXME: FILE* should not be skipped, it should be handled
364         #        properly instead
365         elif ctype == 'FILE*':
366             raise SkipError
367         if ctype in self._list_ctypes:
368             param = options.get('element-type')
369             if param:
370                 contained_type = self._parse_ctype(param[0])
371             else:
372                 contained_type = None
373             return List(ctype.replace('*', ''),
374                         ctype,
375                         contained_type)
376         if ctype in self._map_ctypes:
377             param = options.get('element-type')
378             if param:
379                 key_type = self._parse_ctype(param[0])
380                 value_type = self._parse_ctype(param[1])
381             else:
382                 key_type = None
383                 value_type = None
384             return Map(ctype.replace('*', ''),
385                        ctype,
386                        key_type, value_type)
387         if (ctype in default_array_types) or ('array' in options):
388             derefed = ctype[:-1] # strip the *
389             result = Array(ctype,
390                          self._parse_ctype(derefed))
391             array_opts = options.get('array')
392             if array_opts:
393                 (_, len_name) = array_opts[0].split('=')
394                 result.length_param_name = len_name
395             return result
396         resolved_type_name = self._parse_ctype(ctype)
397
398         # string memory management
399         if type_name_from_ctype(ctype) == TYPE_STRING:
400             if source_type.base_type.type_qualifier & TYPE_QUALIFIER_CONST:
401                 options['transfer'] = ['none']
402             else:
403                 options['transfer'] = ['full']
404
405         return Type(resolved_type_name, ctype)
406
407     def _handle_generic_param_options(self, param, options):
408         for option, data in options.iteritems():
409             if option == 'transfer':
410                 if data:
411                     depth = data[0]
412                     if depth not in ('none', 'container', 'full'):
413                         raise ValueError("Invalid transfer %r" % (depth, ))
414                 else:
415                     depth = 'full'
416                 param.transfer = depth
417
418     def _create_parameter(self, symbol, options):
419         options = self._parse_options(options)
420         if symbol.type == CSYMBOL_TYPE_ELLIPSIS:
421             ptype = Varargs()
422         else:
423             ptype = self._create_type(symbol.base_type, options)
424         param = Parameter(symbol.ident, ptype)
425         for option, data in options.iteritems():
426             if option in ['in-out', 'inout']:
427                 param.direction = 'inout'
428             elif option == 'in':
429                 param.direction = 'in'
430             elif option == 'out':
431                 param.direction = 'out'
432             elif option == 'allow-none':
433                 param.allow_none = True
434             elif option.startswith(('element-type', 'array')):
435                 pass
436             elif option == 'transfer':
437                 pass
438             else:
439                 print 'Unhandled parameter annotation option: %r' % (
440                     option, )
441         self._handle_generic_param_options(param, options)
442         return param
443
444     def _create_return(self, source_type, options=None):
445         if options is None:
446             options_map = {}
447         else:
448             options_map = self._parse_options(options)
449         rtype = self._create_type(source_type, options_map)
450         rtype = self.resolve_param_type(rtype)
451         return_ = Return(rtype)
452         self._handle_generic_param_options(return_, options_map)
453         for option, data in options_map.iteritems():
454             if option == 'transfer':
455                 pass
456             else:
457                 print 'Unhandled return type annotation option: %r' % (
458                     option, )
459         return return_
460
461     def _create_const(self, symbol):
462         name = self._remove_prefix(symbol.ident)
463         name = self._strip_namespace_func(name)
464         if symbol.const_string is None:
465             type_name = 'int'
466             value = symbol.const_int
467         else:
468             type_name = 'utf8'
469             value = symbol.const_string
470         const = Constant(name, type_name, value)
471         return const
472
473     def _create_typedef_struct(self, symbol):
474         name = self.strip_namespace_object(symbol.ident)
475         struct = Struct(name, symbol.ident)
476         self._typedefs_ns[symbol.ident] = struct
477         self._create_struct(symbol)
478         return struct
479
480     def _create_typedef_union(self, symbol):
481         name = self._remove_prefix(symbol.ident)
482         name = self.strip_namespace_object(name)
483         union = Union(name, symbol.ident)
484         self._typedefs_ns[symbol.ident] = union
485         self._create_union(symbol)
486         return union
487
488     def _create_struct(self, symbol):
489         struct = self._typedefs_ns.get(symbol.ident, None)
490         if struct is None:
491             # This is a bit of a hack; really we should try
492             # to resolve through the typedefs to find the real
493             # name
494             if symbol.ident.startswith('_'):
495                 name = symbol.ident[1:]
496             else:
497                 name = symbol.ident
498             name = self.strip_namespace_object(name)
499             name = self.resolve_type_name(name)
500             struct = Struct(name, symbol.ident)
501
502         for child in symbol.base_type.child_list:
503             field = self._traverse_one(child)
504             if field:
505                 struct.fields.append(field)
506
507         return struct
508
509     def _create_union(self, symbol):
510         union = self._typedefs_ns.get(symbol.ident, None)
511         if union is None:
512             # This is a bit of a hack; really we should try
513             # to resolve through the typedefs to find the real
514             # name
515             if symbol.ident.startswith('_'):
516                 name = symbol.ident[1:]
517             else:
518                 name = symbol.ident
519             name = self.strip_namespace_object(name)
520             name = self.resolve_type_name(name)
521             union = Union(name, symbol.ident)
522
523         for child in symbol.base_type.child_list:
524             field = self._traverse_one(child)
525             if field:
526                 union.fields.append(field)
527
528         return union
529
530     def _create_callback(self, symbol):
531         parameters = self._create_parameters(symbol.base_type.base_type)
532         retval = self._create_return(symbol.base_type.base_type.base_type)
533         if symbol.ident.find('_') > 0:
534             name = self._strip_namespace_func(symbol.ident)
535         else:
536             name = self.strip_namespace_object(symbol.ident)
537         return Callback(name, retval, list(parameters), symbol.ident)
538
539     def _typepair_to_str(self, item):
540         nsname, item = item
541         if nsname is None:
542             return item.name
543         return '%s.%s' % (nsname, item.name)
544
545     def _resolve_type_name_1(self, type_name, ctype, names):
546         # First look using the built-in names
547         if ctype:
548             try:
549                 return type_names[ctype]
550             except KeyError, e:
551                 pass
552         try:
553             return type_names[type_name]
554         except KeyError, e:
555             pass
556         type_name = self.strip_namespace_object(type_name)
557         resolved = names.aliases.get(type_name)
558         if resolved:
559             return self._typepair_to_str(resolved)
560         resolved = names.names.get(type_name)
561         if resolved:
562             return self._typepair_to_str(resolved)
563         if ctype:
564             ctype = ctype.replace('*', '')
565             resolved = names.ctypes.get(ctype)
566             if resolved:
567                 return self._typepair_to_str(resolved)
568         resolved = names.type_names.get(type_name)
569         if resolved:
570             return self._typepair_to_str(resolved)
571         raise KeyError("failed to find %r" % (type_name, ))
572
573     def resolve_type_name_full(self, type_name, ctype,
574                                names):
575         try:
576             return self._resolve_type_name_1(type_name, ctype, names)
577         except KeyError, e:
578             try:
579                 return self._resolve_type_name_1(type_name, ctype, self._names)
580             except KeyError, e:
581                 return type_name
582
583     def resolve_type_name(self, type_name, ctype=None):
584         try:
585             return self.resolve_type_name_full(type_name, ctype, self._names)
586         except KeyError, e:
587             return type_name
588
589     def gtypename_to_giname(self, gtname, names):
590         resolved = names.type_names.get(gtname)
591         if resolved:
592             return self._typepair_to_str(resolved)
593         resolved = self._names.type_names.get(gtname)
594         if resolved:
595             return self._typepair_to_str(resolved)
596         raise KeyError("Failed to resolve GType name: %r" % (gtname, ))
597
598     def ctype_of(self, obj):
599         if hasattr(obj, 'ctype'):
600             return obj.ctype
601         elif hasattr(obj, 'symbol'):
602             return obj.symbol
603         else:
604             return None
605
606     def resolve_param_type_full(self, ptype, names):
607         if isinstance(ptype, Node):
608             ptype.name = self.resolve_type_name_full(ptype.name,
609                                                      self.ctype_of(ptype),
610                                                      names)
611             if isinstance(ptype, (Array, List)):
612                 if ptype.element_type is not None:
613                     ptype.element_type = \
614                         self.resolve_param_type_full(ptype.element_type, names)
615             if isinstance(ptype, Map):
616                 if ptype.key_type is not None:
617                     ptype.key_type = \
618                         self.resolve_param_type_full(ptype.key_type, names)
619                     ptype.value_type = \
620                         self.resolve_param_type_full(ptype.value_type, names)
621         elif isinstance(ptype, basestring):
622             return self.resolve_type_name_full(ptype, None, names)
623         else:
624             raise AssertionError("Unhandled param: %r" % (ptype, ))
625         return ptype
626
627     def resolve_param_type(self, ptype):
628         try:
629             return self.resolve_param_type_full(ptype, self._names)
630         except KeyError, e:
631             return ptype
632
633     def follow_aliases(self, type_name, names):
634         while True:
635             resolved = names.aliases.get(type_name)
636             if resolved:
637                 (ns, alias) = resolved
638                 type_name = alias.target
639             else:
640                 break
641         return type_name